#!python

#####################################################################
#
# watermarkstat is a tool for displaying watermarks.
#
#####################################################################
import click
import json
import os
import sys

from natsort import natsorted
from tabulate import tabulate
from sonic_py_common import multi_asic
import utilities_common.multi_asic as multi_asic_util

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        from mock_tables import dbconnector

        if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
            import tests.mock_tables.mock_multi_asic
            dbconnector.load_namespace_config()

        if os.environ["WATERMARKSTAT_UNIT_TESTING"] == "1":
            input_path = os.path.join(tests_path, "wm_input")
            mock_db_path = os.path.join(input_path, "mock_db")

            for mock_db in os.listdir(mock_db_path):
                name, ext = os.path.splitext(mock_db)
                if ext == ".json":
                    dbconnector.dedicated_dbs[name.upper()] = os.path.join(mock_db_path, name)
except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector


headerBufferPool = ['Pool', 'Bytes']


STATUS_NA = 'N/A'
STATUS_INVALID = 'INVALID'

QUEUE_TYPE_MC = 'MC'
QUEUE_TYPE_UC = 'UC'
QUEUE_TYPE_ALL = 'ALL'
SAI_QUEUE_TYPE_MULTICAST = "SAI_QUEUE_TYPE_MULTICAST"
SAI_QUEUE_TYPE_UNICAST = "SAI_QUEUE_TYPE_UNICAST"
SAI_QUEUE_TYPE_ALL = "SAI_QUEUE_TYPE_ALL"

COUNTER_TABLE_PREFIX = "COUNTERS:"
PERSISTENT_TABLE_PREFIX = "PERSISTENT_WATERMARKS:"
PERIODIC_TABLE_PREFIX = "PERIODIC_WATERMARKS:"
USER_TABLE_PREFIX = "USER_WATERMARKS:"

COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"
COUNTERS_QUEUE_NAME_MAP = "COUNTERS_QUEUE_NAME_MAP"
COUNTERS_QUEUE_TYPE_MAP = "COUNTERS_QUEUE_TYPE_MAP"
COUNTERS_QUEUE_INDEX_MAP = "COUNTERS_QUEUE_INDEX_MAP"
COUNTERS_QUEUE_PORT_MAP = "COUNTERS_QUEUE_PORT_MAP"
COUNTERS_PG_NAME_MAP = "COUNTERS_PG_NAME_MAP"
COUNTERS_PG_PORT_MAP = "COUNTERS_PG_PORT_MAP"
COUNTERS_PG_INDEX_MAP = "COUNTERS_PG_INDEX_MAP"
COUNTERS_BUFFER_POOL_NAME_MAP = "COUNTERS_BUFFER_POOL_NAME_MAP"


class WatermarkstatWrapper(object):
    """A wrapper to execute Watermarkstat over the correct namespaces"""
    def __init__(self, namespace):
        self.namespace = namespace

        # Initialize the multi_asic object
        self.multi_asic = multi_asic_util.MultiAsic(namespace_option=namespace)
        self.db = None

    @multi_asic_util.run_on_multi_asic
    def run(self, clear, persistent, wm_type):
        watermarkstat = Watermarkstat(self.db, self.multi_asic.current_namespace)
        if clear:
            watermarkstat.send_clear_notification(("PERSISTENT" if persistent else "USER", wm_type.upper()))
        else:
            table_prefix = PERSISTENT_TABLE_PREFIX if persistent else USER_TABLE_PREFIX
            watermarkstat.print_all_stat(table_prefix, wm_type)


class Watermarkstat(object):

    def __init__(self, db, namespace):
        self.namespace = namespace
        self.db = db

        def get_queue_type(table_id):
            queue_type = self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_TYPE_MAP, table_id)
            if queue_type is None:
                print("Queue Type is not available in table '{}'".format(table_id), file=sys.stderr)
                sys.exit(1)
            elif queue_type == SAI_QUEUE_TYPE_MULTICAST:
                return QUEUE_TYPE_MC
            elif queue_type == SAI_QUEUE_TYPE_UNICAST:
                return QUEUE_TYPE_UC
            elif queue_type == SAI_QUEUE_TYPE_ALL:
                return QUEUE_TYPE_ALL
            else:
                print("Queue Type '{} in table '{}' is invalid".format(queue_type, table_id), file=sys.stderr)
                sys.exit(1)

        def get_queue_port(table_id):
            port_table_id = self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_PORT_MAP, table_id)
            if port_table_id is None:
                print("Port is not available in table '{}'".format(table_id), file=sys.stderr)
                sys.exit(1)

            return port_table_id

        def get_pg_port(table_id):
            port_table_id = self.db.get(self.db.COUNTERS_DB, COUNTERS_PG_PORT_MAP, table_id)
            if port_table_id is None:
                print("Port is not available in table '{}'".format(table_id), file=sys.stderr)
                sys.exit(1)

            return port_table_id

        # Get all ports
        self.counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP)
        if self.counter_port_name_map is None:
            print("COUNTERS_PORT_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        self.port_uc_queues_map = {}
        self.port_mc_queues_map = {}
        self.port_all_queues_map = {}
        self.port_pg_map = {}
        self.port_name_map = {}

        for port in self.counter_port_name_map:
            self.port_uc_queues_map[port] = {}
            self.port_mc_queues_map[port] = {}
            self.port_all_queues_map[port] = {}
            self.port_pg_map[port] = {}
            self.port_name_map[self.counter_port_name_map[port]] = port

        # Get Queues for each port
        counter_queue_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_QUEUE_NAME_MAP)
        if counter_queue_name_map is None:
            print("COUNTERS_QUEUE_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        for queue in counter_queue_name_map:
            port = self.port_name_map[get_queue_port(counter_queue_name_map[queue])]
            if get_queue_type(counter_queue_name_map[queue]) == QUEUE_TYPE_UC:
                self.port_uc_queues_map[port][queue] = counter_queue_name_map[queue]

            elif get_queue_type(counter_queue_name_map[queue]) == QUEUE_TYPE_MC:
                self.port_mc_queues_map[port][queue] = counter_queue_name_map[queue]

            elif get_queue_type(counter_queue_name_map[queue]) == QUEUE_TYPE_ALL:
                self.port_all_queues_map[port][queue] = counter_queue_name_map[queue]

        # Get PGs for each port
        counter_pg_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_PG_NAME_MAP)
        if counter_pg_name_map is None:
            print("COUNTERS_PG_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        for pg in counter_pg_name_map:
            port = self.port_name_map[get_pg_port(counter_pg_name_map[pg])]
            self.port_pg_map[port][pg] = counter_pg_name_map[pg]

        # Get all buffer pools
        self.buffer_pool_name_to_oid_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_BUFFER_POOL_NAME_MAP)
        if self.buffer_pool_name_to_oid_map is None:
            print("COUNTERS_BUFFER_POOL_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        self.watermark_types = {
            "pg_headroom"   : {"message" : "Ingress headroom per PG:",
                               "obj_map" : self.port_pg_map,
                               "idx_func": self.get_pg_index,
                               "wm_name" : "SAI_INGRESS_PRIORITY_GROUP_STAT_XOFF_ROOM_WATERMARK_BYTES",
                               "header_prefix": "PG"},
            "pg_shared"     : {"message" : "Ingress shared pool occupancy per PG:",
                               "obj_map" : self.port_pg_map,
                               "idx_func": self.get_pg_index,
                               "wm_name" : "SAI_INGRESS_PRIORITY_GROUP_STAT_SHARED_WATERMARK_BYTES",
                               "header_prefix": "PG"},
            "q_shared_uni"  : {"message" : "Egress shared pool occupancy per unicast queue:",
                               "obj_map" : self.port_uc_queues_map,
                               "idx_func": self.get_queue_index,
                               "wm_name" : "SAI_QUEUE_STAT_SHARED_WATERMARK_BYTES",
                               "header_prefix": "UC"},
            "q_shared_multi": {"message" : "Egress shared pool occupancy per multicast queue:",
                               "obj_map" : self.port_mc_queues_map,
                               "idx_func": self.get_queue_index,
                               "wm_name" : "SAI_QUEUE_STAT_SHARED_WATERMARK_BYTES",
                               "header_prefix": "MC"},
            "q_shared_all":   {"message" : "Egress shared pool occupancy per all queues:",
                               "obj_map" : self.port_all_queues_map,
                               "idx_func": self.get_queue_index,
                               "wm_name" : "SAI_QUEUE_STAT_SHARED_WATERMARK_BYTES",
                               "header_prefix": "ALL"},
            "buffer_pool"   : {"message": "Shared pool maximum occupancy:",
                               "wm_name": "SAI_BUFFER_POOL_STAT_WATERMARK_BYTES",
                               "header" : headerBufferPool},
            "headroom_pool" : {"message": "Headroom pool maximum occupancy:",
                               "wm_name": "SAI_BUFFER_POOL_STAT_XOFF_ROOM_WATERMARK_BYTES",
                               "header" : headerBufferPool}
        }

    def get_queue_index(self, table_id):
        queue_index = self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_INDEX_MAP, table_id)
        if queue_index is None:
            print("Queue index is not available in table '{}'".format(table_id), file=sys.stderr)
            sys.exit(1)

        return queue_index

    def get_pg_index(self, table_id):
        pg_index = self.db.get(self.db.COUNTERS_DB, COUNTERS_PG_INDEX_MAP, table_id)
        if pg_index is None:
            print("Priority group index is not available in table '{}'".format(table_id), file=sys.stderr)
            sys.exit(1)

        return pg_index

    def build_header(self, wm_type, counter_type):
        if wm_type is None:
            print("Header info is not available!", file=sys.stderr)
            sys.exit(1)

        self.header_list = ['Port']
        header_map = wm_type["obj_map"]

        header_idx_set = set()
        for port in header_map.keys():
            for element in header_map[port].keys():
                element_idx = int(element.split(':')[1])
                header_idx_set.add(element_idx)

        if len(header_idx_set) == 0:
            if counter_type != 'q_shared_multi':
                print("Object map is empty!", file=sys.stderr)
                sys.exit(1)
            else:
                print("Object map from the COUNTERS_DB is empty because the multicast queues are not configured in the CONFIG_DB!")
                sys.exit(0)

        header_idx_list = list(header_idx_set)
        header_idx_list.sort()

        self.header_idx_to_pos = {}
        for i in header_idx_list:
            self.header_idx_to_pos[i] = header_idx_list.index(i)

        self.min_idx = header_idx_list[0]
        self.header_list += ["{}{}".format(wm_type["header_prefix"], idx) for idx in header_idx_list]

    def get_counters(self, table_prefix, port_obj, idx_func, watermark):
        """
            Get the counters from specific table.
        """

        # header list contains the port name followed by the queues/pgs. fields is used to populate the queue/pg values
        fields = ["0"]* (len(self.header_list) - 1)
        if not fields:
            # counters are not enabled.
            return fields

        for name, obj_id in port_obj.items():
            full_table_id = table_prefix + obj_id
            idx = int(idx_func(obj_id))
            pos = self.header_idx_to_pos[idx]
            counter_data = self.db.get(self.db.COUNTERS_DB, full_table_id, watermark)
            if counter_data is None or counter_data == '':
                fields[pos] = STATUS_NA
            elif fields[pos] != STATUS_NA:
                fields[pos] = str(int(counter_data))
        return fields

    def print_all_stat(self, table_prefix, key):
        table = []
        type = self.watermark_types[key]
        if key in ['buffer_pool', 'headroom_pool']:
            self.header_list = type['header']
            # Get stats for each buffer pool
            for buf_pool, bp_oid in natsorted(self.buffer_pool_name_to_oid_map.items()):
                if key == 'headroom_pool' and 'ingress_lossless' not in buf_pool:
                    continue

                db_key = table_prefix + bp_oid
                data = self.db.get(self.db.COUNTERS_DB, db_key, type["wm_name"])
                if data is None:
                    data = STATUS_NA
                table.append((buf_pool, data))
        else:
            self.build_header(type, key)
            # Get stat for each port
            for port in natsorted(self.counter_port_name_map):
                row_data = list()

                data = self.get_counters(table_prefix,
                                         type["obj_map"][port], type["idx_func"], type["wm_name"])
                row_data.append(port)
                row_data.extend(data)
                table.append(tuple(row_data))

        namespace_str = f" (Namespace {self.namespace})" if multi_asic.is_multi_asic() else ''
        print(type["message"] + namespace_str)
        print(tabulate(table, self.header_list, tablefmt='simple', stralign='right'))

    def send_clear_notification(self, data):
        msg = json.dumps(data, separators=(',', ':'))
        self.db.publish('APPL_DB', 'WATERMARK_CLEAR_REQUEST', msg)
        return

@click.command()
@click.option('-c', '--clear', is_flag=True, help='Clear watermarks request')
@click.option('-p', '--persistent', is_flag=True, help='Do the operations on the persistent watermark')
@click.option('-t', '--type', 'wm_type', type=click.Choice(['pg_headroom', 'pg_shared', 'q_shared_uni', 'q_shared_multi', 'buffer_pool', 'headroom_pool', 'q_shared_all']), help='The type of watermark', required=True)
@click.option('-n', '--namespace', type=click.Choice(multi_asic.get_namespace_list()), help='Namespace name or skip for all', default=None)
@click.version_option(version='1.0')
def main(clear, persistent, wm_type, namespace):
    """
       Display the watermark counters

       Examples:
       watermarkstat -t pg_headroom
       watermarkstat -t pg_shared
       watermarkstat -t q_shared_all
       watermarkstat -p -t q_shared_all
       watermarkstat -t q_shared_all -c
       watermarkstat -t q_shared_uni -c
       wwatermarkstat -t q_shared_multi -c
       watermarkstat -p -t pg_shared
       watermarkstat -p -t q_shared_multi -c
       watermarkstat -t buffer_pool
       watermarkstat -t buffer_pool -c
       watermarkstat -p -t buffer_pool -c
       watermarkstat -t pg_headroom -n asic0
       watermarkstat -p -t buffer_pool -c -n asic1
    """

    namespace_context = WatermarkstatWrapper(namespace)
    namespace_context.run(clear, persistent, wm_type)
    sys.exit(0)

if __name__ == "__main__":
    main()
