#!python

#####################################################################
#
# wredstat is a tool for summarizing wred queue statistics of all ports.
#
#####################################################################

import json
import argparse
import datetime
import os.path
import sys

from collections import namedtuple, OrderedDict
from natsort import natsorted
from tabulate import tabulate
from sonic_py_common import multi_asic

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector # lgtm [py/unused-import]
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()

except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector
from utilities_common.cli import json_serial, UserCache
from utilities_common import constants
import utilities_common.multi_asic as multi_asic_util
from utilities_common.cli import json_dump
from utilities_common.netstat import ns_diff, STATUS_NA

QueueStats = namedtuple("QueueStats", "queueindex, queuetype, wredDrppacket, wredDrpbytes, ecnpacket, ecnbytes")
header = ['Port', 'TxQ', 'WredDrp/pkts', 'WredDrp/bytes', 'EcnMarked/pkts', 'EcnMarked/bytes']
voq_header = ['Port', 'Voq', 'WredDrp/pkts', 'WredDrp/bytes', 'EcnMarked/pkts', 'EcnMarked/bytes']

counter_bucket_dict = {
    'SAI_QUEUE_STAT_WRED_DROPPED_PACKETS': 2,
    'SAI_QUEUE_STAT_WRED_DROPPED_BYTES': 3,
    'SAI_QUEUE_STAT_WRED_ECN_MARKED_PACKETS': 4,
    'SAI_QUEUE_STAT_WRED_ECN_MARKED_BYTES': 5,
}

QUEUE_TYPE_MC = 'MC'
QUEUE_TYPE_UC = 'UC'
QUEUE_TYPE_ALL = 'ALL'
QUEUE_TYPE_VOQ = 'VOQ'
SAI_QUEUE_TYPE_MULTICAST = "SAI_QUEUE_TYPE_MULTICAST"
SAI_QUEUE_TYPE_UNICAST = "SAI_QUEUE_TYPE_UNICAST"
SAI_QUEUE_TYPE_UNICAST_VOQ = "SAI_QUEUE_TYPE_UNICAST_VOQ"
SAI_QUEUE_TYPE_ALL = "SAI_QUEUE_TYPE_ALL"

COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"
COUNTERS_SYSTEM_PORT_NAME_MAP = "COUNTERS_SYSTEM_PORT_NAME_MAP"
COUNTERS_QUEUE_NAME_MAP = "COUNTERS_QUEUE_NAME_MAP"
COUNTERS_VOQ_NAME_MAP= "COUNTERS_VOQ_NAME_MAP"
COUNTERS_QUEUE_TYPE_MAP = "COUNTERS_QUEUE_TYPE_MAP"
COUNTERS_QUEUE_INDEX_MAP = "COUNTERS_QUEUE_INDEX_MAP"
COUNTERS_QUEUE_PORT_MAP = "COUNTERS_QUEUE_PORT_MAP"

cnstat_dir = 'N/A'
cnstat_fqn_file = 'N/A'


def build_json(port, cnstat):
    def ports_stats(k):
        p = {}
        p[k[1]] = {
            "wreddroppacket": k[2],
            "wreddropbytes": k[3],
            "ecnmarkedpacket": k[4],
            "ecnmarkedbytes": k[5]
        }
        return p

    out = {}
    for k in cnstat:
        out.update(ports_stats(k))
    return out


class Wredstat(object):
    def __init__(self, namespace, voq=False):
        self.db = None
        self.multi_asic = multi_asic_util.MultiAsic(constants.DISPLAY_ALL, namespace)
        if namespace is not None:
            for ns in self.multi_asic.get_ns_list_based_on_options():
                self.db = multi_asic.connect_to_all_dbs_for_ns(ns)
        else:
            self.db = SonicV2Connector(use_unix_socket_path=False)
            self.db.connect(self.db.COUNTERS_DB)
        self.voq = voq

        def get_queue_port(table_id):
            port_table_id = self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_PORT_MAP, table_id)
            if port_table_id is None:
                print("Port is not available!", table_id)
                sys.exit(1)

            return port_table_id

        # Get all ports
        if voq:
            self.counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_SYSTEM_PORT_NAME_MAP)
        else:
            self.counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP)

        if self.counter_port_name_map is None:
            print("COUNTERS_PORT_NAME_MAP is empty!")
            sys.exit(1)

        self.port_queues_map = {}
        self.port_name_map = {}

        for port in self.counter_port_name_map:
            self.port_queues_map[port] = {}
            self.port_name_map[self.counter_port_name_map[port]] = port

        counter_queue_name_map = None
        # Get Queues for each port
        if voq:
            counter_queue_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_VOQ_NAME_MAP)
        else:
            counter_queue_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_QUEUE_NAME_MAP)

        if counter_queue_name_map is None:
            print("COUNTERS_QUEUE_NAME_MAP is empty!")
            sys.exit(1)

        for queue in counter_queue_name_map:
            port = self.port_name_map[get_queue_port(counter_queue_name_map[queue])]
            self.port_queues_map[port][queue] = counter_queue_name_map[queue]

    def get_cnstat(self, queue_map):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            def get_queue_index(table_id):
                queue_index =  self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_INDEX_MAP, table_id)
                if queue_index is None:
                    print("Queue index is not available!", table_id)
                    sys.exit(1)

                return queue_index

            def get_queue_type(table_id):
                queue_type =  self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_TYPE_MAP, table_id)
                if queue_type is None:
                    print("Queue Type is not available!", table_id)
                    sys.exit(1)
                elif queue_type == SAI_QUEUE_TYPE_MULTICAST:
                    return QUEUE_TYPE_MC
                elif queue_type == SAI_QUEUE_TYPE_UNICAST:
                    return QUEUE_TYPE_UC
                elif queue_type == SAI_QUEUE_TYPE_UNICAST_VOQ:
                    return QUEUE_TYPE_VOQ
                elif queue_type == SAI_QUEUE_TYPE_ALL:
                    return QUEUE_TYPE_ALL
                else:
                    print("Queue Type is invalid:", table_id, queue_type)
                    sys.exit(1)

            fields = ["0","0","0","0","0","0"]
            fields[0] = get_queue_index(table_id)
            fields[1] = get_queue_type(table_id)
            self.state_db = SonicV2Connector(use_unix_socket_path=False)
            self.state_db.connect(self.state_db.STATE_DB)     

            for counter_name, pos in counter_bucket_dict.items():
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
                if counter_data is None:
                    fields[pos] = STATUS_NA
                elif fields[pos] != STATUS_NA:
                    fields[pos] = str(int(counter_data))
            cntr = QueueStats._make(fields)._asdict()
            return cntr

       # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()
        if queue_map is None:
            return cnstat_dict
        for queue in natsorted(queue_map):
            cnstat_dict[queue] = get_counters(queue_map[queue])
        return cnstat_dict

    def cnstat_print(self, port, cnstat_dict, json_opt):
        """
        Print the cnstat. If JSON option is True, return data in
        JSON format.
        """
        table = []
        json_output = {port: {}}

        for key, data in cnstat_dict.items():
            if key == 'time':
                if json_opt:
                    json_output[port][key] = data
                continue
            table.append((port, data['queuetype'] + str(data['queueindex']),
                        data['wredDrppacket'], data['wredDrpbytes'],
                        data['ecnpacket'], data['ecnbytes']))

        if json_opt:
            json_output[port].update(build_json(port, table))
            return json_output
        else:
            hdr = voq_header if self.voq else header
            print(tabulate(table, hdr, tablefmt='simple', stralign='right'))
            print()

    def cnstat_diff_print(self, port, cnstat_new_dict, cnstat_old_dict, json_opt):
        """
        Print the difference between two cnstat results. If JSON
        option is True, return data in JSON format.
        """
        table = []
        json_output = {port: {}}

        for key, cntr in cnstat_new_dict.items():
            if key == 'time':
                if json_opt:
                    json_output[port][key] = cntr
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            if old_cntr is not None:
                table.append((port, cntr['queuetype'] + str(cntr['queueindex']),
                            ns_diff(cntr['wredDrppacket'], old_cntr['wredDrppacket']),
                            ns_diff(cntr['wredDrpbytes'], old_cntr['wredDrpbytes']),
                            ns_diff(cntr['ecnpacket'], old_cntr['ecnpacket']),
                            ns_diff(cntr['ecnbytes'], old_cntr['ecnbytes'])))
            else:
                table.append((port, cntr['queuetype'] + str(cntr['queueindex']),
                        cntr['wredDrppacket'], cntr['wredDrpbytes'],
                        cntr['ecnpacket'], cntr['ecnbytes']))

        if json_opt:
            json_output[port].update(build_json(port, table))
            return json_output
        else:
            hdr = voq_header if self.voq else header
            print(tabulate(table, hdr, tablefmt='simple', stralign='right'))
            print()

    def get_print_all_stat(self, json_opt):
        """
        Get stat for each port
        If JSON option is True, collect data for each port and
        print data in JSON format for all ports
        """
        json_output = {}
        for port in natsorted(self.counter_port_name_map):
            json_output[port] = {}
            cnstat_dict = self.get_cnstat(self.port_queues_map[port])

            cnstat_fqn_file_name = cnstat_fqn_file + port
            if os.path.isfile(cnstat_fqn_file_name):
                try:
                    cnstat_cached_dict = json.load(open(cnstat_fqn_file_name, 'r'))
                    if json_opt:
                        json_output[port].update({"cached_time":cnstat_cached_dict.get('time')})
                        json_output.update(self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict, json_opt))
                    else:
                        print(port + " Last cached time was " + str(cnstat_cached_dict.get('time')))
                        self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict, json_opt)
                except IOError as e:
                    print(e.errno, e)
            else:
                if json_opt:
                    json_output.update(self.cnstat_print(port, cnstat_dict, json_opt))
                else:
                    self.cnstat_print(port, cnstat_dict, json_opt)

        if json_opt:
            print(json_dump(json_output))

    def get_print_port_stat(self, port, json_opt):
        """
        Get stat for the port
        If JSON option is True  print data in JSON format
        """
        if not port in self.port_queues_map:
            print("Port does not exist!", port)
            sys.exit(1)

        # Get stat for the port queried
        cnstat_dict = self.get_cnstat(self.port_queues_map[port])
        cnstat_fqn_file_name = cnstat_fqn_file + port
        json_output = {}
        json_output[port] = {}
        if os.path.isfile(cnstat_fqn_file_name):
            try:
                cnstat_cached_dict = json.load(open(cnstat_fqn_file_name, 'r'))
                if json_opt:
                    json_output[port].update({"cached_time":cnstat_cached_dict.get('time')})
                    json_output.update(self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict, json_opt))
                else:
                    print("Last cached time was " + str(cnstat_cached_dict.get('time')))
                    self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict, json_opt)
            except IOError as e:
                print(e.errno, e)
        else:
            if json_opt:
                json_output.update(self.cnstat_print(port, cnstat_dict, json_opt))
            else:
                self.cnstat_print(port, cnstat_dict, json_opt)

        if json_opt:
            print(json_dump(json_output))

    def save_fresh_stats(self):
        # Get stat for each port and save
        for port in natsorted(self.counter_port_name_map):
            cnstat_dict = self.get_cnstat(self.port_queues_map[port])
            try:
                json.dump(cnstat_dict, open(cnstat_fqn_file + port, 'w'), default=json_serial)
            except IOError as e:
                print(e.errno, e)
                sys.exit(e.errno)
            else:
                print("Clear and update saved counters for " + port)

def main():
    global cnstat_dir
    global cnstat_fqn_file

    parser  = argparse.ArgumentParser(description='Display the wred queue counters',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Examples:
  wredstat
  wredstat -p Ethernet0
  wredstat -c
  wredstat -d
""")

    parser.add_argument('-p', '--port', type=str, help='Show the wred queue counters for just one port', default=None)
    parser.add_argument('-c', '--clear', action='store_true', help='Clear previous stats and save new ones')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats')
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    parser.add_argument('-j', '--json_opt', action='store_true', help='Print in JSON format')
    parser.add_argument('-V', '--voq', action='store_true', help='display voq stats')
    parser.add_argument('-n', '--namespace', default=None, help='Display queue counters for specific namespace')
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_stats = args.delete
    voq = args.voq
    json_opt = args.json_opt
    namespace = args.namespace

    port_to_show_stats = args.port

    cache = UserCache()

    cnstat_dir = cache.get_directory()
    cnstat_fqn_file = os.path.join(cnstat_dir, 'wredstat')

    if delete_stats:
        cache.remove()

    wredstat = Wredstat( namespace, voq )

    if save_fresh_stats:
        wredstat.save_fresh_stats()
        sys.exit(0)


    if port_to_show_stats!=None:
        wredstat.get_print_port_stat(port_to_show_stats, json_opt)
    else:
        wredstat.get_print_all_stat(json_opt)

    sys.exit(0)

if __name__ == "__main__":
    main()
