#!python

#####################################################################
#
# pg-drop is a tool for show/clear ingress pg dropped packet stats.
#
#####################################################################
from importlib import reload
import json
import argparse
import os
import sys
from collections import OrderedDict

from natsort import natsorted
from tabulate import tabulate
from utilities_common.general import load_db_config
from sonic_py_common import multi_asic

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()
except KeyError:
    pass

from utilities_common.cli import UserCache
from swsscommon.swsscommon import ConfigDBConnector, SonicV2Connector

STATUS_NA = 'N/A'

COUNTER_TABLE_PREFIX = "COUNTERS:"

COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"
COUNTERS_PG_NAME_MAP = "COUNTERS_PG_NAME_MAP"
COUNTERS_PG_PORT_MAP = "COUNTERS_PG_PORT_MAP"
COUNTERS_PG_INDEX_MAP = "COUNTERS_PG_INDEX_MAP"

def get_dropstat_dir():
    return UserCache().get_directory()

class PgDropStat(object):

    def __init__(self, namespace):
        self.namespace = namespace
        self.ns_list = multi_asic.get_namespace_list(namespace)
        self.configdb = ConfigDBConnector(namespace=namespace)
        self.configdb.connect()
        dropstat_dir = get_dropstat_dir()
        self.port_drop_stats_file = os.path.join(dropstat_dir, 'pg_drop_stats')

        def get_port_id(oid, namespace):
            """
                Get port ID using pg object ID
            """
            port_id = self.get_counters_mapdata(COUNTERS_PG_PORT_MAP, oid, namespace)
            if not port_id:
                print("Port is not available for oid '{}'".format(oid))
                sys.exit(1)
            return port_id

        # Get all ports
        self.counter_port_name_map = self.get_counters_mapall(COUNTERS_PORT_NAME_MAP)
        if not self.counter_port_name_map:
            print("COUNTERS_PORT_NAME_MAP is empty!")
            sys.exit(1)

        self.port_pg_map = {}
        self.port_name_map = {}

        for ns in self.ns_list:
            self.port_name_map[ns] = {}
            for port in self.counter_port_name_map[ns]:
                self.port_pg_map[port] = {}
                self.port_name_map[ns][self.counter_port_name_map[ns][port]] = port

        # Get PGs for each port
        counter_pg_name_map = self.get_counters_mapall(COUNTERS_PG_NAME_MAP)
        if not counter_pg_name_map:
            print("COUNTERS_PG_NAME_MAP is empty!")
            sys.exit(1)

        for ns in self.ns_list:
            for pg in counter_pg_name_map[ns]:
                port = self.port_name_map[ns][get_port_id(counter_pg_name_map[ns][pg], ns)]
                self.port_pg_map[port][pg] = counter_pg_name_map[ns][pg]

        self.pg_drop_types = {
            "pg_drop"       : {"message" : "Ingress PG dropped packets:",
                               "obj_map" : self.port_pg_map,
                               "idx_func": self.get_pg_index,
                               "counter_name" : "SAI_INGRESS_PRIORITY_GROUP_STAT_DROPPED_PACKETS",
                               "header_prefix": "PG"},
        }

    def get_counters_mapdata(self, tablemap, index, namespace):
        counters_db = SonicV2Connector(namespace=namespace)
        counters_db.connect(counters_db.COUNTERS_DB)
        data = counters_db.get(counters_db.COUNTERS_DB, tablemap, index)
        return data

    def get_counters_mapall(self, tablemap):
        mapdata = {}
        for ns in self.ns_list:
            counters_db = SonicV2Connector(namespace=ns)
            counters_db.connect(counters_db.COUNTERS_DB)
            map_result = counters_db.get_all(counters_db.COUNTERS_DB, tablemap)
            if map_result:
                mapdata[ns] = map_result
        return mapdata

    def get_pg_index(self, oid, namespace):
        """
            return PG index (0-7)

            oid - object ID for entry in redis
        """
        pg_index = self.get_counters_mapdata(COUNTERS_PG_INDEX_MAP, oid, namespace)
        if not pg_index:
            print("Priority group index is not available for oid '{}'".format(oid))
            sys.exit(1)
        return pg_index

    def build_header(self, pg_drop_type):
        """
            Construct header for table with PG counters
        """
        if pg_drop_type is None:
            print("Header info is not available!")
            sys.exit(1)

        self.header_list = ['Port']
        header_map = pg_drop_type["obj_map"]

        header_idx_set = set()
        for port in header_map.keys():
            for element in header_map[port].keys():
                element_idx = int(element.split(':')[1])
                header_idx_set.add(element_idx)

        if len(header_idx_set) == 0:
            print("Header info is not available!")
            sys.exit(1)

        header_idx_list = list(header_idx_set)
        header_idx_list.sort()

        self.header_idx_to_pos = {}
        for i in header_idx_list:
            self.header_idx_to_pos[i] = header_idx_list.index(i)

        self.min_idx = header_idx_list[0]
        self.header_list += ["{}{}".format(pg_drop_type["header_prefix"], idx) for idx in header_idx_list]

    def get_counters(self, table_prefix, port_obj, idx_func, counter_name, namespace):
        """
            Get the counters of a specific table.
        """
        port_drop_ckpt = {}
        # Grab the latest clear checkpoint, if it exists
        if os.path.isfile(self.port_drop_stats_file):
            port_drop_ckpt = json.load(open(self.port_drop_stats_file, 'r'))

        # Header list contains the port name followed by the PGs. Fields is used to populate the pg values
        fields = ["0"]* (len(self.header_list) - 1)

        for name, obj_id in port_obj.items():
            full_table_id = table_prefix + obj_id
            old_collected_data = port_drop_ckpt.get(name,{})[full_table_id] if len(port_drop_ckpt) > 0 else 0
            idx = int(idx_func(obj_id, namespace))
            pos = self.header_idx_to_pos[idx]
            counter_data = self.get_counters_mapdata(full_table_id, counter_name, namespace)
            if counter_data is None:
                fields[pos] = STATUS_NA
            elif fields[pos] != STATUS_NA:
                fields[pos] = str(int(counter_data) -  old_collected_data)
        return fields

    def print_all_stat(self, table_prefix, key):
        """
            Print table that show stats per PG
        """
        table = []
        type = self.pg_drop_types[key]
        self.build_header(type)
        # Get stat for each port
        for ns in self.ns_list:
            for port in natsorted(self.counter_port_name_map[ns]):
                row_data = list()
                data = self.get_counters(table_prefix, type["obj_map"][port], type["idx_func"], type["counter_name"], ns)
                row_data.append(port)
                row_data.extend(data)
                table.append(tuple(row_data))

        print(type["message"])
        print(tabulate(table, self.header_list, tablefmt='simple', stralign='right'))

    def get_counts(self, counters, oid, namespace):
        """
            Get the PG drop counts for an individual counter.
        """
        counts = {}
        table_id = COUNTER_TABLE_PREFIX + oid
        for counter in counters:
            counter_data = self.get_counters_mapdata(table_id, counter, namespace)
            if counter_data is None:
                counts[table_id] = 0
            else:
                counts[table_id] = int(counter_data)
        return counts

    def get_counts_table(self, counters, object_table):
        """
            Returns a dictionary containing a mapping from an object (like a port)
            to its PG drop counts. Counts are contained in a dictionary that maps
            counter oid to its counts.
        """
        counter_object_name_map = self.get_counters_mapall(object_table)
        current_stat_dict = OrderedDict()

        if not counter_object_name_map:
            return current_stat_dict

        for ns in self.ns_list:
            for obj in natsorted(counter_object_name_map[ns]):
                current_stat_dict[obj] = self.get_counts(counters, counter_object_name_map[ns][obj], ns)
        return current_stat_dict

    def clear_drop_counts(self):
        """
            Clears the current PG drop counter.
        """

        counter_pg_drop_array = [ "SAI_INGRESS_PRIORITY_GROUP_STAT_DROPPED_PACKETS"]
        try:
            json.dump(self.get_counts_table(
                counter_pg_drop_array,
                COUNTERS_PG_NAME_MAP),
                open(self.port_drop_stats_file, 'w+'))
        except IOError as e:
            print(e)
            sys.exit(e.errno)
        print("Cleared PG drop counter")

    def check_if_stats_enabled(self):
        pg_drop_info = self.configdb.get_entry('FLEX_COUNTER_TABLE', 'PG_DROP')
        if pg_drop_info:
            status = pg_drop_info.get("FLEX_COUNTER_STATUS", 'disable')
            if status == "disable":
                print("Warning: PG counters are disabled. Use 'counterpoll pg-drop enable' to enable polling")
                sys.exit(0)

def main():
    parser = argparse.ArgumentParser(description='Display PG drop counter',
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog="""
Examples:
pg-drop -c show
pg-drop -c show --namespace asic0
pg-drop -c clear
""")

    parser.add_argument('-c', '--command', type=str, help='Desired action to perform')
    parser.add_argument('-n', '--namespace', type=str, help='Namespace name or skip for all', default=None)

    args = parser.parse_args()
    command = args.command

    dropstat_dir = get_dropstat_dir()
    # Create the directory to hold clear results
    if not os.path.exists(dropstat_dir):
        try:
            os.makedirs(dropstat_dir)
        except IOError as e:
            print(e)
            sys.exit(e.errno)

    # Load database config files
    load_db_config()
    namespaces = multi_asic.get_namespace_list()
    if args.namespace and args.namespace not in namespaces:
        namespacelist = ', '.join(namespaces)
        print(f"Input value for '--namespace' / '-n'. Choose from one of ({namespacelist})")
        sys.exit(1)

    # For 'clear' command force applying to all namespaces
    pgdropstat = PgDropStat(args.namespace if command != 'clear' else None)

    if command == 'clear':
        pgdropstat.clear_drop_counts()
    elif command == 'show':
        pgdropstat.check_if_stats_enabled()
        pgdropstat.print_all_stat(COUNTER_TABLE_PREFIX, "pg_drop" )
    else:
        print("Command not recognized")
    sys.exit(0)


if __name__ == "__main__":
    main()
