#!python

"""
portconfig is the utility to show and change ECN configuration

usage: portconfig [-h] [-v] [-s] [-f] [-m] [-tp] [-p PROFILE] [-gmin GREEN_MIN]
                 [-gmax GREEN_MAX] [-ymin YELLOW_MIN] [-ymax YELLOW_MAX]
                 [-rmin RED_MIN] [-rmax RED_MAX] [-vv] [-n namespace]

optional arguments:
  -h     --help                show this help message and exit
  -v     --version             show program's version number and exit
  -vv    --verbose             verbose output
  -p     --port                port name
  -s     --speed               port speed in Mbits
  -f     --fec                 port fec mode
  -m     --mtu                 port mtu in bytes
  -tp    --tpid                interface tpid (0x8100, 9100, 9200, 88A8)
  -n     --namesapce           Namespace name
  -an    --autoneg             port auto negotiation mode
  -S     --adv-speeds          port advertised speeds
  -t     --interface-type      port interface type
  -T     --adv-interface-types port advertised interface types
  -lt    --link-training       port link training mode
  -P     --tx-power            400G ZR modulet target Tx output power (dBm)
  -F     --laser-freq          400G ZR module 75GHz grid frequency (GHz)
"""
import os
import sys
import decimal
import argparse

from utilities_common.constants import DEFAULT_SUPPORTED_FECS_LIST

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "1" or os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        test_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, test_path)
        import mock_tables.dbconnector
except KeyError:
    pass

from swsscommon.swsscommon import ConfigDBConnector, SonicV2Connector
from utilities_common.general import load_db_config

# APPL_DB constants
PORT_TABLE_NAME = "PORT"
PORT_SPEED_CONFIG_FIELD_NAME = "speed"
PORT_FEC_CONFIG_FIELD_NAME = "fec"
PORT_MTU_CONFIG_FIELD_NAME = "mtu"
PORT_AUTONEG_CONFIG_FIELD_NAME = "autoneg"
PORT_ADV_SPEEDS_CONFIG_FIELD_NAME = "adv_speeds"
PORT_INTERFACE_TYPE_CONFIG_FIELD_NAME = "interface_type"
PORT_ADV_INTERFACE_TYPES_CONFIG_FIELD_NAME = "adv_interface_types"
PORT_LINK_TRAINING_CONFIG_FIELD_NAME = "link_training"
PORT_XCVR_LASER_FREQ_FIELD_NAME = "laser_freq"
PORT_XCVR_TX_POWER_FIELD_NAME = "tx_power"
PORT_CHANNEL_TABLE_NAME = "PORTCHANNEL"
PORT_CHANNEL_MBR_TABLE_NAME = "PORTCHANNEL_MEMBER"
TPID_CONFIG_FIELD_NAME = "tpid"
SWITCH_CAPABILITY = "SWITCH_CAPABILITY|switch"


# STATE_DB constants
PORT_STATE_TABLE_NAME = "PORT_TABLE"
PORT_STATE_SUPPORTED_SPEEDS = "supported_speeds"
PORT_STATE_SUPPORTED_FECS = "supported_fecs"

VALID_INTERFACE_TYPE_SET = set(['CR','CR2','CR4','CR8','SR','SR2','SR4','SR8',
                                'LR','LR4','LR8','KR','KR4','KR8','CAUI','GMII',
                                'SFI','XLAUI','KR2','CAUI4','XAUI',
                                'XFI','XGMII', 'none'])

class portconfig(object):
    """
    Process aclstat
    """
    def __init__(self, verbose, port, namespace):
        self.verbose = verbose
        self.namespace = namespace
        self.is_lag = False
        self.is_lag_mbr = False
        self.parent = port
        self.is_rj45_port = False
        # Set up db connections
        if namespace is None:
            self.db = ConfigDBConnector()
            self.state_db = SonicV2Connector(host='127.0.0.1')
        else:
            self.db = ConfigDBConnector(use_unix_socket_path=True, namespace=namespace)
            self.state_db = SonicV2Connector(use_unix_socket_path=True, namespace=namespace)
        self.db.connect()
        self.state_db.connect(self.state_db.STATE_DB, False)

        # check whether table for this port exists
        port_tables = self.db.get_table(PORT_TABLE_NAME)
        lag_tables = self.db.get_table(PORT_CHANNEL_TABLE_NAME)
        lag_mbr_tables = self.db.get_table(PORT_CHANNEL_MBR_TABLE_NAME)
        if port in port_tables:
            port_transceiver_info = self.state_db.get_all(self.state_db.STATE_DB, "TRANSCEIVER_INFO|{}".format(port))
            if port_transceiver_info:
                self.is_rj45_port = True if port_transceiver_info.get("type") == "RJ45" else False
            for key in lag_mbr_tables:
                if port == key[1]:
                    self.parent = key[0]
                    self.is_lag_mbr = True
                    break;
        elif port in lag_tables:
            self.is_lag = True
        else:
            raise Exception("Invalid port %s" % (port))

    def list_params(self, port):
        # chack whether table for this port exists
        port_tables = self.db.get_table(PORT_TABLE_NAME)
        if port in port_tables:
            print(port_tables[port])

    def set_speed(self, port, speed):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        if self.verbose:
            print("Setting speed %s on port %s" % (speed, port))
        supported_speeds_str = self.get_supported_speeds(port) or ''
        supported_speeds = [int(s) for s in supported_speeds_str.split(',') if s]
        if supported_speeds and int(speed) not in supported_speeds:
            print('Invalid speed specified: {}'.format(speed))
            print('Valid speeds:{}'.format(supported_speeds_str))
            exit(1)
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_SPEED_CONFIG_FIELD_NAME: speed})

    def set_fec(self, port, fec):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        if self.verbose:
            print("Setting fec %s on port %s" % (fec, port))
        supported_fecs = self.get_supported_fecs(port)
        if fec not in supported_fecs:
            if supported_fecs:
                print('fec {} is not in {}'.format(fec, supported_fecs))
            else:
                print('Setting fec is not supported on port {}'.format(port))
            exit(1)
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_FEC_CONFIG_FIELD_NAME: fec})

    def set_mtu(self, port, mtu):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        if self.verbose:
            print("Setting mtu %s on port %s" % (mtu, port))
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_MTU_CONFIG_FIELD_NAME: mtu})

    def set_link_training(self, port, mode):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        if self.verbose:
            print("Setting link-training %s on port %s" % (mode, port))
        lt_modes = ['on', 'off']
        if mode not in lt_modes:
            print('Invalid mode specified: {}'.format(mode))
            print('Valid modes: {}'.format(','.join(lt_modes)))
            exit(1)
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_LINK_TRAINING_CONFIG_FIELD_NAME: mode})

    def set_autoneg(self, port, mode):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        if self.verbose:
            print("Setting autoneg %s on port %s" % (mode, port))
        mode = 'on' if mode == 'enabled' else 'off'
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_AUTONEG_CONFIG_FIELD_NAME: mode})

    def set_tx_power(self, port, tx_power):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        print("Setting target Tx output power to %s dBm on port %s" % (tx_power, port))
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_XCVR_TX_POWER_FIELD_NAME: tx_power})

    def set_laser_freq(self, port, laser_freq):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        print("Setting laser frequency to %s GHz on port %s" % (laser_freq, port))
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_XCVR_LASER_FREQ_FIELD_NAME: laser_freq})

    def set_adv_speeds(self, port, adv_speeds):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        if self.verbose:
            print("Setting adv_speeds %s on port %s" % (adv_speeds, port))

        if adv_speeds != 'all':
            supported_speeds_str = self.get_supported_speeds(port)
            if supported_speeds_str:
                supported_speeds = set(supported_speeds_str.split(','))
                config_speed_list = [x.strip() for x in adv_speeds.split(',')]
                config_speeds = set(config_speed_list)
                if len(config_speeds) < len(config_speed_list):
                    print('Invalid speed specified: {}. Please remove duplicate speed'.format(adv_speeds))
                    exit(1)
                invalid_speeds = config_speeds - supported_speeds
                if invalid_speeds:
                    print('Invalid speed specified: {}'.format(','.join(invalid_speeds)))
                    print('Valid speeds:{}'.format(supported_speeds_str))
                    exit(1)

        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_ADV_SPEEDS_CONFIG_FIELD_NAME: adv_speeds})

    def set_interface_type(self, port, interface_type):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        if self.is_rj45_port:
            print("Setting RJ45 ports' type is not supported")
            exit(1)
        if self.verbose:
            print("Setting interface_type %s on port %s" % (interface_type, port))
        if interface_type not in VALID_INTERFACE_TYPE_SET:
            print("Invalid interface type specified: {}".format(interface_type))
            print("Valid interface types:{}".format(','.join(VALID_INTERFACE_TYPE_SET)))
            exit(1)
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_INTERFACE_TYPE_CONFIG_FIELD_NAME: interface_type})

    def set_adv_interface_types(self, port, adv_interface_types):
        if self.is_lag:
            raise Exception("Invalid port %s" % (port))

        if self.is_rj45_port:
            print("Setting RJ45 ports' advertised types is not supported")
            exit(1)
        if self.verbose:
            print("Setting adv_interface_types %s on port %s" % (adv_interface_types, port))

        if adv_interface_types != 'all':
            config_interface_type_list = [x.strip() for x in adv_interface_types.split(',')]
            config_interface_types = set(config_interface_type_list)
            if len(config_interface_types) < len(config_interface_type_list):
                print("Invalid interface type specified: {}. Please remove duplicate interface type".format(adv_interface_types))
                exit(1)
            invalid_interface_types = config_interface_types - VALID_INTERFACE_TYPE_SET
            if invalid_interface_types:
                print("Invalid interface type specified: {}".format(','.join(invalid_interface_types)))
                print("Valid interface types:{}".format(','.join(VALID_INTERFACE_TYPE_SET)))
                exit(1)
        self.db.mod_entry(PORT_TABLE_NAME, port, {PORT_ADV_INTERFACE_TYPES_CONFIG_FIELD_NAME: adv_interface_types})

    def get_supported_speeds(self, port):
        if not self.namespace:
            state_db = SonicV2Connector(host="127.0.0.1")
        else:
            state_db = SonicV2Connector(host="127.0.0.1", namesapce=self.namespace, use_unix_socket_path=True)
        state_db.connect(state_db.STATE_DB)
        return state_db.get(state_db.STATE_DB, '{}|{}'.format(PORT_STATE_TABLE_NAME, port), PORT_STATE_SUPPORTED_SPEEDS)

    def get_supported_fecs(self, port):
        # If there is supported_fecs exposed in STATE_DB, let's use it.
        # Otherwise, take the default
        if not self.namespace:
            state_db = SonicV2Connector(host="127.0.0.1")
        else:
            state_db = SonicV2Connector(host="127.0.0.1", namespace=self.namespace, use_unix_socket_path=True)
        state_db.connect(state_db.STATE_DB)

        supported_fecs_str = state_db.get(state_db.STATE_DB, '{}|{}'.format(PORT_STATE_TABLE_NAME, port), PORT_STATE_SUPPORTED_FECS)
        if supported_fecs_str:
            if supported_fecs_str != 'N/A':
                supported_fecs_list = supported_fecs_str.split(',')
            else:
                supported_fecs_list = []
        else:
            supported_fecs_list = DEFAULT_SUPPORTED_FECS_LIST

        return supported_fecs_list

    def set_tpid(self, port, tpid):
        if self.verbose:
            print("Setting tpid %s on port %s" % (tpid, port))

        tpid_config_ready = False
        tpid_port_capable = False
        tpid_lag_capable = False
        # check TPID Config capability
        port_tpid_capable = self.state_db.get(self.state_db.STATE_DB, SWITCH_CAPABILITY, "PORT_TPID_CAPABLE")
        lag_tpid_capable = self.state_db.get(self.state_db.STATE_DB, SWITCH_CAPABILITY, "LAG_TPID_CAPABLE")
        if port_tpid_capable:
            tpid_config_ready = True
            if port_tpid_capable == "true":
                tpid_port_capable = True
        if lag_tpid_capable:
            tpid_config_ready = True
            if lag_tpid_capable == "true":
                tpid_lag_capable = True

        if tpid_config_ready:
            # TPID allowed are : 8100, 9100, 9200 and 88A8. Reject all other values.
            if tpid == '0x8100' or tpid == '0x9100' or tpid == '0x9200' or tpid == '0x88A8' or tpid == '0x88a8':
                if tpid == '0x88a8':
                    tpid = '0x88A8'

                if self.is_lag:
                    if tpid_lag_capable:
                        self.db.mod_entry(PORT_CHANNEL_TABLE_NAME, port, {TPID_CONFIG_FIELD_NAME: tpid})
                    else:
                       raise Exception("HW is not capable to support PortChannel TPID config.")
                else:
                   if self.is_lag_mbr:
                       raise Exception("%s is already member of %s. Set TPID NOT allowed." % (port, self.parent))
                   else:
                       if tpid_port_capable:
                           self.db.mod_entry(PORT_TABLE_NAME, port, {TPID_CONFIG_FIELD_NAME: tpid})
                       else:
                           raise Exception("HW is not capable to support Port TPID config.")
            else:
                raise Exception("TPID %s is not allowed. Allowed: 0x8100, 9100, 9200, or 88A8." % (tpid))
        else:
            raise Exception("System not ready to accept TPID config. Please try again later.")


def main():
    parser = argparse.ArgumentParser(description='Set SONiC port parameters',
                         formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-p', '--port', type=str, help='port name (e.g. Ethernet0)', required=True, default=None)
    parser.add_argument('-l', '--list', action='store_true', help='list port parametars', default=False)
    parser.add_argument('-s', '--speed', type=int, help='port speed value in Mbit', default=None)
    parser.add_argument('-f', '--fec', type=str, help='port fec mode value (default is: none, rs, fc)', default=None)
    parser.add_argument('-m', '--mtu', type=int, help='port mtu value in bytes', default=None)
    parser.add_argument('-tp', '--tpid', type=str, help='port TPID value in hex (e.g. 0x8100)', default=None)
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    parser.add_argument('-vv', '--verbose', action='store_true', help='Verbose output', default=False)
    parser.add_argument('-n', '--namespace', metavar='namespace details', type = str, required = False,
                        help = 'The asic namespace whose DB instance we need to connect', default=None)
    parser.add_argument('-an', '--autoneg', type = str, required = False,
                        help = 'port auto negotiation mode', default=None)
    parser.add_argument('-S', '--adv-speeds', type = str, required = False,
                        help = 'port advertised speeds', default=None)
    parser.add_argument('-t', '--interface-type', type = str, required = False,
                        help = 'port interface type', default=None)
    parser.add_argument('-T', '--adv-interface-types', type = str, required = False,
                        help = 'port advertised interface types', default=None)
    parser.add_argument('-lt', '--link-training', type = str, required = False,
                        help = 'port link training mode', default=None)
    parser.add_argument('-P', '--tx-power', type=float, required=False,
                                            help='Tx output power(dBm)', default=None)
    parser.add_argument('-F', '--laser-freq', type=int, required=False,
                                            help='Laser frequency(GHz)', default=None)
    args = parser.parse_args()

    # Load database config files
    load_db_config()
    try:
        port = portconfig(args.verbose, args.port, args.namespace)
        if args.list:
            port.list_params(args.port)
        elif args.speed or args.fec or args.mtu or args.link_training or args.autoneg or args.adv_speeds or \
            args.interface_type or args.adv_interface_types or args.tpid or \
            args.tx_power or args.laser_freq:
            if args.speed:
                port.set_speed(args.port, args.speed)
            if args.fec:
                port.set_fec(args.port, args.fec)
            if args.mtu:
                port.set_mtu(args.port, args.mtu)
            if args.link_training:
                port.set_link_training(args.port, args.link_training)
            if args.autoneg:
                port.set_autoneg(args.port, args.autoneg)
            if args.adv_speeds:
                port.set_adv_speeds(args.port, args.adv_speeds)
            if args.interface_type:
                port.set_interface_type(args.port, args.interface_type)
            if args.adv_interface_types:
                port.set_adv_interface_types(args.port, args.adv_interface_types)
            if args.tpid:
                port.set_tpid(args.port, args.tpid)
            if args.tx_power:
                d = decimal.Decimal(str(args.tx_power))
                if d.as_tuple().exponent < -1:
                    print("Error: tx power must be with single decimal place")
                    sys.exit(1)
                port.set_tx_power(args.port, args.tx_power)
            if args.laser_freq:
                if args.laser_freq <= 0:
                    print("Error: Frequency must be > 0")
                    sys.exit(1)
                port.set_laser_freq(args.port, args.laser_freq)
        else:
            parser.print_help()
            sys.exit(1)

    except Exception as e:
        try:
            if os.environ["UTILITIES_UNIT_TESTING"] == "1" or os.environ["UTILITIES_UNIT_TESTING"] == "2":
                print(str(e), file=sys.stdout)
        except KeyError:
            print(str(e), file=sys.stderr)

        sys.exit(1)

if __name__ == "__main__":
    main()
