#!python

#####################################################################
#
# pfcstat is a tool for summarizing Priority-based Flow Control (PFC) statistics.
#
#####################################################################

import json
import argparse
import datetime
import os.path
import sys

from collections import namedtuple, OrderedDict
from copy import deepcopy
from natsort import natsorted
from tabulate import tabulate

from sonic_py_common.multi_asic import get_external_ports

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()

except KeyError:
    pass

from utilities_common.netstat import ns_diff, STATUS_NA, format_number_with_comma
from utilities_common import multi_asic as multi_asic_util
from utilities_common import constants
from utilities_common.cli import json_serial, UserCache


PStats = namedtuple("PStats", "pfc0, pfc1, pfc2, pfc3, pfc4, pfc5, pfc6, pfc7")
header_Rx = ['Port Rx', 'PFC0', 'PFC1', 'PFC2', 'PFC3', 'PFC4', 'PFC5', 'PFC6', 'PFC7']

header_Tx = ['Port Tx', 'PFC0', 'PFC1', 'PFC2', 'PFC3', 'PFC4', 'PFC5', 'PFC6', 'PFC7']

counter_bucket_rx_dict = {
    'SAI_PORT_STAT_PFC_0_RX_PKTS': 0,
    'SAI_PORT_STAT_PFC_1_RX_PKTS': 1,
    'SAI_PORT_STAT_PFC_2_RX_PKTS': 2,
    'SAI_PORT_STAT_PFC_3_RX_PKTS': 3,
    'SAI_PORT_STAT_PFC_4_RX_PKTS': 4,
    'SAI_PORT_STAT_PFC_5_RX_PKTS': 5,
    'SAI_PORT_STAT_PFC_6_RX_PKTS': 6,
    'SAI_PORT_STAT_PFC_7_RX_PKTS': 7
}

counter_bucket_tx_dict = {
    'SAI_PORT_STAT_PFC_0_TX_PKTS': 0,
    'SAI_PORT_STAT_PFC_1_TX_PKTS': 1,
    'SAI_PORT_STAT_PFC_2_TX_PKTS': 2,
    'SAI_PORT_STAT_PFC_3_TX_PKTS': 3,
    'SAI_PORT_STAT_PFC_4_TX_PKTS': 4,
    'SAI_PORT_STAT_PFC_5_TX_PKTS': 5,
    'SAI_PORT_STAT_PFC_6_TX_PKTS': 6,
    'SAI_PORT_STAT_PFC_7_TX_PKTS': 7
}


COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"

class Pfcstat(object):
    def __init__(self, namespace, display):
        self.multi_asic = multi_asic_util.MultiAsic(display, namespace)
        self.db = None
        self.config_db = None
        self.cnstat_dict = OrderedDict()

    @multi_asic_util.run_on_multi_asic
    def collect_cnstat(self, rx):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            fields = ["0","0","0","0","0","0","0","0"]
            if rx:
                bucket_dict = counter_bucket_rx_dict
            else:
                bucket_dict = counter_bucket_tx_dict
            for counter_name, pos in bucket_dict.items():
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(
                    self.db.COUNTERS_DB, full_table_id, counter_name
                )
                if counter_data is None:
                    fields[pos] = STATUS_NA
                else:
                    fields[pos] = str(int(counter_data))
            cntr = PStats._make(fields)._asdict()
            return cntr

        # Get the info from database
        counter_port_name_map = self.db.get_all(
            self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP
        )
        if counter_port_name_map is None:
            return
        display_ports_set = set(counter_port_name_map.keys())
        if self.multi_asic.display_option == constants.DISPLAY_EXTERNAL:
            display_ports_set = get_external_ports(
                display_ports_set, self.multi_asic.current_namespace
            )
        # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()
        if counter_port_name_map is not None:
            for port in natsorted(counter_port_name_map):
                if port in display_ports_set:
                    cnstat_dict[port] = get_counters(
                        counter_port_name_map[port]
                    )
            self.cnstat_dict.update(cnstat_dict)

    def get_cnstat(self, rx):
        """
            Get the counters info from database.
        """
        self.cnstat_dict.clear()
        self.collect_cnstat(rx)
        return self.cnstat_dict

    def cnstat_print(self, cnstat_dict, rx):
        """
            Print the cnstat.
        """
        table = []

        for key, data in cnstat_dict.items():
            if key == 'time':
                continue
            table.append((key,
                          format_number_with_comma(data['pfc0']),
                          format_number_with_comma(data['pfc1']),
                          format_number_with_comma(data['pfc2']),
                          format_number_with_comma(data['pfc3']),
                          format_number_with_comma(data['pfc4']),
                          format_number_with_comma(data['pfc5']),
                          format_number_with_comma(data['pfc6']),
                          format_number_with_comma(data['pfc7'])))

        if rx:
            print(tabulate(table, header_Rx, tablefmt='simple', stralign='right'))
        else:
            print(tabulate(table, header_Tx, tablefmt='simple', stralign='right'))

    def cnstat_diff_print(self, cnstat_new_dict, cnstat_old_dict, rx):
        """
            Print the difference between two cnstat results.
        """
        table = []

        for key, cntr in cnstat_new_dict.items():
            if key == 'time':
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            if old_cntr is not None:
                table.append((key,
                            ns_diff(cntr['pfc0'], old_cntr['pfc0']),
                            ns_diff(cntr['pfc1'], old_cntr['pfc1']),
                            ns_diff(cntr['pfc2'], old_cntr['pfc2']),
                            ns_diff(cntr['pfc3'], old_cntr['pfc3']),
                            ns_diff(cntr['pfc4'], old_cntr['pfc4']),
                            ns_diff(cntr['pfc5'], old_cntr['pfc5']),
                            ns_diff(cntr['pfc6'], old_cntr['pfc6']),
                            ns_diff(cntr['pfc7'], old_cntr['pfc7'])))
            else:
                table.append((key,
                              format_number_with_comma(cntr['pfc0']),
                              format_number_with_comma(cntr['pfc1']),
                              format_number_with_comma(cntr['pfc2']),
                              format_number_with_comma(cntr['pfc3']),
                              format_number_with_comma(cntr['pfc4']),
                              format_number_with_comma(cntr['pfc5']),
                              format_number_with_comma(cntr['pfc6']),
                              format_number_with_comma(cntr['pfc7'])))

        if rx:
            print(tabulate(table, header_Rx, tablefmt='simple', stralign='right'))
        else:
            print(tabulate(table, header_Tx, tablefmt='simple', stralign='right'))

def main():
    parser  = argparse.ArgumentParser(description='Display the pfc counters',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Examples:
  pfcstat
  pfcstat -c
  pfcstat -d
  pfcstat -n asic1
  pfcstat -s all -n asic0
""")

    parser.add_argument( '-c', '--clear', action='store_true',
        help='Clear previous stats and save new ones'
    )
    parser.add_argument(
        '-d', '--delete', action='store_true', help='Delete saved stats'
    )
    parser.add_argument('-s', '--show', default=constants.DISPLAY_EXTERNAL,
        help='Display all interfaces or only external interfaces'
    )
    parser.add_argument('-n', '--namespace', default=None,
        help='Display interfaces for specific namespace'
    )
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_all_stats = args.delete

    cache = UserCache()
    cnstat_file = 'pfcstat'

    cnstat_dir = cache.get_directory()
    cnstat_fqn_file_rx = os.path.join(cnstat_dir, "{}rx".format(cnstat_file))
    cnstat_fqn_file_tx = os.path.join(cnstat_dir, "{}tx".format(cnstat_file))

    # if '-c' option is provided get stats from all (frontend and backend) ports
    if save_fresh_stats:
        args.namespace = None
        args.show = constants.DISPLAY_ALL

    pfcstat = Pfcstat(args.namespace, args.show)

    if delete_all_stats:
        cache.remove()

    """
        Get the counters of pfc rx counter
    """
    cnstat_dict_rx = deepcopy(pfcstat.get_cnstat(True))

    """
        Get the counters of pfc tx counter
    """
    cnstat_dict_tx = deepcopy(pfcstat.get_cnstat(False))

    if save_fresh_stats:
        try:
            json.dump(cnstat_dict_rx, open(cnstat_fqn_file_rx, 'w'), default=json_serial)
            json.dump(cnstat_dict_tx, open(cnstat_fqn_file_tx, 'w'), default=json_serial)
        except IOError as e:
            print(e.errno, e)
            sys.exit(e.errno)
        else:
            print("Clear saved counters")
            sys.exit(0)


    """
        Print the counters of pfc rx counter
    """
    if os.path.isfile(cnstat_fqn_file_rx):
        try:
            cnstat_cached_dict = json.load(open(cnstat_fqn_file_rx, 'r'))
            print("Last cached time was " + str(cnstat_cached_dict.get('time')))
            pfcstat.cnstat_diff_print(cnstat_dict_rx, cnstat_cached_dict, True)
        except IOError as e:
            print(e.errno, e)
    else:
        pfcstat.cnstat_print(cnstat_dict_rx, True)

    print("")

    """
        Print the counters of pfc tx counter
    """
    if os.path.isfile(cnstat_fqn_file_tx):
        try:
            cnstat_cached_dict = json.load(open(cnstat_fqn_file_tx, 'r'))
            print("Last cached time was " + str(cnstat_cached_dict.get('time')))
            pfcstat.cnstat_diff_print(cnstat_dict_tx, cnstat_cached_dict, False)
        except IOError as e:
            print(e.errno, e)
    else:
        pfcstat.cnstat_print(cnstat_dict_tx, False)

    sys.exit(0)

if __name__ == "__main__":
    main()
