#!python

"""
ecnconfig is the utility to

1) show and change ECN configuration

usage: ecnconfig [-h] [-v] [-l] [-p PROFILE] [-gmin GREEN_MIN]
                 [-gmax GREEN_MAX] [-ymin YELLOW_MIN] [-ymax YELLOW_MAX]
                 [-rmin RED_MIN] [-rmax RED_MAX] [-gdrop GREEN_DROP_PROB]
                 [-ydrop YELLOW_DROP_PROB] [-rdrop RED_DROP_PROB] [-vv]

optional arguments:
  -h     --help                show this help message and exit
  -v     --version             show program's version number and exit
  -vv    --verbose             verbose output
  -l     --list                show ECN WRED configuration
  -p     --profile             specify WRED profile name
  -gmin  --green-min           set min threshold for packets marked green
  -gmax  --green-max           set max threshold for packets marked green
  -ymin  --yellow-min          set min threshold for packets marked yellow
  -ymax  --yellow-max          set max threshold for packets marked yellow
  -rmin  --red-min             set min threshold for packets marked red
  -rmax  --red-max             set max threshold for packets marked red
  -gdrop --green-drop-prob     set max drop/mark probability for packets marked green
  -ydrop --yellow-drop-prob    set max drop/mark probability for packets marked yellow
  -rdrop --red-drop-prob       set max drop/mark probability for packets marked red

2) show and change ECN on/off status on queues

usage: ecnconfig [-h] [-v] [-q QUEUE_INDEX] [{on,off}] [-vv]

positional arguments:
  {on,off}                  turn on/off ecn

optional arguments:
  -h     --help                show this help message and exit
  -v     --version             show program's version number and exit
  -vv    --verbose             verbose output
  -q     --queue               specify queue index list

Sample outputs:
$ecnconfig -q 3 on -vv
Enable ECN on Ethernet0,Ethernet4,Ethernet8,Ethernet12,Ethernet16,Ethernet20,Ethernet24,Ethernet28,Ethernet32,Ethernet36,Ethernet40,Ethernet44,Ethernet48,Ethernet52,Ethernet56,Ethernet60,Ethernet64,Ethernet68,Ethernet72,Ethernet76,Ethernet80,Ethernet84,Ethernet88,Ethernet92,Ethernet96,Ethernet100,Ethernet104,Ethernet108,Ethernet112,Ethernet116,Ethernet120,Ethernet124 queue 3

$ecnconfig -q 3
ECN status:
queue 3: on
"""
import argparse
import json
import os
import sys

from tabulate import tabulate

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector

except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector, ConfigDBConnector


WRED_PROFILE_TABLE_NAME = "WRED_PROFILE"
WRED_CONFIG_FIELDS = {
    "gmax": "green_max_threshold",
    "gmin": "green_min_threshold",
    "ymax": "yellow_max_threshold",
    "ymin": "yellow_min_threshold",
    "rmax": "red_max_threshold",
    "rmin": "red_min_threshold",
    "gdrop": "green_drop_probability",
    "ydrop": "yellow_drop_probability",
    "rdrop": "red_drop_probability"
}

PORT_TABLE_NAME            = "PORT"
QUEUE_TABLE_NAME           = "QUEUE"
DEVICE_NEIGHBOR_TABLE_NAME = "DEVICE_NEIGHBOR"
FIELD                      = "wred_profile"
ON                         = "AZURE_LOSSLESS"

def chk_exec_privilege():
    if os.geteuid() != 0 and os.environ.get("UTILITIES_UNIT_TESTING", "0") != "2":
        sys.exit("Root privileges required for this operation")

class EcnConfig(object):
    """
    Process ecnconfig
    """
    def __init__(self, filename, verbose):
        self.ports = []
        self.queues = []
        self.filename = filename
        self.verbose = verbose

        # Set up db connections
        self.db = ConfigDBConnector()
        self.db.connect()

    def list(self):
        wred_profiles = self.db.get_table(WRED_PROFILE_TABLE_NAME)
        for name, data in wred_profiles.items():
            profile_name = name
            profile_data = data
            config = []
            print("Profile: " + profile_name)
            for field, value in profile_data.items():
                line = [field, value]
                config.append(line)
            print(tabulate(config) + "\n")
        if self.verbose:
            print("Total profiles: %d" % len(wred_profiles))

    # get parameters of a WRED profile
    def get_profile_data(self, profile):
        wred_profiles = self.db.get_table(WRED_PROFILE_TABLE_NAME)

        for profile_name, profile_data in wred_profiles.items():
            if profile_name == profile:
                return profile_data

        return None

    def validate_profile_data(self, profile_data):
        result = True

        # check if thresholds are non-negative integers
        # check if probabilities are non-nagative integers in [0, 100]
        for key, field in WRED_CONFIG_FIELDS.items():
            if field in profile_data:
                value = profile_data[field]

                if 'threshold' in field:
                    if value.isdigit() == False:
                        print("Invalid %s (%s). %s should be an non-negative integer." % (key, value, key))
                        result = False

                elif 'probability' in field:
                    if value.isdigit() == False or int(value) < 0 or int(value) > 100:
                        print("Invalid %s (%s). %s should be an integer between 0 and 100." % (key, value, key))
                        result = False

        if result == False:
            return result

        # check if min threshold is no larger than max threshold
        colors = ['g', 'y', 'r']
        for color in colors:
            if (WRED_CONFIG_FIELDS[color + 'min'] in profile_data and
               WRED_CONFIG_FIELDS[color + 'max'] in profile_data):

                min_thresh = int(profile_data[WRED_CONFIG_FIELDS[color + 'min']])
                max_thresh = int(profile_data[WRED_CONFIG_FIELDS[color + 'max']])

                if min_thresh > max_thresh:
                    print("Invalid %s (%d) and %s (%d). %s should be smaller than %s" %
                          (color + 'min', min_thresh, color + 'max', max_thresh, color + 'min', color + 'max'))
                    result = False

        return result

    def set_wred_threshold(self, profile, threshold, value):
        chk_exec_privilege()

        field = WRED_CONFIG_FIELDS[threshold]
        if self.verbose:
            print("Setting %s value to %s" % (field, value))
        self.db.mod_entry(WRED_PROFILE_TABLE_NAME, profile, {field: value})
        if self.filename is not None:
            prof_table = self.db.get_table(WRED_PROFILE_TABLE_NAME)
            with open(self.filename, "w") as fd:
                json.dump(prof_table, fd)

    def set_wred_prob(self, profile, drop_color, value):
        chk_exec_privilege()

        field = WRED_CONFIG_FIELDS[drop_color]
        if self.verbose:
            print("Setting %s value to %s%%" % (field, value))
        self.db.mod_entry(WRED_PROFILE_TABLE_NAME, profile, {field: value})
        if self.filename is not None:
            prof_table = self.db.get_table(WRED_PROFILE_TABLE_NAME)
            with open(self.filename, "w") as fd:
                json.dump(prof_table, fd)

class EcnQ(object):
    """
    Process ecn on/off on queues
    """
    def __init__(self, queues, filename, verbose):
        self.ports_key = []
        self.queues = queues.split(',')
        self.filename = filename
        self.verbose = verbose

        # Set up db connections
        self.config_db = ConfigDBConnector()
        self.config_db.connect()

        self.db = SonicV2Connector(use_unix_socket_path=False)
        self.db.connect(self.db.CONFIG_DB)

        self.gen_ports_key()

    def gen_ports_key(self):
        if self.ports_key is not None:
            port_table = self.config_db.get_table(DEVICE_NEIGHBOR_TABLE_NAME)
            self.ports_key = list(port_table.keys())

            # Verify at least one port is available
            if len(self.ports_key) == 0:
                raise Exception("No active ports detected in table '{}'".format(DEVICE_NEIGHBOR_TABLE_NAME))

            # In multi-ASIC platforms backend ethernet ports are identified as
            # 'Ethernet-BPxy'. Add 1024 to sort backend ports to the end.
            self.ports_key.sort(
                key = lambda k: int(k[8:]) if "BP" not in k else int(k[11:]) + 1024
            )

    def dump_table_info(self):
        if self.filename is not None:
            q_table = self.config_db.get_table(QUEUE_TABLE_NAME)
            with open(self.filename, "w") as fd:
                json.dump({repr(x):y for x, y in q_table.items()}, fd)

    def set(self, enable):
        chk_exec_privilege()

        for queue in self.queues:
            if self.verbose:
                print("%s ECN on %s queue %s" % ("Enable" if enable else "Disable", ','.join(self.ports_key), queue))
            for port_key in self.ports_key:
                key = '|'.join([port_key, queue])
                entry = self.config_db.get_entry(QUEUE_TABLE_NAME, key)
                if enable:
                    entry[FIELD] = ON
                else:
                    # Remove entry to propagate SAI change
                    if FIELD in entry:
                        del entry[FIELD]
                # If entry is now empty, remove the key
                if entry == {}:
                    self.config_db.mod_entry(QUEUE_TABLE_NAME, key, None)
                else:
                    self.config_db.set_entry(QUEUE_TABLE_NAME, key, entry)
        self.dump_table_info()

    def get(self):
        print("ECN status:")
        for queue in self.queues:
            out = ' '.join(['queue', queue])
            if self.verbose:
                out = ' '.join([','.join(self.ports_key), out])

            # ecn on/off status on a queue index is homogeneous among all ports
            # checking one port is sufficient
            key = '|'.join([QUEUE_TABLE_NAME, self.ports_key[0], queue])
            val = self.db.get(self.db.CONFIG_DB, key, FIELD)

            if val == ON:
                print("%s: on" % (out))
            else:
                print("%s: off" % (out))
        self.dump_table_info()

def main():
    parser = argparse.ArgumentParser(description='Show and change:\n'
                                                 '1) ECN WRED configuration\n'
                                                 '2) ECN on/off status on queues',
                                     formatter_class=argparse.RawTextHelpFormatter)

    parser.add_argument('-l', '--list', action='store_true', help='show ECN WRED configuration')
    parser.add_argument('-p', '--profile', type=str, help='specify WRED profile name', default=None)
    parser.add_argument('-gmin', '--green-min', type=str, help='set min threshold for packets marked \'green\'', default=None)
    parser.add_argument('-gmax', '--green-max', type=str, help='set max threshold for packets marked \'green\'', default=None)
    parser.add_argument('-ymin', '--yellow-min', type=str, help='set min threshold for packets marked \'yellow\'', default=None)
    parser.add_argument('-ymax', '--yellow-max', type=str, help='set max threshold for packets marked \'yellow\'', default=None)
    parser.add_argument('-rmin', '--red-min', type=str, help='set min threshold for packets marked \'red\'', default=None)
    parser.add_argument('-rmax', '--red-max', type=str, help='set max threshold for packets marked \'red\'', default=None)
    parser.add_argument('-gdrop', '--green-drop-prob', type=str, help='set max drop/mark probability for packets marked \'green\'', default=None)
    parser.add_argument('-ydrop', '--yellow-drop-prob', type=str, help='set max drop/mark probability for packets marked \'yellow\'', default=None)
    parser.add_argument('-rdrop', '--red-drop-prob', type=str, help='set max drop/mark probability for packets marked \'red\'', default=None)
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    parser.add_argument('-vv', '--verbose', action='store_true', help='Verbose output', default=False)

    parser.add_argument('command', nargs='?', choices=['on', 'off'], type=str, help='turn on/off ecn', default=None)
    parser.add_argument('-q', '--queue', type=str, help='specify queue index list: 3,4', default=None)
    parser.add_argument('-f', '--filename', help='file used by mock tests', type=str, default=None)

    if os.environ.get("UTILITIES_UNIT_TESTING", "0") == "2":
        sys.argv.extend(['-f', '/tmp/ecnconfig'])

    args = parser.parse_args()

    try:
        if args.list or args.profile:
            prof_cfg = EcnConfig(args.filename, args.verbose)
            if args.list:
                arg_len_max = 2
                if args.verbose:
                    arg_len_max += 1
                if args.filename:
                    arg_len_max += 2
                if len(sys.argv) > arg_len_max:
                    raise Exception("Input arguments error. No set options allowed when -l[ist] specified")
                prof_cfg.list()
            elif args.profile:
                arg_len_min = 4
                if args.verbose:
                    arg_len_min += 1
                if args.filename:
                    arg_len_min += 2
                if len(sys.argv) < arg_len_min:
                    raise Exception("Input arguments error. Specify at least one threshold parameter to set")

                # get current configuration data
                wred_profile_data = prof_cfg.get_profile_data(args.profile)
                if wred_profile_data is None:
                    raise Exception("Input arguments error. Invalid WRED profile %s" % (args.profile))

                if args.green_max:
                    wred_profile_data[WRED_CONFIG_FIELDS["gmax"]] = args.green_max
                if args.green_min:
                    wred_profile_data[WRED_CONFIG_FIELDS["gmin"]] = args.green_min
                if args.yellow_max:
                    wred_profile_data[WRED_CONFIG_FIELDS["ymax"]] = args.yellow_max
                if args.yellow_min:
                    wred_profile_data[WRED_CONFIG_FIELDS["ymin"]] = args.yellow_min
                if args.red_max:
                    wred_profile_data[WRED_CONFIG_FIELDS["rmax"]] = args.red_max
                if args.red_min:
                    wred_profile_data[WRED_CONFIG_FIELDS["rmin"]] = args.red_min
                if args.green_drop_prob:
                    wred_profile_data[WRED_CONFIG_FIELDS["gdrop"]] = args.green_drop_prob
                if args.yellow_drop_prob:
                    wred_profile_data[WRED_CONFIG_FIELDS["ydrop"]] = args.yellow_drop_prob
                if args.red_drop_prob:
                    wred_profile_data[WRED_CONFIG_FIELDS["rdrop"]] = args.red_drop_prob

                # validate new configuration data
                if prof_cfg.validate_profile_data(wred_profile_data) == False:
                    raise Exception("Input arguments error. Invalid WRED profile parameters")

                # apply new configuration
                # the following parameters can be combined in one run
                if args.green_max:
                    prof_cfg.set_wred_threshold(args.profile, "gmax", args.green_max)
                if args.green_min:
                    prof_cfg.set_wred_threshold(args.profile, "gmin", args.green_min)
                if args.yellow_max:
                    prof_cfg.set_wred_threshold(args.profile, "ymax", args.yellow_max)
                if args.yellow_min:
                    prof_cfg.set_wred_threshold(args.profile, "ymin", args.yellow_min)
                if args.red_max:
                    prof_cfg.set_wred_threshold(args.profile, "rmax", args.red_max)
                if args.red_min:
                    prof_cfg.set_wred_threshold(args.profile, "rmin", args.red_min)
                if args.green_drop_prob:
                    prof_cfg.set_wred_prob(args.profile, "gdrop", args.green_drop_prob)
                if args.yellow_drop_prob:
                    prof_cfg.set_wred_prob(args.profile, "ydrop", args.yellow_drop_prob)
                if args.red_drop_prob:
                    prof_cfg.set_wred_prob(args.profile, "rdrop", args.red_drop_prob)

        elif args.queue:
            arg_len_min = 3
            if args.filename:
                arg_len_min += 1
            if args.verbose:
                arg_len_min += 1
            if len(sys.argv) < arg_len_min:
                raise Exception("Input arguments error. Specify at least one queue by index")

            q_ecn = EcnQ(args.queue, args.filename, args.verbose)
            if not args.command:
                q_ecn.get()
            else:
                q_ecn.set(enable = True if args.command == 'on' else False)
        else:
            parser.print_help()
            sys.exit(1)

    except Exception as e:
        print("Exception caught: ", str(e), file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()
