#!python

"""
mmuconfig is the utility to show and change mmu configuration

usage: mmuconfig [-h] [-v] [-l] [-p PROFILE] [-a ALPHA] [-vv] [-s staticth]

optional arguments:
  -h     --help            show this help message and exit
  -v     --version         show program's version number and exit
  -vv    --verbose         verbose output
  -l     --list            show mmu configuration
  -p     --profile         specify buffer profile name
  -a     --alpha           set n for dyanmic threshold alpha 2^(n)
  -s     --staticth        set static threshold

"""

import os
import sys
import click
import tabulate
import traceback
import json
from utilities_common.general import load_db_config
from sonic_py_common import multi_asic
from utilities_common import multi_asic as multi_asic_util

BUFFER_POOL_TABLE_NAME = "BUFFER_POOL"
BUFFER_PROFILE_TABLE_NAME = "BUFFER_PROFILE"
DEFAULT_LOSSLESS_BUFFER_PARAMETER_NAME = "DEFAULT_LOSSLESS_BUFFER_PARAMETER"

DYNAMIC_THRESHOLD = "dynamic_th"
DYNAMIC_THRESHOLD_MIN = -8
DYNAMIC_THRESHOLD_MAX = 8
STATIC_THRESHOLD = "static_th"
STATIC_THRESHOLD_MIN = 0
PACKET_TRIMMING = "packet_discard_action"
PACKET_TRIMMING_ENABLED = "on"
PACKET_TRIMMING_DISABLED = "off"

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()
    else:
        mock_tables.dbconnector.load_database_config()

except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector, ConfigDBConnector

class MmuConfig(object):
    def __init__(self, verbose, config, filename, namespace):
        self.verbose = verbose
        self.config = config
        self.filename = filename
        self.namespace = namespace
        self.multi_asic = multi_asic_util.MultiAsic(namespace_option=namespace)
        self.config_db = None
        self.db = None

        # For unit testing
        self.updated_profile_table = {}

    def get_table(self, tablename):
        if self.config:
            return self.config_db.get_table(tablename)

        entries = {}
        keys = self.db.keys(self.db.STATE_DB, tablename + '*')

        if not keys:
            return None

        for key in keys:
            entries[key.split('|')[1]] = self.db.get_all(self.db.STATE_DB, key)

        return entries

    @multi_asic_util.run_on_multi_asic
    def list(self):
        namespace_str = f" for namespace {self.multi_asic.current_namespace}" if multi_asic.is_multi_asic() else ''
        lossless_traffic_pattern = self.get_table(DEFAULT_LOSSLESS_BUFFER_PARAMETER_NAME)
        if lossless_traffic_pattern:
            for _, pattern in lossless_traffic_pattern.items():
                config = []

                print(f"Lossless traffic pattern{namespace_str}:")
                for field, value in pattern.items():
                    config.append([field, value])
                print(tabulate.tabulate(config) + "\n")

        buf_pools = self.get_table(BUFFER_POOL_TABLE_NAME)
        if buf_pools:
            for pool_name, pool_data in buf_pools.items():
                config = []

                print(f"Pool{namespace_str}: " + pool_name)
                for field, value in pool_data.items():
                    config.append([field, value])
                print(tabulate.tabulate(config) + "\n")
            if self.verbose:
                print("Total pools: %d\n\n" % len(buf_pools))
        else:
            print(f"No buffer pool information available{namespace_str}")

        buf_profs = self.get_table(BUFFER_PROFILE_TABLE_NAME)
        if buf_profs:
            for prof_name, prof_data in buf_profs.items():
                config = []

                print(f"Profile{namespace_str}: " + prof_name)
                for field, value in prof_data.items():
                    config.append([field, value])
                print(tabulate.tabulate(config) + "\n")
            if self.verbose:
                print("Total profiles: %d" % len(buf_profs))
        else:
            print(f"No buffer profile information available{namespace_str}")

    @multi_asic_util.run_on_multi_asic
    def set(self, profile, field, value):
        namespace_str = f" for namespace {self.multi_asic.current_namespace}" if multi_asic.is_multi_asic() else ''
        if os.geteuid() != 0 and os.environ.get("UTILITIES_UNIT_TESTING", "0") != "2":
            sys.exit("Root privileges required for this operation")

        buf_profs = self.config_db.get_table(BUFFER_PROFILE_TABLE_NAME)
        if field == DYNAMIC_THRESHOLD:
            if profile in buf_profs and DYNAMIC_THRESHOLD not in buf_profs[profile]:
                sys.exit("%s not using dynamic thresholding" % (profile))
        elif field == STATIC_THRESHOLD:
            if profile in buf_profs and STATIC_THRESHOLD not in buf_profs[profile]:
                sys.exit("%s not using static threshold" % (profile))
        elif field == PACKET_TRIMMING:
            # Packet trimming eligibility configuration works for both dynamic/static buffer model
            pass
        else:
            sys.exit("Set field %s not supported" % (field))

        if self.verbose:
            print("Setting %s %s value to %s%s" % (profile, field, value, namespace_str))
        self.config_db.mod_entry(BUFFER_PROFILE_TABLE_NAME, profile, {field: value})
        if self.filename is not None:
            self.updated_profile_table[self.multi_asic.current_namespace] = self.config_db.get_table(BUFFER_PROFILE_TABLE_NAME)
            with open(self.filename, "w") as fd:
                json.dump(self.updated_profile_table, fd)


class TrimTypeValidator(click.ParamType):
    """ Trim option validator """
    name = "text"

    def get_metavar(self, param):
        return "[on|off]"

    def convert(self, value, param, ctx):
        click.Choice([PACKET_TRIMMING_ENABLED, PACKET_TRIMMING_DISABLED]).convert(value, param, ctx)

        return "trim" if value == PACKET_TRIMMING_ENABLED else "drop"


@click.command(help='Show and change: mmu configuration')
@click.option('-l', '--list', 'show_config', is_flag=True, help='show mmu configuration')
@click.option('-p', '--profile', type=str, help='specify buffer profile name', default=None)
@click.option('-a', '--alpha', type=click.IntRange(DYNAMIC_THRESHOLD_MIN, DYNAMIC_THRESHOLD_MAX), help='set n for dyanmic threshold alpha 2^(n)', default=None)
@click.option('-s', '--staticth', type=click.IntRange(min=STATIC_THRESHOLD_MIN), help='set static threshold', default=None)
@click.option('-t', '--trim', 'trim', type=TrimTypeValidator(), help='Configures packet trimming eligibility')
@click.option('-n', '--namespace', type=click.Choice(multi_asic.get_namespace_list()), help='Namespace name or skip for all', default=None)
@click.option('-vv', '--verbose', is_flag=True, help='verbose output', default=False)
@click.version_option(version='1.0')
def main(show_config, profile, alpha, staticth, trim, namespace, verbose):
    # A test file created for unit test purposes
    filename=None
    if os.environ.get("UTILITIES_UNIT_TESTING", "0") == "2":
        filename = '/tmp/mmuconfig'

    # Buffershow and mmuconfig cmds share this script
    # Buffershow cmd cannot modify configs hence config is set to False
    config = True if sys.argv[0].split('/')[-1] == "mmuconfig" else False

    try:
        load_db_config()
        mmu_cfg = MmuConfig(verbose, config, filename, namespace)

        # Both mmuconfig and buffershow have access to show_config option
        if show_config:
            mmu_cfg.list()
        # Buffershow cannot modify profiles
        elif config and profile:
            if alpha:
                mmu_cfg.set(profile, DYNAMIC_THRESHOLD, alpha)
            elif staticth:
                mmu_cfg.set(profile, STATIC_THRESHOLD, staticth)
            elif trim:
                mmu_cfg.set(profile, PACKET_TRIMMING, trim)
        else:
            ctx = click.get_current_context()
            click.echo(ctx.get_help())
            sys.exit(1)

    except Exception as e:
        print("Exception caught: ", str(e), file=sys.stderr)
        traceback.print_exc()
        sys.exit(1)

if __name__ == "__main__":
    main()
