#!python

#####################################################################
#
# queuestat is a tool for summarizing queue statistics of all ports.
#
#####################################################################

import json
import click
import datetime
import os.path
import sys

from collections import namedtuple, OrderedDict
from natsort import natsorted
from tabulate import tabulate
from sonic_py_common import multi_asic

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector # lgtm [py/unused-import]
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()

except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector
from utilities_common.cli import json_serial, UserCache
from utilities_common import constants
import utilities_common.multi_asic as multi_asic_util

QueueStats = namedtuple("QueueStats", "queueindex, queuetype, totalpacket, totalbytes, droppacket, dropbytes, trimpacket")
VoqStats = namedtuple("VoqStats", "queueindex, queuetype, totalpacket, totalbytes, droppacket, dropbytes, creditWDpkts")
std_header = ['Port', 'TxQ', 'Counter/pkts', 'Counter/bytes', 'Drop/pkts', 'Drop/bytes']
all_header = ['Port', 'TxQ', 'Counter/pkts', 'Counter/bytes', 'Drop/pkts', 'Drop/bytes', 'Trim/pkts']
trim_header = ['Port', 'TxQ', 'Trim/pkts']
voq_header = ['Port', 'Voq', 'Counter/pkts', 'Counter/bytes', 'Drop/pkts', 'Drop/bytes', 'Credit-WD-Del/pkts']

counter_bucket_dict = {
    'SAI_QUEUE_STAT_PACKETS': 2,
    'SAI_QUEUE_STAT_BYTES': 3,
    'SAI_QUEUE_STAT_DROPPED_PACKETS': 4,
    'SAI_QUEUE_STAT_DROPPED_BYTES': 5,
}
trim_counter_bucket_dict = {
    'SAI_QUEUE_STAT_TRIM_PACKETS': 6,
}
voq_counter_bucket_dict = {
    'SAI_QUEUE_STAT_CREDIT_WD_DELETED_PACKETS': 6
}

from utilities_common.cli import json_dump
from utilities_common.netstat import ns_diff, STATUS_NA

QUEUE_TYPE_MC = 'MC'
QUEUE_TYPE_UC = 'UC'
QUEUE_TYPE_ALL = 'ALL'
QUEUE_TYPE_VOQ = 'VOQ'
SAI_QUEUE_TYPE_MULTICAST = "SAI_QUEUE_TYPE_MULTICAST"
SAI_QUEUE_TYPE_UNICAST = "SAI_QUEUE_TYPE_UNICAST"
SAI_QUEUE_TYPE_UNICAST_VOQ = "SAI_QUEUE_TYPE_UNICAST_VOQ"
SAI_QUEUE_TYPE_ALL = "SAI_QUEUE_TYPE_ALL"

COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"
COUNTERS_SYSTEM_PORT_NAME_MAP = "COUNTERS_SYSTEM_PORT_NAME_MAP"
COUNTERS_QUEUE_NAME_MAP = "COUNTERS_QUEUE_NAME_MAP"
COUNTERS_VOQ_NAME_MAP= "COUNTERS_VOQ_NAME_MAP"
COUNTERS_QUEUE_TYPE_MAP = "COUNTERS_QUEUE_TYPE_MAP"
COUNTERS_QUEUE_INDEX_MAP = "COUNTERS_QUEUE_INDEX_MAP"
COUNTERS_QUEUE_PORT_MAP = "COUNTERS_QUEUE_PORT_MAP"

cnstat_dir = 'N/A'
cnstat_fqn_file = 'N/A'


def build_json(port, cnstat, all=False, trim=False, voq=False):
    def ports_stats(k):
        p = {}
        if voq:
           p[k[1]] = {
              "totalpacket": k[2],
              "totalbytes": k[3],
              "droppacket": k[4],
              "dropbytes": k[5],
              "creditWDPkts": k[6]
           }
        else:
            if all: # All statistics
                p[k[1]] = {
                    "totalpacket": k[2],
                    "totalbytes": k[3],
                    "droppacket": k[4],
                    "dropbytes": k[5],
                    "trimpacket": k[6],
                }
            elif trim: # Packet Trimming related statistics
                p[k[1]] = {
                    "trimpacket": k[2],
                }
            else: # Generic statistics
                p[k[1]] = {
                    "totalpacket": k[2],
                    "totalbytes": k[3],
                    "droppacket": k[4],
                    "dropbytes": k[5],
                }
        return p

    out = {}
    for k in cnstat:
        out.update(ports_stats(k))
    return out

class QueuestatWrapper(object):
    """A wrapper to execute queuestat cmd over the correct namespaces"""
    def __init__(self, namespace, all, trim, voq):
        self.namespace = namespace
        self.all = all
        self.trim = trim
        self.voq = voq

        # Initialize the multi-asic namespace
        self.multi_asic = multi_asic_util.MultiAsic(constants.DISPLAY_ALL, namespace_option=namespace)
        self.db = None

    @multi_asic_util.run_on_multi_asic
    def run(self, save_fresh_stats, port_to_show_stats, json_opt, non_zero):
        queuestat = Queuestat(self.multi_asic.current_namespace, self.db, self.all, self.trim, self.voq)
        if save_fresh_stats:
            queuestat.save_fresh_stats()
            return

        if port_to_show_stats != None:
            queuestat.get_print_port_stat(port_to_show_stats, json_opt, non_zero)
        else:
            queuestat.get_print_all_stat(json_opt, non_zero)


class Queuestat(object):
    def __init__(self, namespace, db, all=False, trim=False, voq=False):
        self.db = db
        self.all = all
        self.trim = trim
        self.voq = voq
        self.namespace = namespace
        self.namespace_str = f" for {namespace}" if namespace else ''

        def get_queue_port(table_id):
            port_table_id = self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_PORT_MAP, table_id)
            if port_table_id is None:
                print(f"Port is not available{self.namespace_str}!", table_id)
                sys.exit(1)

            return port_table_id

        # Get all ports
        if voq:
            self.counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_SYSTEM_PORT_NAME_MAP)
        else:
            self.counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP)

        if self.counter_port_name_map is None:
            print(f"COUNTERS_PORT_NAME_MAP is empty{self.namespace_str}!")
            sys.exit(1)

        self.port_queues_map = {}
        self.port_name_map = {}

        for port in self.counter_port_name_map:
            self.port_queues_map[port] = {}
            self.port_name_map[self.counter_port_name_map[port]] = port

        counter_queue_name_map = None
        # Get Queues for each port
        if voq:
            counter_queue_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_VOQ_NAME_MAP)
        else:
            counter_queue_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_QUEUE_NAME_MAP)

        if counter_queue_name_map is None:
            print(f"COUNTERS_QUEUE_NAME_MAP is empty{self.namespace_str}!")
            sys.exit(1)

        for queue in counter_queue_name_map:
            port = self.port_name_map[get_queue_port(counter_queue_name_map[queue])]
            self.port_queues_map[port][queue] = counter_queue_name_map[queue]

    def get_cnstat(self, queue_map):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            def get_queue_index(table_id):
                queue_index =  self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_INDEX_MAP, table_id)
                if queue_index is None:
                    print(f"Queue index is not available{self.namespace_str}!", table_id)
                    sys.exit(1)

                return queue_index

            def get_queue_type(table_id):
                queue_type =  self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_TYPE_MAP, table_id)
                if queue_type is None:
                    print(f"Queue Type is not available{self.namespace_str}!", table_id)
                    sys.exit(1)
                elif queue_type == SAI_QUEUE_TYPE_MULTICAST:
                    return QUEUE_TYPE_MC
                elif queue_type == SAI_QUEUE_TYPE_UNICAST:
                    return QUEUE_TYPE_UC
                elif queue_type == SAI_QUEUE_TYPE_UNICAST_VOQ:
                    return QUEUE_TYPE_VOQ
                elif queue_type == SAI_QUEUE_TYPE_ALL:
                    return QUEUE_TYPE_ALL
                else:
                    print(f"Queue Type is invalid{self.namespace_str}:", table_id, queue_type)
                    sys.exit(1)

            counter_dict = { **counter_bucket_dict }
            fields = [ get_queue_index(table_id), get_queue_type(table_id) ]

            if self.voq:
                counter_dict.update(voq_counter_bucket_dict)
            else:
                counter_dict.update(trim_counter_bucket_dict)

            # Layout is per QueueStats/VoqStats type definition
            fields.extend(["0"]*len(counter_dict))

            for counter_name, pos in counter_dict.items():
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
                if counter_data is None:
                    fields[pos] = STATUS_NA
                elif fields[pos] != STATUS_NA:
                    fields[pos] = str(int(counter_data))

            if self.voq:
               cntr = VoqStats._make(fields)._asdict()
            else:
               cntr = QueueStats._make(fields)._asdict()
            return cntr

        # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()
        if queue_map is None:
            return cnstat_dict
        for queue in natsorted(queue_map):
            cnstat_dict[queue] = get_counters(queue_map[queue])
        return cnstat_dict

    def cnstat_print(self, port, cnstat_dict, json_opt, non_zero):
        """
        Print the cnstat. If JSON option is True, return data in
        JSON format.
        """
        table = []
        json_output = {port: {}}

        for key, data in cnstat_dict.items():
            if key == 'time':
                if json_opt:
                    json_output[port][key] = data
                continue
            if self.voq:
               if not non_zero or data['totalpacket'] != '0' or  data['totalbytes'] != '0' or \
                  data['droppacket'] != '0' or data['dropbytes'] != '0' or data['creditWDpkts'] != '0':
                  table.append((port, data['queuetype'] + str(data['queueindex']),
                              data['totalpacket'], data['totalbytes'],
                              data['droppacket'], data['dropbytes'], data['creditWDpkts']))
            else:
                queuetag = data['queuetype'] + str(data['queueindex'])

                if self.all: # All statistics
                    if not non_zero or \
                            data['totalpacket'] != '0' or data['totalbytes'] != '0' or \
                            data['droppacket'] != '0' or data['dropbytes'] != '0' or \
                            data['trimpacket'] != '0':
                        table.append((
                            port, queuetag,
                            data['totalpacket'], data['totalbytes'],
                            data['droppacket'], data['dropbytes'],
                            data['trimpacket']
                        ))
                elif self.trim: # Packet Trimming related statistics
                    if not non_zero or \
                            data['trimpacket'] != '0':
                        table.append((
                            port, queuetag,
                            data['trimpacket']
                        ))
                else: # Generic statistics
                    if not non_zero or \
                            data['totalpacket'] != '0' or data['totalbytes'] != '0' or \
                            data['droppacket'] != '0' or data['dropbytes'] != '0':
                        table.append((
                            port, queuetag,
                            data['totalpacket'], data['totalbytes'],
                            data['droppacket'], data['dropbytes']
                        ))

        if json_opt:
            json_output[port].update(build_json(port, table, self.all, self.trim, self.voq))
            return json_output
        else:
            if self.voq:
                hdr = voq_header
            elif self.all:
                hdr = all_header
            elif self.trim:
                hdr = trim_header
            else:
                hdr = std_header

            if table:
                print(f"For namespace {self.namespace}:")
                print(tabulate(table, hdr, tablefmt='simple', stralign='right'))
                print()

    def cnstat_diff_print(self, port, cnstat_new_dict, cnstat_old_dict, json_opt, non_zero):
        """
        Print the difference between two cnstat results. If JSON
        option is True, return data in JSON format.
        """
        table = []
        json_output = {port: {}}

        for key, cntr in cnstat_new_dict.items():
            if key == 'time':
                if json_opt:
                    json_output[port][key] = cntr
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)
            if old_cntr is not None:
                if self.voq:
                    if not non_zero or ns_diff(cntr['totalpacket'], old_cntr['totalpacket']) != '0' or \
                                   ns_diff(cntr['totalbytes'], old_cntr['totalbytes']) != '0' or \
                                   ns_diff(cntr['droppacket'], old_cntr['droppacket']) != '0' or \
                                   ns_diff(cntr['dropbytes'], old_cntr['dropbytes']) != '0' or \
                                   ns_diff(cntr['creditWDpkts'], old_cntr['creditWDpkts']) != '0':
                        table.append((port, cntr['queuetype'] + str(cntr['queueindex']),
                                   ns_diff(cntr['totalpacket'], old_cntr['totalpacket']),
                                   ns_diff(cntr['totalbytes'], old_cntr['totalbytes']),
                                   ns_diff(cntr['droppacket'], old_cntr['droppacket']),
                                   ns_diff(cntr['dropbytes'], old_cntr['dropbytes']),
                                   ns_diff(cntr['creditWDpkts'], old_cntr['creditWDpkts'])))
                else:
                    queuetag = cntr['queuetype'] + str(cntr['queueindex'])

                    if self.all: # All statistics
                        totalpacket = ns_diff(cntr['totalpacket'], old_cntr['totalpacket'])
                        totalbytes = ns_diff(cntr['totalbytes'], old_cntr['totalbytes'])
                        droppacket = ns_diff(cntr['droppacket'], old_cntr['droppacket'])
                        dropbytes = ns_diff(cntr['dropbytes'], old_cntr['dropbytes'])
                        trimpacket = ns_diff(cntr['trimpacket'], old_cntr['trimpacket'])

                        if not non_zero or \
                                totalpacket != '0' or totalbytes != '0' or \
                                droppacket != '0' or dropbytes != '0' or \
                                trimpacket != '0':
                            table.append((
                                port, queuetag,
                                totalpacket, totalbytes,
                                droppacket, dropbytes,
                                trimpacket
                            ))
                    elif self.trim: # Packet Trimming related statistics
                        trimpacket = ns_diff(cntr['trimpacket'], old_cntr['trimpacket'])

                        if not non_zero or \
                                trimpacket != '0':
                            table.append((
                                port, queuetag,
                                trimpacket
                            ))
                    else: # Generic statistics
                        totalpacket = ns_diff(cntr['totalpacket'], old_cntr['totalpacket'])
                        totalbytes = ns_diff(cntr['totalbytes'], old_cntr['totalbytes'])
                        droppacket = ns_diff(cntr['droppacket'], old_cntr['droppacket'])
                        dropbytes = ns_diff(cntr['dropbytes'], old_cntr['dropbytes'])

                        if not non_zero or \
                                totalpacket != '0' or totalbytes != '0' or \
                                droppacket != '0' or dropbytes != '0':
                            table.append((
                                port, queuetag,
                                totalpacket, totalbytes,
                                droppacket, dropbytes
                            ))
        if json_opt:
            json_output[port].update(build_json(port, table, self.all, self.trim, self.voq))
            return json_output
        else:
            if self.voq:
                hdr = voq_header
            elif self.all:
                hdr = all_header
            elif self.trim:
                hdr = trim_header
            else:
                hdr = std_header

            if table:
                print(port + f" Last cached time{self.namespace_str} was " + str(cnstat_old_dict.get('time')))
                print(tabulate(table, hdr, tablefmt='simple', stralign='right'))
                print()

    def get_print_all_stat(self, json_opt, non_zero):
        """
        Get stat for each port
        If JSON option is True, collect data for each port and
        print data in JSON format for all ports
        """
        json_output = {}
        for port in natsorted(self.counter_port_name_map):
            json_output[port] = {}
            cnstat_dict = self.get_cnstat(self.port_queues_map[port])
            cache_ns = ''
            if self.voq and self.namespace is not None:
                cache_ns = '-' + self.namespace + '-'
            cnstat_fqn_file_name = cnstat_fqn_file + cache_ns + port
            if os.path.isfile(cnstat_fqn_file_name):
                try:
                    cnstat_cached_dict = json.load(open(cnstat_fqn_file_name, 'r'))
                    if json_opt:
                        json_output[port].update({"cached_time":cnstat_cached_dict.get('time')})
                        json_output.update(self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict, json_opt, non_zero))
                    else:
                        self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict, json_opt, non_zero)
                except IOError as e:
                    print(e.errno, e)
            else:
                if json_opt:
                    json_output.update(self.cnstat_print(port, cnstat_dict, json_opt, non_zero))
                else:
                    self.cnstat_print(port, cnstat_dict, json_opt, non_zero)

        if json_opt:
            print(json_dump(json_output))

    def get_print_port_stat(self, port, json_opt, non_zero):
        """
        Get stat for the port
        If JSON option is True  print data in JSON format
        """
        if not port in self.port_queues_map:
            print("Port doesn't exist!", port)
            sys.exit(1)

        # Get stat for the port queried
        cnstat_dict = self.get_cnstat(self.port_queues_map[port])
        cache_ns = ''
        if self.voq and self.namespace is not None:
            cache_ns = '-' + self.namespace + '-'
        cnstat_fqn_file_name = cnstat_fqn_file + cache_ns + port
        json_output = {}
        json_output[port] = {}
        if os.path.isfile(cnstat_fqn_file_name):
            try:
                cnstat_cached_dict = json.load(open(cnstat_fqn_file_name, 'r'))
                if json_opt:
                    json_output[port].update({"cached_time":cnstat_cached_dict.get('time')})
                    json_output.update(self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict, json_opt, non_zero))
                else:
                    print(f"Last cached time{self.namespace_str} was " + str(cnstat_cached_dict.get('time')))
                    self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict, json_opt, non_zero)
            except IOError as e:
                print(e.errno, e)
        else:
            if json_opt:
                json_output.update(self.cnstat_print(port, cnstat_dict, json_opt, non_zero))
            else:
                self.cnstat_print(port, cnstat_dict, json_opt, non_zero)

        if json_opt:
            print(json_dump(json_output))

    def save_fresh_stats(self):
        # Get stat for each port and save
        cache_ns = ''
        if self.voq and self.namespace is not None:
            cache_ns = '-' + self.namespace + '-'
        for port in natsorted(self.counter_port_name_map):
            cnstat_dict = self.get_cnstat(self.port_queues_map[port])
            try:
                json.dump(cnstat_dict, open(cnstat_fqn_file + cache_ns + port, 'w'), default=json_serial)
            except IOError as e:
                print(e.errno, e)
                sys.exit(e.errno)
            else:
                print("Clear and update saved counters for " + port)


@click.command()
@click.option('-p', '--port', type=str, help='Show the queue conters for just one port', default=None)
@click.option('-c', '--clear', is_flag=True, default=False, help='Clear previous stats and save new ones')
@click.option('-d', '--delete', is_flag=True, default=False, help='Delete saved stats')
@click.option('-j', '--json_opt',  is_flag=True, default=False, help='Print in JSON format')
@click.option('-a', '--all', is_flag=True, default=False, help='Display all the stats counters')
@click.option('-T', '--trim', is_flag=True, default=False, help='Display trimming related statistics')
@click.option('-V', '--voq', is_flag=True, default=False, help='display voq stats')
@click.option('-nz','--non_zero', is_flag=True, default=False, help='Display non-zero queue counters')
@click.option('-n', '--namespace', type=click.Choice(multi_asic.get_namespace_list()), help='Display queuecounters for a specific namespace name or skip for all', default=None)
@click.version_option(version='1.0')
def main(port, clear, delete, json_opt, all, trim, voq, non_zero, namespace):
    """
    Examples:
      queuestat
      queuestat -p Ethernet0
      queuestat -c
      queuestat -d
      queuestat -p Ethernet0 -n asic0
    """

    global cnstat_dir
    global cnstat_fqn_file

    save_fresh_stats = clear
    delete_stats = delete

    port_to_show_stats = port

    cache = UserCache()

    cnstat_dir = cache.get_directory()
    cnstat_fqn_file = os.path.join(cnstat_dir, 'queuestat')

    if delete_stats:
        cache.remove()

    queuestat_wrapper = QueuestatWrapper(namespace, all, trim, voq)
    queuestat_wrapper.run(save_fresh_stats, port_to_show_stats, json_opt, non_zero)

    sys.exit(0)

if __name__ == "__main__":
    main()
