#!python

'''
    sensormond
    Sensor monitor daemon for SONiC
'''

import signal
import sys
import threading
import time
import yaml

import sonic_platform
from sonic_py_common import daemon_base, logger, device_info
from swsscommon import swsscommon
from sonic_platform_base.sensor_fs import VoltageSensorFs, CurrentSensorFs


SYSLOG_IDENTIFIER = 'sensormond'
NOT_AVAILABLE     = 'N/A'
CHASSIS_INFO_KEY  = 'chassis 1'
INVALID_SLOT      = -1

PHYSICAL_ENTITY_INFO_TABLE = 'PHYSICAL_ENTITY_INFO'

# Exit with non-zero exit code by default so supervisord will restart Sensormon.
SENSORMON_ERROR_EXIT = 1
exit_code = SENSORMON_ERROR_EXIT

# Utility functions

def try_get(callback, default=NOT_AVAILABLE):
    '''
    Handy function to invoke the callback and catch NotImplementedError
    :param callback: Callback to be invoked
    :param default: Default return value if exception occur
    :return: Default return value if exception occur else return value of the callback
    '''
    try:
        ret = callback()
        if ret is None:
            ret = default
    except NotImplementedError:
        ret = default

    return ret

def update_entity_info(table, parent_name, key, device, device_index):
    fvs = swsscommon.FieldValuePairs(
        [('position_in_parent', str(try_get(device.get_position_in_parent, device_index))),
         ('parent_name', parent_name)])
    table.set(key, fvs)

class SensorStatus(logger.Logger):

    def __init__(self):
        super(SensorStatus, self).__init__(SYSLOG_IDENTIFIER)

        self.value = None
        self.over_threshold = False
        self.under_threshold = False

    def set_value(self, name, value):
        '''
        Record sensor changes.
        :param name: Name of the sensor.
        :param value: New value.
        '''
        if value == NOT_AVAILABLE:
            if self.value is not None:
                self.log_warning('Value of {} became unavailable'.format(name))
                self.value = None
            return

        self.value = value

    def set_over_threshold(self, value, threshold):
        '''
        Set over threshold status
        :param value: value
        :param threshold: High threshold
        :return: True if over threshold status changed else False
        '''
        if value == NOT_AVAILABLE or threshold == NOT_AVAILABLE:
            old_status = self.over_threshold
            self.over_threshold = False
            return old_status != self.over_threshold

        status = value > threshold
        if status == self.over_threshold:
            return False

        self.over_threshold = status
        return True

    def set_under_threshold(self, value, threshold):
        '''
        Set under value status
        :param value: value
        :param threshold: Low threshold
        :return: True if under threshold status changed else False
        '''
        if value == NOT_AVAILABLE or threshold == NOT_AVAILABLE:
            old_status = self.under_threshold
            self.under_threshold = False
            return old_status != self.under_threshold

        status = value < threshold
        if status == self.under_threshold:
            return False

        self.under_threshold = status
        return True


#
# SensorUpdater  ======================================================================
#
class SensorUpdater(logger.Logger):

    def __init__(self, table_name, chassis):
        '''
        Initializer of SensorUpdater
        :param table_name: Name of sensor table
        :param chassis: Object representing a platform chassis
        '''
        super(SensorUpdater, self).__init__(SYSLOG_IDENTIFIER)

        self.chassis = chassis
        state_db = daemon_base.db_connect("STATE_DB")
        self.table = swsscommon.Table(state_db, table_name)
        self.phy_entity_table = swsscommon.Table(state_db, PHYSICAL_ENTITY_INFO_TABLE)
        self.chassis_table = None

        self.is_chassis_system = chassis.is_modular_chassis()
        if self.is_chassis_system:
            my_slot = try_get(chassis.get_my_slot, INVALID_SLOT)
            if my_slot != INVALID_SLOT:
                try:
                    # Modular chassis may not have table CHASSIS_STATE_DB.
                    slot_table_name = table_name + '_' + str(my_slot)
                    chassis_state_db = daemon_base.db_connect("CHASSIS_STATE_DB")
                    self.chassis_table = swsscommon.Table(chassis_state_db, slot_table_name)
                except Exception as e:
                    self.chassis_table = None

    def __del__(self):
        if self.table:
            table_keys = self.table.getKeys()
            for tk in table_keys:
                self.table._del(tk)
                if self.is_chassis_system and self.chassis_table is not None:
                    self.chassis_table._del(tk)
        if self.phy_entity_table:
            phy_entity_keys = self.phy_entity_table.getKeys()
            for pek in phy_entity_keys:
                self.phy_entity_table._del(pek)

    def _log_on_status_changed(self, normal_status, normal_log, abnormal_log):
        '''
        Log when any status changed
        :param normal_status: Expected status.
        :param normal_log: Log string for expected status.
        :param abnormal_log: Log string for unexpected status
        :return:
        '''
        if normal_status:
            self.log_notice(normal_log)
        else:
            self.log_warning(abnormal_log)

#
# VoltageUpdater  ======================================================================
#
class VoltageUpdater(SensorUpdater):
    # Voltage information table name in database
    VOLTAGE_INFO_TABLE_NAME = 'VOLTAGE_INFO'

    def __init__(self, chassis, fs_sensors):
        '''
        Initializer of VoltageUpdater
        :param chassis: Object representing a platform chassis
        '''
        super(VoltageUpdater, self).__init__(self.VOLTAGE_INFO_TABLE_NAME, chassis)

        self.voltage_status_dict = {}

        if self.is_chassis_system:
            self.module_voltage_sensors = set()

        self.fs_sensors = fs_sensors

    def update(self):
        '''
        Update all voltage information to database
        :return:
        '''
        self.log_debug("Start voltage update")

        sensor_list = self.fs_sensors + self.chassis.get_all_voltage_sensors()

        for index, voltage_sensor in enumerate(sensor_list):
            self._refresh_voltage_status(CHASSIS_INFO_KEY, voltage_sensor, index)

        if self.is_chassis_system:
            available_voltage_sensors = set()
            for module_index, module in enumerate(self.chassis.get_all_modules()):
                module_name = try_get(module.get_name, 'Module {}'.format(module_index + 1))
                
                for voltage_sensor_index, voltage_sensor in enumerate(module.get_all_voltage_sensors()):
                    available_voltage_sensors.add((voltage_sensor, module_name, voltage_sensor_index))
                    self._refresh_voltage_status(module_name, voltage_sensor, voltage_sensor_index)

            voltage_sensors_to_remove = self.module_voltage_sensors - available_voltage_sensors
            self.module_voltage_sensors = available_voltage_sensors
            for voltage_sensor, parent_name, voltage_sensor_index in voltage_sensors_to_remove:
                self._remove_voltage_sensor_from_db(voltage_sensor, parent_name, voltage_sensor_index)
                
        self.log_debug("End Voltage updating")

    def _refresh_voltage_status(self, parent_name, voltage_sensor, voltage_sensor_index):
        '''
        Get voltage status by platform API and write to database
        :param parent_name: Name of parent device of the voltage_sensor object
        :param voltage_sensor: Object representing a platform voltage voltage_sensor
        :param voltage_sensor_index: Index of the voltage_sensor object in platform chassis
        :return:
        '''
        try:
            name = try_get(voltage_sensor.get_name, '{} voltage_sensor {}'.format(parent_name, voltage_sensor_index + 1))

            update_entity_info(self.phy_entity_table, parent_name, name, voltage_sensor, voltage_sensor_index + 1)

            if name not in self.voltage_status_dict:
                self.voltage_status_dict[name] = SensorStatus()

            voltage_status = self.voltage_status_dict[name]

            high_threshold = NOT_AVAILABLE
            low_threshold = NOT_AVAILABLE
            high_critical_threshold = NOT_AVAILABLE
            low_critical_threshold = NOT_AVAILABLE
            maximum_voltage = NOT_AVAILABLE
            minimum_voltage = NOT_AVAILABLE
            unit = NOT_AVAILABLE
            voltage = try_get(voltage_sensor.get_value)
            is_replaceable = try_get(voltage_sensor.is_replaceable, False)
            if voltage != NOT_AVAILABLE:
                voltage_status.set_value(name, voltage)
                unit = try_get(voltage_sensor.get_unit)
                minimum_voltage = try_get(voltage_sensor.get_minimum_recorded)
                maximum_voltage = try_get(voltage_sensor.get_maximum_recorded)
                high_threshold = try_get(voltage_sensor.get_high_threshold)
                low_threshold = try_get(voltage_sensor.get_low_threshold)
                high_critical_threshold = try_get(voltage_sensor.get_high_critical_threshold)
                low_critical_threshold = try_get(voltage_sensor.get_low_critical_threshold)

            warning = False
            if voltage != NOT_AVAILABLE and voltage_status.set_over_threshold(voltage, high_threshold):
                self._log_on_status_changed(not voltage_status.over_threshold,
                                            'High voltage warning cleared: {} voltage restored to {}{}, high threshold {}{}'.
                                            format(name, voltage, unit, high_threshold, unit),
                                            'High voltage warning: {} current voltage {}{}, high threshold {}{}'.
                                            format(name, voltage, unit, high_threshold, unit)
                                            )
            warning = warning | voltage_status.over_threshold

            if voltage != NOT_AVAILABLE and voltage_status.set_under_threshold(voltage, low_threshold):
                self._log_on_status_changed(not voltage_status.under_threshold,
                                            'Low voltage warning cleared: {} voltage restored to {}{}, low threshold {}{}'.
                                            format(name, voltage, unit, low_threshold, unit),
                                            'Low voltage warning: {} current voltage {}{}, low threshold {}{}'.
                                            format(name, voltage, unit, low_threshold, unit)
                                            )
            warning = warning | voltage_status.under_threshold

            fvs = swsscommon.FieldValuePairs(
                [('voltage', str(voltage)),
                ('unit', unit),
                ('minimum_voltage', str(minimum_voltage)),
                ('maximum_voltage', str(maximum_voltage)),
                ('high_threshold', str(high_threshold)),
                ('low_threshold', str(low_threshold)),
                ('warning_status', str(warning)),
                ('critical_high_threshold', str(high_critical_threshold)),
                ('critical_low_threshold', str(low_critical_threshold)),
                ('is_replaceable', str(is_replaceable)),
                ('timestamp', time.strftime('%Y%m%d %H:%M:%S'))
                ])

            self.table.set(name, fvs)
            if self.is_chassis_system and self.chassis_table is not None:
                self.chassis_table.set(name, fvs)
        except Exception as e:
            self.log_warning('Failed to update voltage_sensor status for {} - {}'.format(name, repr(e)))

    def _remove_voltage_sensor_from_db(self, voltage_sensor, parent_name, voltage_sensor_index):
        name = try_get(voltage_sensor.get_name, '{} voltage_sensor {}'.format(parent_name, voltage_sensor_index + 1))
        self.table._del(name)

        if self.chassis_table is not None:
            self.chassis_table._del(name)
    
#
# CurrentUpdater  ======================================================================
#
class CurrentUpdater(SensorUpdater):
    # Current information table name in database
    CURRENT_INFO_TABLE_NAME = 'CURRENT_INFO'

    def __init__(self, chassis, fs_sensors):
        '''
        Initializer of CurrentUpdater
        :param chassis: Object representing a platform chassis
        '''
        super(CurrentUpdater, self).__init__(self.CURRENT_INFO_TABLE_NAME, chassis)

        self.current_status_dict = {}
        if self.is_chassis_system:
            self.module_current_sensors = set()

        self.fs_sensors = fs_sensors

    def update(self):
        '''
        Update all current information to database
        :return:
        '''
        self.log_debug("Start current updating")

        sensor_list = self.fs_sensors + self.chassis.get_all_current_sensors()

        for index, current_sensor in enumerate(sensor_list):
            self._refresh_current_status(CHASSIS_INFO_KEY, current_sensor, index)

        if self.is_chassis_system:
            available_current_sensors = set()
            for module_index, module in enumerate(self.chassis.get_all_modules()):
                module_name = try_get(module.get_name, 'Module {}'.format(module_index + 1))
                
                for current_sensor_index, current_sensor in enumerate(module.get_all_current_sensors()):
                    available_current_sensors.add((current_sensor, module_name, current_sensor_index))
                    self._refresh_current_status(module_name, current_sensor, current_sensor_index)

            current_sensors_to_remove = self.module_current_sensors - available_current_sensors
            self.module_current_sensors = available_current_sensors
            for current_sensor, parent_name, current_sensor_index in current_sensors_to_remove:
                self._remove_current_sensor_from_db(current_sensor, parent_name, current_sensor_index)
                
        self.log_debug("End Current updating")

    def _refresh_current_status(self, parent_name, current_sensor, current_sensor_index):
        '''
        Get current status by platform API and write to database
        :param parent_name: Name of parent device of the current_sensor object
        :param current_sensor: Object representing a platform current current_sensor
        :param current_sensor_index: Index of the current_sensor object in platform chassis
        :return:
        '''
        try:
            name = try_get(current_sensor.get_name, '{} current_sensor {}'.format(parent_name, current_sensor_index + 1))

            update_entity_info(self.phy_entity_table, parent_name, name, current_sensor, current_sensor_index + 1)

            if name not in self.current_status_dict:
                self.current_status_dict[name] = SensorStatus()

            current_status = self.current_status_dict[name]

            unit = NOT_AVAILABLE
            high_threshold = NOT_AVAILABLE
            low_threshold = NOT_AVAILABLE
            high_critical_threshold = NOT_AVAILABLE
            low_critical_threshold = NOT_AVAILABLE
            maximum_current = NOT_AVAILABLE
            minimum_current = NOT_AVAILABLE
            current = try_get(current_sensor.get_value)
            is_replaceable = try_get(current_sensor.is_replaceable, False)
            if current != NOT_AVAILABLE:
                current_status.set_value(name, current)
                unit = try_get(current_sensor.get_unit)
                minimum_current = try_get(current_sensor.get_minimum_recorded)
                maximum_current = try_get(current_sensor.get_maximum_recorded)
                high_threshold = try_get(current_sensor.get_high_threshold)
                low_threshold = try_get(current_sensor.get_low_threshold)
                high_critical_threshold = try_get(current_sensor.get_high_critical_threshold)
                low_critical_threshold = try_get(current_sensor.get_low_critical_threshold)

            warning = False
            if current != NOT_AVAILABLE and current_status.set_over_threshold(current, high_threshold):
                self._log_on_status_changed(not current_status.over_threshold,
                                            'High Current warning cleared: {} current restored to {}{}, high threshold {}{}'.
                                            format(name, current, unit, high_threshold, unit),
                                            'High Current warning: {} current Current {}{}, high threshold {}{}'.
                                            format(name, current, unit, high_threshold, unit)
                                            )
            warning = warning | current_status.over_threshold

            if current != NOT_AVAILABLE and current_status.set_under_threshold(current, low_threshold):
                self._log_on_status_changed(not current_status.under_threshold,
                                            'Low current warning cleared: {} current restored to {}{}, low threshold {}{}'.
                                            format(name, current, unit, low_threshold, unit),
                                            'Low current warning: {} current current {}{}, low threshold {}{}'.
                                            format(name, current, unit, low_threshold, unit)
                                            )
            warning = warning | current_status.under_threshold

            fvs = swsscommon.FieldValuePairs(
                [('current', str(current)),
                ('unit', unit),
                ('minimum_current', str(minimum_current)),
                ('maximum_current', str(maximum_current)),
                ('high_threshold', str(high_threshold)),
                ('low_threshold', str(low_threshold)),
                ('warning_status', str(warning)),
                ('critical_high_threshold', str(high_critical_threshold)),
                ('critical_low_threshold', str(low_critical_threshold)),
                ('is_replaceable', str(is_replaceable)),
                ('timestamp', time.strftime('%Y%m%d %H:%M:%S'))
                ])

            self.table.set(name, fvs)
            if self.is_chassis_system and self.chassis_table is not None:
                self.chassis_table.set(name, fvs)
        except Exception as e:
            self.log_warning('Failed to update current_sensor status for {} - {}'.format(name, repr(e)))

    def _remove_current_sensor_from_db(self, current_sensor, parent_name, current_sensor_index):
        name = try_get(current_sensor.get_name, '{} current_sensor {}'.format(parent_name, current_sensor_index + 1))
        self.table._del(name)

        if self.chassis_table is not None:
            self.chassis_table._del(name)

#
# Daemon =======================================================================
#
class SensorMonitorDaemon(daemon_base.DaemonBase):

    # Initial update interval
    INITIAL_INTERVAL = 5
    # Periodic Update interval 
    UPDATE_INTERVAL = 60
    # Update time threshold. If update time exceeds this threshold, log warning msg.
    UPDATE_ELAPSED_THRESHOLD = 30

    def __init__(self):
        '''
        Initializer of SensorMonitorDaemon
        '''
        super(SensorMonitorDaemon, self).__init__(SYSLOG_IDENTIFIER)

        # Set minimum logging level to INFO
        self.set_min_log_priority_info()

        self.stop_event = threading.Event()

        self.wait_time = self.INITIAL_INTERVAL

        self.interval  = self.UPDATE_INTERVAL

        self._voltage_sensor_fs = []
        self._current_sensor_fs = []

        try:
            self.chassis = sonic_platform.platform.Platform().get_chassis()
        except Exception as e:
            self.log_error("Failed to get chassis info, err: {}".format(repr(e)))

        # Initialize voltage and current sensors lists from data file if available
        try:
            (platform_path, _) = device_info.get_paths_to_platform_and_hwsku_dirs()
            self.sensors_yaml_file = platform_path + "/sensors.yaml"

            with open(self.sensors_yaml_file, 'r') as f:
                sensors_data = yaml.safe_load(f)
                if 'voltage_sensors' in sensors_data:
                    self._voltage_sensor_fs = VoltageSensorFs.factory(VoltageSensorFs, sensors_data['voltage_sensors'])
                if 'current_sensors' in sensors_data:
                    self._current_sensor_fs = CurrentSensorFs.factory(CurrentSensorFs, sensors_data['current_sensors'])
        except:
            # Sensors yaml file is not available
            pass

        self.voltage_updater = VoltageUpdater(self.chassis, self._voltage_sensor_fs)

        self.current_updater = CurrentUpdater(self.chassis, self._current_sensor_fs)


    # Override signal handler from DaemonBase
    def signal_handler(self, sig, frame):
        '''
        Signal handler
        :param sig: Signal number
        :param frame: not used
        :return:
        '''
        FATAL_SIGNALS = [signal.SIGINT, signal.SIGTERM]
        NONFATAL_SIGNALS = [signal.SIGHUP]

        global exit_code

        if sig in FATAL_SIGNALS:
            self.log_info("Caught signal '{}' - exiting...".format(signal.Signals(sig).name))
            exit_code = 128 + sig  # Make sure we exit with a non-zero code so that supervisor will try to restart us
            self.stop_event.set()
        elif sig in NONFATAL_SIGNALS:
            self.log_info("Caught signal '{}' - ignoring...".format(signal.Signals(sig).name))
        else:
            self.log_warning("Caught unhandled signal '{}' - ignoring...".format(signal.Signals(sig).name))

    # Main daemon logic
    def run(self):
        '''
        Run main logical of this daemon
        :return:
        '''
        if self.stop_event.wait(self.wait_time):
            # We received a fatal signal
            return False

        begin = time.time()

        self.voltage_updater.update()
        self.current_updater.update()

        elapsed = time.time() - begin
        if elapsed <  self.interval:
            self.wait_time = self.interval - elapsed
        else:
            self.wait_time = self.INITIAL_INTERVAL

        if elapsed > self.UPDATE_ELAPSED_THRESHOLD:
            self.log_warning('Sensors update took a long time : '
                                    '{} seconds'.format(elapsed))

        return True

#
# Main =========================================================================
#
def main():
    sensor_control = SensorMonitorDaemon()

    sensor_control.log_info("Starting up...")

    while sensor_control.run():
        pass

    sensor_control.log_info("Shutting down with exit code {}".format(exit_code))

    return exit_code


if __name__ == '__main__':
    sys.exit(main())
