#!python

"""
mmuconfig is the utility to show and change mmu configuration

usage: mmuconfig [-h] [-v] [-l] [-p PROFILE] [-a ALPHA] [-vv] [-s staticth]

optional arguments:
  -h     --help            show this help message and exit
  -v     --version         show program's version number and exit
  -vv    --verbose         verbose output
  -l     --list            show mmu configuration
  -p     --profile         specify buffer profile name
  -a     --alpha           set n for dyanmic threshold alpha 2^(n)
  -s     --staticth        set static threshold

"""

import os
import sys
import argparse
import tabulate
import traceback
import json

BUFFER_POOL_TABLE_NAME = "BUFFER_POOL"
BUFFER_PROFILE_TABLE_NAME = "BUFFER_PROFILE"
DEFAULT_LOSSLESS_BUFFER_PARAMETER_NAME = "DEFAULT_LOSSLESS_BUFFER_PARAMETER"

DYNAMIC_THRESHOLD = "dynamic_th"
STATIC_THRESHOLD = "static_th"
BUFFER_PROFILE_FIELDS = {
    "alpha": DYNAMIC_THRESHOLD,
    "staticth" : STATIC_THRESHOLD
}

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector

except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector, ConfigDBConnector

class MmuConfig(object):
    def __init__(self, verbose, config, filename):
        self.verbose = verbose
        self.config = config
        self.filename = filename

        # Set up db connections
        if self.config:
            self.db = ConfigDBConnector()
            self.db.connect()
        else:
            self.db = SonicV2Connector(use_unix_socket_path=False)
            self.db.connect(self.db.STATE_DB, False)

    def get_table(self, tablename):
        if self.config:
            return self.db.get_table(tablename)

        entries = {}
        keys = self.db.keys(self.db.STATE_DB, tablename + '*')

        if not keys:
            return None

        for key in keys:
            entries[key.split('|')[1]] = self.db.get_all(self.db.STATE_DB, key)

        return entries

    def list(self):
        lossless_traffic_pattern = self.get_table(DEFAULT_LOSSLESS_BUFFER_PARAMETER_NAME)
        if lossless_traffic_pattern:
            for _, pattern in lossless_traffic_pattern.items():
                config = []

                print("Lossless traffic pattern:")
                for field, value in pattern.items():
                    config.append([field, value])
                print(tabulate.tabulate(config) + "\n")

        buf_pools = self.get_table(BUFFER_POOL_TABLE_NAME)
        if buf_pools:
            for pool_name, pool_data in buf_pools.items():
                config = []

                print("Pool: " + pool_name)
                for field, value in pool_data.items():
                    config.append([field, value])
                print(tabulate.tabulate(config) + "\n")
            if self.verbose:
                print("Total pools: %d\n\n" % len(buf_pools))
        else:
            print("No buffer pool information available")

        buf_profs = self.get_table(BUFFER_PROFILE_TABLE_NAME)
        if buf_profs:
            for prof_name, prof_data in buf_profs.items():
                config = []

                print("Profile: " + prof_name)
                for field, value in prof_data.items():
                    config.append([field, value])
                print(tabulate.tabulate(config) + "\n")
            if self.verbose:
                print("Total profiles: %d" % len(buf_profs))
        else:
            print("No buffer profile information available")

    def set(self, profile, field_alias, value):
        if os.geteuid() != 0 and os.environ.get("UTILITIES_UNIT_TESTING", "0") != "2":
            sys.exit("Root privileges required for this operation")

        field = BUFFER_PROFILE_FIELDS[field_alias]
        buf_profs = self.db.get_table(BUFFER_PROFILE_TABLE_NAME)
        v = int(value)
        if field == DYNAMIC_THRESHOLD:
            if v < -8 or v > 8:
                sys.exit("Invalid alpha value: 2^(%s)" % (value))

            if profile in buf_profs and DYNAMIC_THRESHOLD not in buf_profs[profile]:
                sys.exit("%s not using dynamic thresholding" % (profile))
        elif field == STATIC_THRESHOLD:
            if v < 0:
                sys.exit("Invalid static threshold value: (%s)" % (value))

            if profile in buf_profs and STATIC_THRESHOLD not in buf_profs[profile]:
                sys.exit("%s not using static threshold" % (profile))
        else:
            sys.exit("Set field %s not supported" % (field))

        if self.verbose:
            print("Setting %s %s value to %s" % (profile, field, value))
        self.db.mod_entry(BUFFER_PROFILE_TABLE_NAME, profile, {field: value})
        if self.filename is not None:
            prof_table = self.db.get_table(BUFFER_PROFILE_TABLE_NAME)
            with open(self.filename, "w") as fd:
                json.dump(prof_table, fd)


def main(config):
    if config:
        parser = argparse.ArgumentParser(description='Show and change: mmu configuration',
                                         formatter_class=argparse.RawTextHelpFormatter)

        parser.add_argument('-l', '--list', action='store_true', help='show mmu configuration')
        parser.add_argument('-p', '--profile', type=str, help='specify buffer profile name', default=None)
        parser.add_argument('-a', '--alpha', type=str, help='set n for dyanmic threshold alpha 2^(n)', default=None)
        parser.add_argument('-s', '--staticth', type=str, help='set static threshold', default=None)
        parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    else:
        parser = argparse.ArgumentParser(description='Show buffer state',
                                         formatter_class=argparse.RawTextHelpFormatter)

        parser.add_argument('-l', '--list', action='store_true', help='show buffer state')
        parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')

    parser.add_argument('-vv', '--verbose', action='store_true', help='verbose output', default=False)
    parser.add_argument('-f', '--filename', help='file used by mock tests', type=str, default=None)

    if os.environ.get("UTILITIES_UNIT_TESTING", "0") == "2":
        sys.argv.extend(['-f', '/tmp/mmuconfig'])


    args = parser.parse_args()

    try:
        mmu_cfg = MmuConfig(args.verbose, config, args.filename)
        if args.list:
            mmu_cfg.list()
        elif config and args.profile:
            if args.alpha:
                mmu_cfg.set(args.profile, "alpha", args.alpha)
            elif args.staticth:
                mmu_cfg.set(args.profile, "staticth", args.staticth)
        else:
            parser.print_help()
            sys.exit(1)

    except Exception as e:
        print("Exception caught: ", str(e), file=sys.stderr)
        traceback.print_exc()
        sys.exit(1)

if __name__ == "__main__":
    if sys.argv[0].split('/')[-1] == "mmuconfig":
        main(True)
    else:
        main(False)
