#!python

"""
    pcied
    PCIe device monitoring daemon for SONiC
"""

import os
import signal
import sys
import threading

from sonic_py_common import daemon_base, device_info, logger
from swsscommon import swsscommon

#
# Constants ====================================================================
#

# TODO: Once we no longer support Python 2, we can eliminate this and get the
# name using the 'name' field (e.g., `signal.SIGINT.name`) starting with Python 3.5
SIGNALS_TO_NAMES_DICT = dict((getattr(signal, n), n)
                             for n in dir(signal) if n.startswith('SIG') and '_' not in n)

SYSLOG_IDENTIFIER = "pcied"

PCIE_RESULT_REGEX = "PCIe Device Checking All Test"
PCIE_DEVICE_TABLE_NAME = "PCIE_DEVICE"
PCIE_STATUS_TABLE_NAME = "PCIE_DEVICES"

PCIED_MAIN_THREAD_SLEEP_SECS = 60

PCIEUTIL_CONF_FILE_ERROR = 1
PCIEUTIL_LOAD_ERROR = 2

platform_pcieutil = None

log = logger.Logger(SYSLOG_IDENTIFIER)

exit_code = 0

# wrapper functions to call the platform api
def load_platform_pcieutil():
    _platform_pcieutil = None
    (platform_path, _) = device_info.get_paths_to_platform_and_hwsku_dirs()
    try:
        from sonic_platform.pcie import Pcie
        _platform_pcieutil = Pcie(platform_path)
    except ImportError as e:
        log.log_notice("Failed to load platform Pcie module. Error : {}, Fallback to default module".format(str(e)), True)
        try:
            from sonic_platform_base.sonic_pcie.pcie_common import PcieUtil
            _platform_pcieutil = PcieUtil(platform_path)
        except ImportError as e:
            log.log_error("Failed to load default PcieUtil module. Error : {}".format(str(e)), True)
    return _platform_pcieutil

def read_id_file(device_name):
    id = None
    dev_id_path = '/sys/bus/pci/devices/0000:%s/device' % device_name

    if os.path.exists(dev_id_path):
        with open(dev_id_path, 'r') as fd:
            id = fd.read().strip()
    return id

#
# Daemon =======================================================================
#


class DaemonPcied(daemon_base.DaemonBase):
    def __init__(self, log_identifier):
        super(DaemonPcied, self).__init__(log_identifier)

        self.timeout = PCIED_MAIN_THREAD_SLEEP_SECS
        self.stop_event = threading.Event()
        self.state_db = None
        self.device_table = None
        self.table = None
        self.resultInfo = []
        self.device_name = None
        self.aer_stats = {}

        global platform_pcieutil

        platform_pcieutil = load_platform_pcieutil()
        if platform_pcieutil is None:
            sys.exit(PCIEUTIL_LOAD_ERROR)

        # Connect to STATE_DB and create pcie device table
        self.state_db = daemon_base.db_connect("STATE_DB")
        self.device_table = swsscommon.Table(self.state_db, PCIE_DEVICE_TABLE_NAME)
        self.status_table = swsscommon.Table(self.state_db, PCIE_STATUS_TABLE_NAME)

    def __del__(self):
        if self.device_table:
            table_keys = self.device_table.getKeys()
            for tk in table_keys:
                self.device_table._del(tk)
        if self.status_table:
            stable_keys = self.status_table.getKeys()
            for stk in stable_keys:
                self.status_table._del(stk)

    # load aer-fields into statedb
    def update_aer_to_statedb(self):
        if self.aer_stats is None:
            self.log_debug("PCIe device {} has no AER Stats".format(device_name))
            return

        aer_fields = {}

        for key, fv in self.aer_stats.items():
            for field, value in fv.items():
                key_field = "{}|{}".format(key,field)
                aer_fields[key_field] = value

        if aer_fields:
            formatted_fields = swsscommon.FieldValuePairs(list(aer_fields.items()))
            self.device_table.set(self.device_name, formatted_fields)
        else:
            self.log_debug("PCIe device {} has no AER attriutes".format(self.device_name))


    # Check the PCIe AER Stats
    def check_n_update_pcie_aer_stats(self, Bus, Dev, Fn):
        self.device_name = "%02x:%02x.%d" % (Bus, Dev, Fn)

        Id = read_id_file(self.device_name)

        self.aer_stats = {}
        if Id is not None:
            fvp = swsscommon.FieldValuePairs([('id', Id)])
            self.device_table.set(self.device_name, fvp)
            self.aer_stats = platform_pcieutil.get_pcie_aer_stats(bus=Bus, dev=Dev, func=Fn)
            self.update_aer_to_statedb()


    # Update the PCIe devices status to DB
    def update_pcie_devices_status_db(self, err):
        if err:
            pcie_status = "FAILED"
            self.log_error("PCIe device status check : {}".format(pcie_status))
        else:
            pcie_status = "PASSED"
            self.log_info("PCIe device status check : {}".format(pcie_status))
        fvs = swsscommon.FieldValuePairs([
            ('status', pcie_status)
        ])

        self.status_table.set("status", fvs)

    # Check the PCIe devices
    def check_pcie_devices(self):
        self.resultInfo = platform_pcieutil.get_pcie_check()
        err = 0
        if self.resultInfo is None:
            return

        for result in self.resultInfo:
            if result["result"] == "Failed":
                self.log_warning("PCIe Device: " + result["name"] + " Not Found")
                err += 1
            else:
                Bus = int(result["bus"], 16)
                Dev = int(result["dev"], 16)
                Fn = int(result["fn"], 16)
                # update AER-attributes to DB
                self.check_n_update_pcie_aer_stats(Bus, Dev, Fn)

        # update PCIe Device Status to DB
        self.update_pcie_devices_status_db(err)

   # Override signal handler from DaemonBase
    def signal_handler(self, sig, frame):
        FATAL_SIGNALS = [signal.SIGINT, signal.SIGTERM]
        NONFATAL_SIGNALS = [signal.SIGHUP]

        global exit_code

        if sig in FATAL_SIGNALS:
            self.log_info("Caught signal '{}' - exiting...".format(SIGNALS_TO_NAMES_DICT[sig]))
            exit_code = 128 + sig  # Make sure we exit with a non-zero code so that supervisor will try to restart us
            self.stop_event.set()
        elif sig in NONFATAL_SIGNALS:
            self.log_info("Caught signal '{}' - ignoring...".format(SIGNALS_TO_NAMES_DICT[sig]))
        else:
            self.log_warning("Caught unhandled signal '{}' - ignoring...".format(SIGNALS_TO_NAMES_DICT[sig]))

    # Main daemon logic
    def run(self):
        if self.stop_event.wait(self.timeout):
            # We received a fatal signal
            return False

        self.check_pcie_devices()

        return True
#
# Main =========================================================================
#


def main():
    pcied = DaemonPcied(SYSLOG_IDENTIFIER)

    pcied.log_info("Starting up...")

    while pcied.run():
        pass

    pcied.log_info("Shutting down...")

    return exit_code

if __name__ == '__main__':
    sys.exit(main())
