#!python

"""
ecnconfig is the utility to

1) show and change ECN configuration

usage: ecnconfig [-h] [-v] [-l] [-p PROFILE] [-gmin GREEN_MIN] [-n NAMESPACE]
                 [-gmax GREEN_MAX] [-ymin YELLOW_MIN] [-ymax YELLOW_MAX]
                 [-rmin RED_MIN] [-rmax RED_MAX] [-gdrop GREEN_DROP_PROB]
                 [-ydrop YELLOW_DROP_PROB] [-rdrop RED_DROP_PROB] [-vv]

optional arguments:
  -h     --help                show this help message and exit
  -v     --version             show program's version number and exit
  -vv    --verbose             verbose output
  -l     --list                show ECN WRED configuration
  -p     --profile             specify WRED profile name
  -n     --namespace           show ECN configuration for specified namespace
  -gmin  --green-min           set min threshold for packets marked green
  -gmax  --green-max           set max threshold for packets marked green
  -ymin  --yellow-min          set min threshold for packets marked yellow
  -ymax  --yellow-max          set max threshold for packets marked yellow
  -rmin  --red-min             set min threshold for packets marked red
  -rmax  --red-max             set max threshold for packets marked red
  -gdrop --green-drop-prob     set max drop/mark probability for packets marked green
  -ydrop --yellow-drop-prob    set max drop/mark probability for packets marked yellow
  -rdrop --red-drop-prob       set max drop/mark probability for packets marked red

2) show and change ECN on/off status on queues

usage: ecnconfig [-h] [-v] [-q QUEUE_INDEX] [{on,off}] [-vv]

positional arguments:
  {on,off}                  turn on/off ecn

optional arguments:
  -h     --help                show this help message and exit
  -v     --version             show program's version number and exit
  -vv    --verbose             verbose output
  -q     --queue               specify queue index list

Sample outputs:
$ecnconfig -q 3 on -vv
Enable ECN on Ethernet0,Ethernet4,Ethernet8,Ethernet12,Ethernet16,Ethernet20,Ethernet24,Ethernet28,Ethernet32,Ethernet36,Ethernet40,Ethernet44,Ethernet48,Ethernet52,Ethernet56,Ethernet60,Ethernet64,Ethernet68,Ethernet72,Ethernet76,Ethernet80,Ethernet84,Ethernet88,Ethernet92,Ethernet96,Ethernet100,Ethernet104,Ethernet108,Ethernet112,Ethernet116,Ethernet120,Ethernet124 queue 3

$ecnconfig -q 3
ECN status:
queue 3: on
"""
import click
import json
import os
import sys

from tabulate import tabulate

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()
except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector, ConfigDBConnector

from sonic_py_common import multi_asic
from utilities_common import multi_asic as multi_asic_util
from utilities_common.general import load_db_config

WRED_PROFILE_TABLE_NAME = "WRED_PROFILE"
WRED_CONFIG_FIELDS = {
    "gmax": "green_max_threshold",
    "gmin": "green_min_threshold",
    "ymax": "yellow_max_threshold",
    "ymin": "yellow_min_threshold",
    "rmax": "red_max_threshold",
    "rmin": "red_min_threshold",
    "gdrop": "green_drop_probability",
    "ydrop": "yellow_drop_probability",
    "rdrop": "red_drop_probability"
}

QUEUE_TABLE_NAME           = "QUEUE"
DEVICE_NEIGHBOR_TABLE_NAME = "DEVICE_NEIGHBOR"
FIELD                      = "wred_profile"
ON                         = "AZURE_LOSSLESS"

def chk_exec_privilege():
    if os.geteuid() != 0 and os.environ.get("UTILITIES_UNIT_TESTING", "0") != "2":
        sys.exit("Root privileges required for this operation")

class EcnConfig(object):
    """
    Process ecnconfig
    """
    def __init__(self, test_filename, verbose, namespace):
        self.ports = []
        self.queues = []
        self.verbose = verbose
        self.namespace = namespace
        self.multi_asic = multi_asic_util.MultiAsic(namespace_option=namespace)
        self.config_db = None
        self.num_wred_profiles = 0

        # For unit testing
        self.test_filename = test_filename
        self.updated_profile_tables = {}

    @multi_asic_util.run_on_multi_asic
    def list(self):
        """
        List all WRED profiles.
        """
        wred_profiles = self.config_db.get_table(WRED_PROFILE_TABLE_NAME)
        for name, data in wred_profiles.items():
            profile_name = name
            profile_data = data
            config = []
            print("Profile: " + profile_name)
            for field, value in profile_data.items():
                line = [field, value]
                config.append(line)
            print(tabulate(config) + "\n")
        self.num_wred_profiles += len(wred_profiles)

    def get_profile_data(self, profile):
        """
        Get parameters of a WRED profile
        """
        if self.namespace or not multi_asic.is_multi_asic():
            db = ConfigDBConnector(namespace=self.namespace)
            db.connect()
            wred_profiles = db.get_table(WRED_PROFILE_TABLE_NAME)
        else:
            wred_profiles = multi_asic.get_table(WRED_PROFILE_TABLE_NAME)

        for profile_name, profile_data in wred_profiles.items():
            if profile_name == profile:
                return profile_data

        return None

    def validate_profile_data(self, profile_data):
        """
        Validate threshold, probability and color values.
        """
        result = True

        # check if thresholds are non-negative integers
        # check if probabilities are non-nagative integers in [0, 100]
        for key, field in WRED_CONFIG_FIELDS.items():
            if field in profile_data:
                value = profile_data[field]

                if 'threshold' in field:
                    if value.isdigit() == False:
                        print("Invalid %s (%s). %s should be an non-negative integer." % (key, value, key))
                        result = False

                elif 'probability' in field:
                    if value.isdigit() == False or int(value) < 0 or int(value) > 100:
                        print("Invalid %s (%s). %s should be an integer between 0 and 100." % (key, value, key))
                        result = False

        if result == False:
            return result

        # check if min threshold is no larger than max threshold
        colors = ['g', 'y', 'r']
        for color in colors:
            if (WRED_CONFIG_FIELDS[color + 'min'] in profile_data and
               WRED_CONFIG_FIELDS[color + 'max'] in profile_data):

                min_thresh = int(profile_data[WRED_CONFIG_FIELDS[color + 'min']])
                max_thresh = int(profile_data[WRED_CONFIG_FIELDS[color + 'max']])

                if min_thresh > max_thresh:
                    print("Invalid %s (%d) and %s (%d). %s should be smaller than %s" %
                          (color + 'min', min_thresh, color + 'max', max_thresh, color + 'min', color + 'max'))
                    result = False

        return result

    @multi_asic_util.run_on_multi_asic
    def set_wred_threshold(self, profile, threshold, value):
        """
        Single asic behaviour:
        Set threshold value on default namespace

        Multi asic behaviour:
        Set threshold value on the specified namespace.
        If no namespace is provided, set on all namespaces.
        """
        chk_exec_privilege()

        # Modify the threshold
        field = WRED_CONFIG_FIELDS[threshold]
        if self.verbose:
            namespace_str = f" for namespace {self.multi_asic.current_namespace}" if multi_asic.is_multi_asic() else ''
            print("Setting %s value to %s%s" % (field, value, namespace_str))
        self.config_db.mod_entry(WRED_PROFILE_TABLE_NAME, profile, {field: value})

        # Record the change for unit testing
        if self.test_filename:
            profile_table = self.config_db.get_table(WRED_PROFILE_TABLE_NAME)
            if self.multi_asic.current_namespace in self.updated_profile_tables.keys():
                self.updated_profile_tables[self.multi_asic.current_namespace][profile][threshold] = value
            else:
                self.updated_profile_tables[self.multi_asic.current_namespace] = profile_table

    @multi_asic_util.run_on_multi_asic
    def set_wred_prob(self, profile, drop_color, value):
        """
        Single asic behaviour:
        Set drop probability on default namespace

        Multi asic behaviour:
        Set drop probability value on the specified namespace.
        If no namespace is provided, set on all namespaces.
        """
        chk_exec_privilege()

        # Modify the drop probability
        field = WRED_CONFIG_FIELDS[drop_color]
        if self.verbose:
            namespace_str = f" for namespace {self.multi_asic.current_namespace}" if multi_asic.is_multi_asic() else ''
            print("Setting %s value to %s%%%s" % (field, value, namespace_str))
        self.config_db.mod_entry(WRED_PROFILE_TABLE_NAME, profile, {field: value})

        # Record the change for unit testing
        if self.test_filename:
            profile_table = self.config_db.get_table(WRED_PROFILE_TABLE_NAME)
            if self.multi_asic.current_namespace in self.updated_profile_tables.keys():
                self.updated_profile_tables[self.multi_asic.current_namespace][profile][field] = value
            else:
                self.updated_profile_tables[self.multi_asic.current_namespace] = profile_table

class EcnQ(object):
    """
    Process ecn on/off on queues
    """
    def __init__(self, queues, test_filename, verbose, namespace):
        self.ports_key = []
        self.queues = queues.split(',')
        self.verbose = verbose
        self.namespace = namespace
        self.multi_asic = multi_asic_util.MultiAsic(namespace_option=namespace)
        self.config_db = None
        self.db = None

        # For unit testing
        self.test_filename = test_filename
        self.updated_q_table = {}

    def gen_ports_key(self):
        port_table = self.config_db.get_table(DEVICE_NEIGHBOR_TABLE_NAME)
        self.ports_key = list(port_table.keys())

        # Verify at least one port is available
        if len(self.ports_key) == 0:
            raise Exception("No active ports detected in table '{}'".format(DEVICE_NEIGHBOR_TABLE_NAME))

        # In multi-ASIC platforms backend ethernet ports are identified as
        # 'Ethernet-BPxy'. Add 1024 to sort backend ports to the end.
        self.ports_key.sort(
            key = lambda k: int(k[8:]) if "BP" not in k else int(k[11:]) + 1024
        )

    def dump_table_info(self):
        """
        A function to dump updated queue tables.
        These JSON dumps are used exclusively by unit tests.
        The tables are organized by namespaces for multi-asic support.
        """
        if self.test_filename is not None:
            q_table = self.config_db.get_table(QUEUE_TABLE_NAME)
            with open(self.test_filename, "w") as fd:
                self.updated_q_table[self.multi_asic.current_namespace] = {repr(x):y for x, y in q_table.items()}
                json.dump(self.updated_q_table, fd)

    @multi_asic_util.run_on_multi_asic
    def set(self, enable):
        """
        Single asic behaviour:
        Enable or disable queues on default namespace

        Multi asic behaviour:
        Enable or disable queues on a specified namespace.
        If no namespace is provided, set on all namespaces.
        """
        chk_exec_privilege()

        self.gen_ports_key()
        for queue in self.queues:
            if self.verbose:
                print("%s ECN on %s queue %s" % ("Enable" if enable else "Disable", ','.join(self.ports_key), queue))
            for port_key in self.ports_key:
                key = '|'.join([port_key, queue])
                entry = self.config_db.get_entry(QUEUE_TABLE_NAME, key)
                if enable:
                    entry[FIELD] = ON
                else:
                    # Remove entry to propagate SAI change
                    if FIELD in entry:
                        del entry[FIELD]
                # If entry is now empty, remove the key
                if entry == {}:
                    self.config_db.mod_entry(QUEUE_TABLE_NAME, key, None)
                else:
                    self.config_db.set_entry(QUEUE_TABLE_NAME, key, entry)
        # For unit testing
        self.dump_table_info()

    @multi_asic_util.run_on_multi_asic
    def get(self):
        """
        Single asic behaviour:
        Get status of queues on default namespace

        Multi asic behaviour:
        Get status of queues on a specified namespace.
        If no namespace is provided, get queue status on all namespaces.
        """
        self.gen_ports_key()
        namespace = self.multi_asic.current_namespace
        namespace_str = f" for namespace {namespace}" if namespace else ''
        print(f"ECN status{namespace_str}:")

        for queue in self.queues:
            out = ' '.join(['queue', queue])
            if self.verbose:
                out = ' '.join([','.join(self.ports_key), out])

            # ecn on/off status on a queue index is homogeneous among all ports
            # checking one port is sufficient
            key = '|'.join([QUEUE_TABLE_NAME, self.ports_key[0], queue])
            val = self.db.get(self.db.CONFIG_DB, key, FIELD)

            if val == ON:
                print("%s: on" % (out))
            else:
                print("%s: off" % (out))
        # For unit testing
        self.dump_table_info()

@click.command(help='Show and change: ECN WRED configuration\nECN on/off status on queues')
@click.argument('command', type=click.Choice(['on', 'off'], case_sensitive=False), required=False, default=None)
@click.option('-l', '--list', 'show_config', is_flag=True, help='show ECN WRED configuration')
@click.option('-p', '--profile', type=str, help='specify WRED profile name', default=None)
@click.option('-gmin', '--green-min', type=str, help='set min threshold for packets marked \'green\'', default=None)
@click.option('-gmax', '--green-max', type=str, help='set max threshold for packets marked \'green\'', default=None)
@click.option('-ymin', '--yellow-min', type=str, help='set min threshold for packets marked \'yellow\'', default=None)
@click.option('-ymax', '--yellow-max', type=str, help='set max threshold for packets marked \'yellow\'', default=None)
@click.option('-rmin', '--red-min', type=str, help='set min threshold for packets marked \'red\'', default=None)
@click.option('-rmax', '--red-max', type=str, help='set max threshold for packets marked \'red\'', default=None)
@click.option('-gdrop', '--green-drop-prob', type=str, help='set max drop/mark probability for packets marked \'green\'', default=None)
@click.option('-ydrop', '--yellow-drop-prob', type=str, help='set max drop/mark probability for packets marked \'yellow\'', default=None)
@click.option('-rdrop', '--red-drop-prob', type=str, help='set max drop/mark probability for packets marked \'red\'', default=None)
@click.option('-n', '--namespace', type=click.Choice(multi_asic.get_namespace_list()), help='Namespace name or skip for all', default=None)
@click.option('-vv', '--verbose', is_flag=True, help='Verbose output', default=False)
@click.option('-q', '--queue', type=str, help='specify queue index list: 3,4', default=None)
@click.version_option(version='1.0')
def main(command, show_config, profile, green_min,
         green_max, yellow_min, yellow_max, red_min,
         red_max, green_drop_prob, yellow_drop_prob,
         red_drop_prob, namespace, verbose, queue):
    test_filename = None
    if os.environ.get("UTILITIES_UNIT_TESTING", "0") == "2":
        test_filename = '/tmp/ecnconfig'

    try:
        load_db_config()
        if show_config or profile:
            # Check if a set option has been provided
            setOption = (green_min or green_max or yellow_min or yellow_max or red_min or red_max
                         or green_drop_prob or yellow_drop_prob or red_drop_prob)

            prof_cfg = EcnConfig(test_filename, verbose, namespace)
            if show_config:
                if setOption:
                    raise Exception("Input arguments error. No set options allowed when -l[ist] specified")

                prof_cfg.list()
                if verbose:
                    print("Total profiles: %d" % prof_cfg.num_wred_profiles)

            elif profile:
                if not setOption:
                    raise Exception("Input arguments error. Specify at least one threshold parameter to set")

                # get current configuration data
                wred_profile_data = prof_cfg.get_profile_data(profile)
                if wred_profile_data is None:
                    raise Exception("Input arguments error. Invalid WRED profile %s for namespace %s" % (profile, namespace))

                if green_max:
                    wred_profile_data[WRED_CONFIG_FIELDS["gmax"]] = green_max
                if green_min:
                    wred_profile_data[WRED_CONFIG_FIELDS["gmin"]] = green_min
                if yellow_max:
                    wred_profile_data[WRED_CONFIG_FIELDS["ymax"]] = yellow_max
                if yellow_min:
                    wred_profile_data[WRED_CONFIG_FIELDS["ymin"]] = yellow_min
                if red_max:
                    wred_profile_data[WRED_CONFIG_FIELDS["rmax"]] = red_max
                if red_min:
                    wred_profile_data[WRED_CONFIG_FIELDS["rmin"]] = red_min
                if green_drop_prob:
                    wred_profile_data[WRED_CONFIG_FIELDS["gdrop"]] = green_drop_prob
                if yellow_drop_prob:
                    wred_profile_data[WRED_CONFIG_FIELDS["ydrop"]] = yellow_drop_prob
                if red_drop_prob:
                    wred_profile_data[WRED_CONFIG_FIELDS["rdrop"]] = red_drop_prob

                # validate new configuration data
                if prof_cfg.validate_profile_data(wred_profile_data) == False:
                    raise Exception("Input arguments error. Invalid WRED profile parameters")

                # apply new configuration
                # the following parameters can be combined in one run
                if green_max:
                    prof_cfg.set_wred_threshold(profile, "gmax", green_max)
                if green_min:
                    prof_cfg.set_wred_threshold(profile, "gmin", green_min)
                if yellow_max:
                    prof_cfg.set_wred_threshold(profile, "ymax", yellow_max)
                if yellow_min:
                    prof_cfg.set_wred_threshold(profile, "ymin", yellow_min)
                if red_max:
                    prof_cfg.set_wred_threshold(profile, "rmax", red_max)
                if red_min:
                    prof_cfg.set_wred_threshold(profile, "rmin", red_min)
                if green_drop_prob:
                    prof_cfg.set_wred_prob(profile, "gdrop", green_drop_prob)
                if yellow_drop_prob:
                    prof_cfg.set_wred_prob(profile, "ydrop", yellow_drop_prob)
                if red_drop_prob:
                    prof_cfg.set_wred_prob(profile, "rdrop", red_drop_prob)

                # Dump the current config in the file for unit tests
                if test_filename:
                    with open(test_filename, "w") as fd:
                        json.dump(prof_cfg.updated_profile_tables, fd)

        elif queue:
            if queue.split(',') == ['']:
                raise Exception("Input arguments error. Specify at least one queue by index")
            q_ecn = EcnQ(queue, test_filename, verbose, namespace)
            if command is None:
                q_ecn.get()
            else:
                q_ecn.set(enable = True if command == 'on' else False)
        else:
            sys.exit(1)

    except Exception as e:
        print("Exception caught: ", str(e), file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()
