#!python

"""
    pcied
    PCIe device monitoring daemon for SONiC
"""

import os
import signal
import sys
import threading

from sonic_py_common import daemon_base, device_info, logger
from swsscommon import swsscommon

#
# Constants ====================================================================
#

# TODO: Once we no longer support Python 2, we can eliminate this and get the
# name using the 'name' field (e.g., `signal.SIGINT.name`) starting with Python 3.5
SIGNALS_TO_NAMES_DICT = dict((getattr(signal, n), n)
                             for n in dir(signal) if n.startswith('SIG') and '_' not in n)

SYSLOG_IDENTIFIER = "pcied"

PCIE_RESULT_REGEX = "PCIe Device Checking All Test"
PCIE_DEVICE_TABLE_NAME = "PCIE_DEVICE"
PCIE_STATUS_TABLE_NAME = "PCIE_DEVICES"
PCIE_DETACH_INFO_TABLE = "PCIE_DETACH_INFO"

PCIE_DETACH_BUS_INFO_FIELD = "bus_info"
PCIE_DETACH_DPU_STATE_FIELD = "dpu_state"

PCIED_MAIN_THREAD_SLEEP_SECS = 60

PCIEUTIL_CONF_FILE_ERROR = 1
PCIEUTIL_LOAD_ERROR = 2

platform_pcieutil = None

log = logger.Logger(SYSLOG_IDENTIFIER)

exit_code = 0

# wrapper functions to call the platform api
def load_platform_pcieutil():
    _platform_pcieutil = None
    (platform_path, _) = device_info.get_paths_to_platform_and_hwsku_dirs()
    try:
        from sonic_platform.pcie import Pcie
        _platform_pcieutil = Pcie(platform_path)
    except ImportError as e:
        log.log_notice("Failed to load platform Pcie module. Error : {}, Fallback to default module".format(str(e)), True)
        try:
            from sonic_platform_base.sonic_pcie.pcie_common import PcieUtil
            _platform_pcieutil = PcieUtil(platform_path)
        except ImportError as e:
            log.log_error("Failed to load default PcieUtil module. Error : {}".format(str(e)), True)
    if _platform_pcieutil is None:
        log.log_error("Failed to load any PCIe utility module. Exiting...", True)
        raise RuntimeError("Unable to load PCIe utility module.")
    return _platform_pcieutil

def read_id_file(device_name):
    id = None
    dev_id_path = '/sys/bus/pci/devices/0000:%s/device' % device_name

    if os.path.exists(dev_id_path):
        with open(dev_id_path, 'r') as fd:
            id = fd.read().strip()
    return id

#
# Daemon =======================================================================
#


class DaemonPcied(daemon_base.DaemonBase):
    def __init__(self, log_identifier):
        super(DaemonPcied, self).__init__(log_identifier)

        self.timeout = PCIED_MAIN_THREAD_SLEEP_SECS
        self.stop_event = threading.Event()
        self.state_db = None
        self.device_table = None
        self.status_table = None
        self.resultInfo = []
        self.device_name = None
        self.aer_stats = {}

        global platform_pcieutil

        platform_pcieutil = load_platform_pcieutil()
        if platform_pcieutil is None:
            sys.exit(PCIEUTIL_LOAD_ERROR)

        # Connect to STATE_DB and create pcie device table
        try:
            self.state_db = daemon_base.db_connect("STATE_DB")
            self.device_table = swsscommon.Table(self.state_db, PCIE_DEVICE_TABLE_NAME)
            self.status_table = swsscommon.Table(self.state_db, PCIE_STATUS_TABLE_NAME)
            self.detach_info = swsscommon.Table(self.state_db, PCIE_DETACH_INFO_TABLE)
        except Exception as e:
            log.log_error("Failed to connect to STATE_DB or create table. Error: {}".format(str(e)), True)
            sys.exit(PCIEUTIL_CONF_FILE_ERROR)

    def __del__(self):
        try:
            if self.device_table:
                table_keys = self.device_table.getKeys()
                for tk in table_keys:
                    self.device_table._del(tk)
            if self.status_table:
                stable_keys = self.status_table.getKeys()
                for stk in stable_keys:
                    self.status_table._del(stk)
        except Exception as e:
            log.log_warning("Exception during cleanup: {}".format(str(e)), True)

    # load aer-fields into statedb
    def update_aer_to_statedb(self):
        if self.aer_stats is None:
            self.log_debug("PCIe device {} has no AER Stats".format(self.device_name))
            return

        try:
            aer_fields = {
                f"{key}|{field}": value
                for key, fv in self.aer_stats.items()
                for field, value in fv.items()
            }
            if aer_fields:
                self.device_table.set(self.device_name, swsscommon.FieldValuePairs(list(aer_fields.items())))
            else:
                self.log_debug("PCIe device {} has no AER attributes".format(self.device_name))
        except Exception as e:
            self.log_error("Exception while updating AER attributes to STATE_DB for {}: {}".format(self.device_name, str(e)))


    # Check the PCIe AER Stats
    def check_n_update_pcie_aer_stats(self, Bus, Dev, Fn):
        try:
            self.device_name = "%02x:%02x.%d" % (Bus, Dev, Fn)
            Id = read_id_file(self.device_name)
            self.aer_stats = {}
            if Id is not None:
                self.device_table.set(self.device_name, swsscommon.FieldValuePairs([('id', Id)]))
                self.aer_stats = platform_pcieutil.get_pcie_aer_stats(bus=Bus, dev=Dev, func=Fn)
                self.update_aer_to_statedb()
        except Exception as e:
            self.log_error("Exception while checking AER attributes for {}: {}".format(self.device_name, str(e)))


    # Update the PCIe devices status to DB
    def update_pcie_devices_status_db(self, err):
        if err:
            pcie_status = "FAILED"
            self.log_error("PCIe device status check : {}".format(pcie_status))
        else:
            pcie_status = "PASSED"
            self.log_info("PCIe device status check : {}".format(pcie_status))

        try:
            self.status_table.set("status", swsscommon.FieldValuePairs([('status', pcie_status)]))
        except Exception as e:
            self.log_error("Exception while updating PCIe device status to STATE_DB: {}".format(str(e)))

    # Check if any PCI interface is in detaching mode by querying the state_db
    def is_dpu_in_detaching_mode(self, pcie_dev):
        # Ensure detach_info is not None
        if self.detach_info is None:
            self.log_debug("detach_info is None")
            return False

        # Query the state_db for the device detaching status
        detach_info_keys = list(self.detach_info.getKeys())
        if not detach_info_keys:
            return False

        for key in detach_info_keys:
            dpu_info = self.detach_info.get(key)
            if dpu_info:
                # Convert tuple of field-value pairs to dictionary for easier access
                dpu_dict = dict(dpu_info[1])
                bus_info = dpu_dict.get(PCIE_DETACH_BUS_INFO_FIELD)
                dpu_state = dpu_dict.get(PCIE_DETACH_DPU_STATE_FIELD)
                if bus_info == pcie_dev and dpu_state == "detaching":
                    return True

        return False

    # Check the PCIe devices
    def check_pcie_devices(self):
        self.resultInfo = platform_pcieutil.get_pcie_check()
        err = 0
        if self.resultInfo is None:
            return

        for result in self.resultInfo:
            if result["result"] == "Failed":
                # Convert bus, device, and function to a bus_info format like "0000:03:00.0"
                pcie_dev = "0000:{:02x}:{:02x}.{}".format(int(result["bus"], 16), int(result["dev"], 16), int(result["fn"], 16))

                # Check if the device is in detaching mode
                if device_info.is_smartswitch() and self.is_dpu_in_detaching_mode(pcie_dev):
                    self.log_debug("PCIe Device: {} is in detaching mode, skipping warning.".format(pcie_dev))
                    continue

                self.log_warning("PCIe Device: " + result["name"] + " Not Found")
                err += 1
            else:
                Bus = int(result["bus"], 16)
                Dev = int(result["dev"], 16)
                Fn = int(result["fn"], 16)
                # update AER-attributes to DB
                self.check_n_update_pcie_aer_stats(Bus, Dev, Fn)

        # update PCIe Device Status to DB
        self.update_pcie_devices_status_db(err)

   # Override signal handler from DaemonBase
    def signal_handler(self, sig, frame):
        FATAL_SIGNALS = [signal.SIGINT, signal.SIGTERM]
        NONFATAL_SIGNALS = [signal.SIGHUP]

        global exit_code

        if sig in FATAL_SIGNALS:
            self.log_info("Caught signal '{}' - exiting...".format(SIGNALS_TO_NAMES_DICT[sig]))
            exit_code = 128 + sig  # Make sure we exit with a non-zero code so that supervisor will try to restart us
            self.stop_event.set()
        elif sig in NONFATAL_SIGNALS:
            self.log_info("Caught signal '{}' - ignoring...".format(SIGNALS_TO_NAMES_DICT[sig]))
        else:
            self.log_warning("Caught unhandled signal '{}' - ignoring...".format(SIGNALS_TO_NAMES_DICT[sig]))

    # Main daemon logic
    def run(self):
        try:
            if self.stop_event.wait(self.timeout):
                # We received a fatal signal
                return False
        except Exception as e:
            self.log_error("Exception occurred during stop_event wait: {}".format(str(e)))
            return False

        self.check_pcie_devices()

        return True
#
# Main =========================================================================
#


def main():
    pcied = DaemonPcied(SYSLOG_IDENTIFIER)

    pcied.log_info("Starting up...")

    while pcied.run():
        pass

    pcied.log_info("Shutting down...")

    return exit_code

if __name__ == '__main__':
    sys.exit(main())
