#!python

import argparse
import click
import json
import os
import sys
import utilities_common.multi_asic as multi_asic_util

from collections import OrderedDict, namedtuple
from datetime import datetime, timezone, timedelta
from natsort import natsorted
from sonic_py_common import multi_asic
from swsscommon.swsscommon import APP_FABRIC_PORT_TABLE_NAME, COUNTERS_TABLE, COUNTERS_FABRIC_PORT_NAME_MAP, COUNTERS_FABRIC_QUEUE_NAME_MAP
from tabulate import tabulate
from utilities_common import constants
from utilities_common.cli import json_serial, UserCache
from utilities_common.netstat import format_number_with_comma, table_as_json, ns_diff, format_prate

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "1" or os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()
except KeyError:
    pass

PORT_NAME_PREFIX = 'PORT'
COUNTER_TABLE_PREFIX = COUNTERS_TABLE+":"
FABRIC_PORT_STATUS_TABLE_PREFIX = APP_FABRIC_PORT_TABLE_NAME+"|"
FABRIC_PORT_STATUS_FIELD = "STATUS"
STATUS_NA = 'N/A'

cnstat_dir = 'N/A'
cnstat_fqn_file_port = 'N/A'
cnstat_fqn_file_queue = 'N/A'

class FabricStat(object):
    def __init__(self, namespace):
        self.db = None
        self.namespace = namespace
        self.multi_asic = multi_asic_util.MultiAsic(constants.DISPLAY_ALL, namespace)

    def get_cnstat_dict(self):
        self.cnstat_dict = OrderedDict()
        self.collect_stat()
        return self.cnstat_dict

    @multi_asic_util.run_on_multi_asic
    def collect_stat(self):
        """
        Collect the statisitics from all the asics present on the
        device and store in a dict
        """
        self.cnstat_dict.update(self.get_cnstat())

    def get_port_state(self, port_name):
        """
        Get the port state
        """
        full_table_id = FABRIC_PORT_STATUS_TABLE_PREFIX + port_name
        oper_state = self.db.get(self.db.STATE_DB, full_table_id, FABRIC_PORT_STATUS_FIELD)
        if oper_state is not None:
            return oper_state
        return STATUS_NA

    def get_counters(self, counter_bucket_dict, table_id):
        fields = ["0"] * len(counter_bucket_dict)
        for pos, counter_name in counter_bucket_dict.items():
            full_table_id = COUNTER_TABLE_PREFIX + table_id
            counter_data = self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
            if counter_data is None:
                 fields[pos] = STATUS_NA
            elif fields[pos] != STATUS_NA:
                fields[pos] = str(int(fields[pos]) + int(counter_data))
        return fields

    def get_cnstat(self):
        """
        Get the counters info from database.
        """
        assert False, 'Need to override this method'

    def save_fresh_stats(self):
        """
        Get stat for each port and save.
        """
        assert False, 'Need to override this method'

    def cnstat_print(self, cnstat_dict, errors_only=False):
        """
        Print the counter stat.
        """
        assert False, 'Need to override this method'

PortStat = namedtuple("PortStat", "in_cell, in_octet, out_cell, out_octet,\
                               crc, fec_correctable, fec_uncorrectable, symbol_err")
port_counter_bucket_list = [
    'SAI_PORT_STAT_IF_IN_FABRIC_DATA_UNITS',
    'SAI_PORT_STAT_IF_IN_OCTETS',
    'SAI_PORT_STAT_IF_OUT_FABRIC_DATA_UNITS',
    'SAI_PORT_STAT_IF_OUT_OCTETS',
    'SAI_PORT_STAT_IF_IN_ERRORS',
    'SAI_PORT_STAT_IF_IN_FEC_CORRECTABLE_FRAMES',
    'SAI_PORT_STAT_IF_IN_FEC_NOT_CORRECTABLE_FRAMES',
    'SAI_PORT_STAT_IF_IN_FEC_SYMBOL_ERRORS',
    ]
port_counter_bucket_dict = {k : v for k, v in enumerate(port_counter_bucket_list)}

portstat_header_all = ['ASIC', 'PORT', 'STATE',
                       'IN_CELL', 'IN_OCTET', 'OUT_CELL', 'OUT_OCTET',
                       'CRC', 'FEC_CORRECTABLE', 'FEC_UNCORRECTABLE', 'SYMBOL_ERR']
portstat_header_errors_only = ['ASIC', 'PORT', 'STATE',
                               'CRC', 'FEC_CORRECTABLE', 'FEC_UNCORRECTABLE', 'SYMBOL_ERR']

class FabricPortStat(FabricStat):
    def get_cnstat(self):
        counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_FABRIC_PORT_NAME_MAP)
        cnstat_dict = OrderedDict()
        if counter_port_name_map is None:
            return cnstat_dict
        for port_name in natsorted(counter_port_name_map):
            cntr = self.get_counters(port_counter_bucket_dict, counter_port_name_map[port_name])
            cnstat_dict[port_name] = PortStat._make(cntr)
        return cnstat_dict

    def save_fresh_stats(self):
        # Get stat for each port and save
        counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_FABRIC_PORT_NAME_MAP)
        if counter_port_name_map is None:
            print("No counters require cleaning")
            return
        cnstat_dict = self.get_cnstat()
        asic_name = '0'
        if self.namespace:
            asic_name = multi_asic.get_asic_id_from_name(self.namespace)
        try:
            cnstat_fqn_file_port_name = cnstat_fqn_file_port + asic_name
            json.dump(cnstat_dict, open(cnstat_fqn_file_port_name, 'w'), default=json_serial)
        except IOError as e:
            print(e.errno, e)
            sys.exit(e.errno)
        else:
            print("Clear and update saved counters port")

    def cnstat_print(self, cnstat_dict, errors_only=False):
        if len(cnstat_dict) == 0:
            print("Counters %s empty" % self.namespace)
            return

        table = []
        header = None
        # Default ASIC name is 0 for single-ASIC systems. For multi-ASIC systems,
        # derive name from namespace.
        asic_name = '0'
        if self.namespace:
            asic_name = multi_asic.get_asic_id_from_name(self.namespace)

        cnstat_fqn_file_port_name = cnstat_fqn_file_port + asic_name
        cnstat_cached_dict = {}
        if os.path.isfile(cnstat_fqn_file_port_name):
            try:
                cnstat_cached_dict = json.load(open(cnstat_fqn_file_port_name, 'r'))
            except IOError as e:
                print(e.errno, e)

        for key, data in cnstat_dict.items():
            port_id = key[len(PORT_NAME_PREFIX):]
            port_name = "PORT" + port_id
            # The content in the for each port:
            # "IN_CELL, IN_OCTET, OUT_CELL, OUT_OCTET, CRC, FEC_CORRECTABLE, FEC_UNCORRECTABL, SYMBOL_ERR"
            # e.g. PORT76 ['0', '0', '36', '6669', '0', '13', '302626', '3']
            # Now, set default saved values to 0
            diff_cached = ['0', '0', '0', '0', '0', '0', '0', '0']
            if port_name in cnstat_cached_dict:
                diff_cached = cnstat_cached_dict.get(port_name)

            if errors_only:
                header = portstat_header_errors_only
                table.append((asic_name, port_id, self.get_port_state(key),
                              ns_diff(data.crc, diff_cached[4]),
                              ns_diff(data.fec_correctable, diff_cached[5]),
                              ns_diff(data.fec_uncorrectable, diff_cached[6]),
                              ns_diff(data.symbol_err, diff_cached[7])))
            else:
                header = portstat_header_all
                table.append((asic_name, port_id, self.get_port_state(key),
                              ns_diff(data.in_cell, diff_cached[0]),
                              ns_diff(data.in_octet, diff_cached[1]),
                              ns_diff(data.out_cell, diff_cached[2]),
                              ns_diff(data.out_octet, diff_cached[3]),
                              ns_diff(data.crc, diff_cached[4]),
                              ns_diff(data.fec_correctable, diff_cached[5]),
                              ns_diff(data.fec_uncorrectable, diff_cached[6]),
                              ns_diff(data.symbol_err, diff_cached[7])))

        print(tabulate(table, header, tablefmt='simple', stralign='right'))
        print()

QueueStat = namedtuple("QueueStat", "curlevel, watermarklevel, curbyte")

queue_counter_bucket_list = [
    'SAI_QUEUE_STAT_CURR_OCCUPANCY_LEVEL',
    'SAI_QUEUE_STAT_WATERMARK_LEVEL',
    'SAI_QUEUE_STAT_CURR_OCCUPANCY_BYTES',
]
queue_counter_bucket_dict = {k : v for k, v in enumerate(queue_counter_bucket_list)}

queuestat_header = ['ASIC', 'PORT', 'STATE', 'QUEUE_ID', 'CURRENT_BYTE', 'CURRENT_LEVEL', 'WATERMARK_LEVEL']

class FabricQueueStat(FabricStat):
    def get_cnstat(self):
        counter_queue_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_FABRIC_QUEUE_NAME_MAP)
        cnstat_dict = OrderedDict()
        if counter_queue_name_map is None:
            return cnstat_dict
        for port_queue_name in natsorted(counter_queue_name_map):
            cntr = self.get_counters(queue_counter_bucket_dict, counter_queue_name_map[port_queue_name])
            cnstat_dict[port_queue_name] = QueueStat._make(cntr)
        return cnstat_dict

    def save_fresh_stats(self):
        # Get stat for each port and save
        counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_FABRIC_PORT_NAME_MAP)
        if counter_port_name_map is None:
            print("No counters require cleaning")
            return
        cnstat_dict = self.get_cnstat()
        asic_name = '0'
        if self.namespace:
            asic_name = multi_asic.get_asic_id_from_name(self.namespace)
        try:
            cnstat_fqn_file_queue_name = cnstat_fqn_file_queue + asic_name
            json.dump(cnstat_dict, open(cnstat_fqn_file_queue_name, 'w'), default=json_serial)
        except IOError as e:
            print(e.errno, e)
            sys.exit(e.errno)
        else:
            print("Clear and update saved counters queue")

    def cnstat_print(self, cnstat_dict, errors_only=False):
        if len(cnstat_dict) == 0:
            print("Counters %s empty" % self.namespace)
            return

        table = []
        # Default ASIC name is 0 for single-ASIC systems. For multi-ASIC systems,
        # derive name from namespace.
        asic_name = '0'
        if self.namespace:
            asic_name = multi_asic.get_asic_id_from_name(self.namespace)

        cnstat_fqn_file_queue_name = cnstat_fqn_file_queue + asic_name
        cnstat_cached_dict={}
        if os.path.isfile(cnstat_fqn_file_queue_name):
            try:
                cnstat_cached_dict = json.load(open(cnstat_fqn_file_queue_name, 'r'))
            except IOError as e:
                print(e.errno, e)

        for key, data in cnstat_dict.items():
            port_name, queue_id = key.split(':')
            # The content of saved counters queue for each port:
            # portName:queueId CURRENT_LEVEL, WATERMARK_LEVEL, CURRENT_BYTE
            # e.g. PORT90:0 ['N/A', 'N/A', 'N/A']
            # Now, set default saved values to 0
            diff_cached = ['0', '0', '0']
            if key in cnstat_cached_dict:
                diff_cached = cnstat_cached_dict.get(key)
            port_id = port_name[len(PORT_NAME_PREFIX):]
            table.append((asic_name, port_id, self.get_port_state(port_name), queue_id,
                          ns_diff(data.curbyte, diff_cached[2]),
                          ns_diff(data.curlevel, diff_cached[0]),
                          ns_diff(data.watermarklevel, diff_cached[1])))

        print(tabulate(table, queuestat_header, tablefmt='simple', stralign='right'))
        print()

class FabricCapacity(FabricStat):
    def __init__(self, namespace, table_cnt, threshold):
        self.db = None
        self.namespace = namespace
        self.multi_asic = multi_asic_util.MultiAsic(constants.DISPLAY_ALL, namespace)
        self.table_cnt = table_cnt
        self.threshold = threshold

    def capacity_print(self):
        # Connect to database
        self.db = multi_asic.connect_to_all_dbs_for_ns(self.namespace)
        # Get fabric capacity data from STATE_DB table FABRIC_CAPACITY_TABLE
        # and store them in fabric_capacity_data
        fabric_capacity_data = self.db.get_all(self.db.STATE_DB, "FABRIC_CAPACITY_TABLE|FABRIC_CAPACITY_DATA")
        operational_fap_capacity = 0
        operational_fabric_capacity = 0
        operational_fabric_links = 0;
        total_fabric_links = 0;
        ratio = 0
        last_event = "None"
        last_time = "Never"

        # Get data from fabric_capacity_data
        if "fabric_capacity" in fabric_capacity_data:
            operational_fabric_capacity = int(fabric_capacity_data['fabric_capacity'])
            operational_fabric_capacity = operational_fabric_capacity/1000.0
        if "number_of_links" in fabric_capacity_data:
            total_fabric_links = int(fabric_capacity_data['number_of_links'])
        if "operating_links" in fabric_capacity_data:
            operational_fabric_links = int(fabric_capacity_data['operating_links'])
        if "warning_threshold" in fabric_capacity_data:
            th = fabric_capacity_data['warning_threshold']
            th = th + "%"
            self.threshold.append(th)
        if "last_event" in fabric_capacity_data:
            last_event = fabric_capacity_data['last_event']
        if "last_event_time" in fabric_capacity_data:
            last_time = fabric_capacity_data['last_event_time']

        # Calculate the ratio of number of operational links and all links
        if total_fabric_links > 0:
            ratio = operational_fabric_links/total_fabric_links*100

        if last_time != "Never":
            dt = datetime.fromtimestamp(int(last_time), timezone.utc)
            td = datetime.now(timezone.utc) - dt
            td_without_ms = timedelta(seconds=td.seconds)
            last_time = str(td_without_ms) +" ago"

        asic_name = "asic0"
        if self.namespace:
            asic_name = self.namespace

        # Update the table to print
        self.table_cnt.append((asic_name, operational_fabric_links, total_fabric_links, ratio, last_event, last_time))

class FabricReachability(FabricStat):
    def reachability_print(self):
        # Connect to database
        self.db = multi_asic.connect_to_all_dbs_for_ns(self.namespace)
        # Get the set of all fabric port keys
        port_keys = self.db.keys(self.db.STATE_DB, FABRIC_PORT_STATUS_TABLE_PREFIX + '*')
        # Create a new dictionary. The key values are the local port values
        # in integer format. Only ports that have remote port data are added.
        # Only ports that are "up" will be connected to a remote peer.
        port_dict = {}
        for port_key in port_keys:
           port_data = self.db.get_all(self.db.STATE_DB, port_key)
           if "REMOTE_PORT" in port_data:
              port_number = int(port_key.replace("FABRIC_PORT_TABLE|PORT", ""))
              port_dict.update({port_number: port_data})
        # Create ordered table of port data
        header = ["Local Link", "Remote Module", "Remote Link", "Status"]
        body = []
        for port_number in sorted(port_dict.keys()):
              port_data = port_dict[port_number]
              body.append((port_number, port_data["REMOTE_MOD"], \
                           port_data["REMOTE_PORT"], port_data["STATUS"]))
        if self.namespace:
           print(f"\n{self.namespace}")
        print(tabulate(body, header, tablefmt='simple', stralign='right'))
        return

class FabricIsolation(FabricStat):
    def isolation_print(self):
        # Connect to database
        self.db = multi_asic.connect_to_all_dbs_for_ns(self.namespace)
        # Get the set of all fabric ports
        port_keys = self.db.keys(self.db.STATE_DB, FABRIC_PORT_STATUS_TABLE_PREFIX + '*')
        # Create a new dictionary. The keys are the local port values in integer format.
        # Only fabric ports that have remote port data are added.
        port_dict = {}
        for port_key in port_keys:
            port_data = self.db.get_all(self.db.STATE_DB, port_key)
            if "REMOTE_PORT" in port_data:
                port_number = int(port_key.replace("FABRIC_PORT_TABLE|PORT", ""))
                port_dict.update({port_number: port_data})
        # Create ordered table of fabric ports.
        header = ["Local Link", "Auto Isolated", "Manual Isolated", "Isolated"]
        auto_isolated = 0
        manual_isolated = 0
        isolated = 0
        body = []
        for port_number in sorted(port_dict.keys()):
            port_data = port_dict[port_number]
            if "AUTO_ISOLATED" in port_data:
                auto_isolated = port_data["AUTO_ISOLATED"]
            if "CONFIG_ISOLATED" in port_data:
                manual_isolated = port_data["CONFIG_ISOLATED"]
            if "ISOLATED" in port_data:
                isolated = port_data["ISOLATED"]
            body.append((port_number, auto_isolated, manual_isolated, isolated));
        if self.namespace:
            print(f"\n{self.namespace}")
        print(tabulate(body, header, tablefmt='simple', stralign='right'))
        return

class FabricRate(FabricStat):
    def rate_print(self):
        # Connect to database
        self.db = multi_asic.connect_to_all_dbs_for_ns(self.namespace)
        # Get the set of all fabric ports
        port_keys = self.db.keys(self.db.STATE_DB, FABRIC_PORT_STATUS_TABLE_PREFIX + '*')
        # Create a new dictionary. The keys are the local port values in integer format.
        # Only fabric ports that have remote port data are added.
        port_dict = {}
        for port_key in port_keys:
            port_data = self.db.get_all(self.db.STATE_DB, port_key)
            port_number = int(port_key.replace("FABRIC_PORT_TABLE|PORT", ""))
            port_dict.update({port_number: port_data})
        # Create ordered table of fabric ports.
        rxRate = 0
        rxData = 0
        txRate = 0
        txData = 0
        time = 0
        local_time = ""
        # RX data , Tx data , Time are for testing
        asic = "asic0"
        if self.namespace:
            asic = self.namespace
        header = ["ASIC", "Link ID", "Rx Data Mbps", "Tx Data Mbps"]
        body = []
        for port_number in sorted(port_dict.keys()):
            port_data = port_dict[port_number]
            if "OLD_RX_RATE_AVG" in port_data:
                rxRate = port_data["OLD_RX_RATE_AVG"]
            if "OLD_RX_DATA" in port_data:
                rxData = port_data["OLD_RX_DATA"]
            if "OLD_TX_RATE_AVG" in port_data:
                txRate = port_data["OLD_TX_RATE_AVG"]
            if "OLD_TX_DATA" in port_data:
                txData = port_data["OLD_TX_DATA"]
            if "LAST_TIME" in port_data:
                time = int(port_data["LAST_TIME"])
                local_time = datetime.fromtimestamp(time)
            body.append((asic, port_number, rxRate, txRate));
        click.echo()
        click.echo(tabulate(body, header, tablefmt='simple', stralign='right'))

def main():
    global cnstat_dir
    global cnstat_fqn_file_port
    global cnstat_fqn_file_queue

    parser  = argparse.ArgumentParser(description='Display the fabric port state and counters',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Examples:
    fabricstat
    fabricstat --namespace asic0
    fabricstat -p -n asic0 -e
    fabricstat -q
    fabricstat -q -n asic0
    fabricstat -c
    fabricstat -c -n asic0
    fabricstat -s
    fabricstat -s -n asic0
    fabricstat -C
    fabricstat -D
""")

    parser.add_argument('-q','--queue', action='store_true', help='Display fabric queue stat, otherwise port stat')
    parser.add_argument('-r','--reachability', action='store_true', help='Display reachability, otherwise port stat')
    parser.add_argument('-n','--namespace', default=None, help='Display fabric ports counters for specific namespace')
    parser.add_argument('-e', '--errors', action='store_true', help='Display errors')
    parser.add_argument('-c','--capacity',action='store_true', help='Display fabric capacity')
    parser.add_argument('-i','--isolation', action='store_true', help='Display fabric ports isolation status')
    parser.add_argument('-s','--rate', action='store_true', help='Display fabric counters rate')
    parser.add_argument('-C','--clear', action='store_true', help='Copy & clear fabric counters')
    parser.add_argument('-D','--delete', action='store_true', help='Delete saved stats')

    args = parser.parse_args()
    queue = args.queue
    reachability = args.reachability
    capacity_status = args.capacity
    isolation_status = args.isolation
    rate = args.rate
    namespace = args.namespace
    errors_only = args.errors

    save_fresh_stats = args.clear
    delete_stats = args.delete

    cache = UserCache()

    cnstat_dir = cache.get_directory()

    cnstat_file = "fabricstatport"
    cnstat_fqn_file_port = os.path.join(cnstat_dir, cnstat_file)

    cnstat_file = "fabricstatqueue"
    cnstat_fqn_file_queue = os.path.join(cnstat_dir, cnstat_file)

    if delete_stats:
        cache.remove()
        sys.exit(0)

    def nsStat(ns, errors_only):
        if queue:
            stat = FabricQueueStat(ns)
        elif reachability:
            stat = FabricReachability(ns)
            stat.reachability_print()
            return
        elif isolation_status:
            stat = FabricIsolation(ns)
            stat.isolation_print()
            return
        elif rate:
            stat = FabricRate(ns)
            stat.rate_print()
            return
        else:
            stat = FabricPortStat(ns)
        cnstat_dict = stat.get_cnstat_dict()
        if save_fresh_stats:
            stat.save_fresh_stats()
        else:
            stat.cnstat_print(cnstat_dict, errors_only)

    if capacity_status:
        # show fabric capacity command
        capacity_header = []
        table_cnt = []
        threshold = []
        capacity_header = ["ASIC", "Operating\nLinks", "Total #\nof Links", "%", "Last Event", "Last Time"]
        if namespace is None:
            # All asics or all fabric asics
            multi_asic = multi_asic_util.MultiAsic()
            for ns in multi_asic.get_ns_list_based_on_options():
                stat = FabricCapacity(ns, table_cnt, threshold)
                stat.capacity_print()
        else:
            # Asic with namespace
            stat = FabricCapacity(namespace, table_cnt, threshold)
            stat.capacity_print()

        print_th = ""
        if threshold:
            print_th = threshold[0]
        click.echo("Monitored fabric capacity threshold: {}".format(print_th))
        click.echo()
        click.echo(tabulate(table_cnt, capacity_header, tablefmt='simple', stralign='right'))
    else:
        # other show fabric commands
        if namespace is None:
            # All asics or all fabric asics
            multi_asic = multi_asic_util.MultiAsic()
            for ns in multi_asic.get_ns_list_based_on_options():
                nsStat(ns, errors_only)
        else:
            # Asic with namespace
            nsStat(namespace, errors_only)

if __name__ == "__main__":
    main()
