#!python

########################################
#
# Neighbor Advertiser
#
########################################


import argparse
import json
import os
import requests
import subprocess
import sys
import time
import traceback
import warnings

from sonic_py_common import logger
from swsscommon.swsscommon import SonicV2Connector, ConfigDBConnector
from netaddr import IPAddress, IPNetwork


#
# Config
#

DEFAULT_DURATION = 300
DEFAULT_REQUEST_TIMEOUT = 2
DEFAULT_FERRET_QUERY_RETRIES = 3
DEFAULT_CONFIG_DB_WAIT_TIME = 3 # seconds
SYSLOG_IDENTIFIER = 'neighbor_advertiser'


#
# Tunnel setup
#

MIRROR_SESSION_NAME = 'neighbor_advertiser'
MIRROR_ACL_TABLE_PREFIX = 'SONIC_'
MIRROR_ACL_TABLE_NAME = 'EVERFLOW'
MIRROR_ACL_TABLEV6_NAME = 'EVERFLOWV6'
MIRROR_ACL_RULE_NAME = 'rule_arp'
MIRROR_ACL_RULEV6_NAME = 'rule_nd'
VXLAN_TUNNEL_NAME = 'neigh_adv'
VXLAN_TUNNEL_MAP_PREFIX = 'map_'


#
# Path
#

SYS_NET_PATH = '/sys/class/net/'
VLAN_MAC_ADDRESS_PATH = '/address'
NEIGHBOR_ADVERTISER_REQUEST_SLICE_PATH = '/tmp/neighbor_advertiser/request_slice.json'
NEIGHBOR_ADVERTISER_RESPONSE_CONFIG_PATH = '/tmp/neighbor_advertiser/response_config.json'
FERRET_NEIGHBOR_ADVERTISER_API_PREFIX = '/Ferret/NeighborAdvertiser/Slices/'


#
# Global logger instance
#
log = logger.Logger()


#
# Global variable of config_db
#

config_db = None


def connect_config_db():
    global config_db
    config_db = ConfigDBConnector()
    config_db.connect()

#
# Global variable of app_db
#
appl_db = None

def connect_app_db():
    global appl_db
    appl_db = SonicV2Connector(host="127.0.0.1")
    appl_db.connect(appl_db.APPL_DB)

#
# Check if a DIP returned from ferret is in any of this switch's VLANs
#

vlan_interface_query = None


def is_dip_in_device_vlan(ferret_dip):
    global vlan_interface_query

    # Lazy load the vlan interfaces the first time we run this check.
    if not vlan_interface_query:
        vlan_interface_query = config_db.get_table('VLAN_INTERFACE')

    ferret_dip = IPAddress(ferret_dip)

    for vlan_interface in vlan_interface_query:
        if not is_ip_prefix_in_key(vlan_interface):
            log.log_info('{} does not have a subnet, skipping...'.format(vlan_interface))
            continue

        vlan_subnet = IPNetwork(vlan_interface[1])

        if ferret_dip.version != vlan_subnet.version:
            log.log_info('{} version (IPv{}) does not match provided DIP version (IPv{}), skipping...'.format(vlan_interface[0], vlan_subnet.version, ferret_dip.version))
            continue

        if ferret_dip in vlan_subnet:
            return True

    return False


#
# Get switch info and intf addr
#

def get_switch_name():
    metadata = config_db.get_table('DEVICE_METADATA')
    return metadata['localhost']['hostname']


def get_switch_hwsku():
    metadata = config_db.get_table('DEVICE_METADATA')
    return metadata['localhost']['hwsku']


def extract_ip_ver_addr(ip_prefix):
    ip = IPNetwork(ip_prefix)
    addr = str(ip.ip)
    ver = ip.ip.version
    return (ver, addr)


def is_ip_prefix_in_key(key):
    '''
    Function to check if IP address is present in the key. If it
    is present, then the key would be a tuple or else, it shall be
    be string
    '''
    return (isinstance(key, tuple))


def get_loopback_addr(ip_ver):
    loopback_intfs = config_db.get_table('LOOPBACK_INTERFACE')
    loopback_addr = ''

    for intf in loopback_intfs:
        if not is_ip_prefix_in_key(intf):
             continue
        if 'Loopback0' in intf:
            intf_ip_prefix = intf[1]
            (intf_ip_ver, intf_ip_addr) = extract_ip_ver_addr(intf_ip_prefix)
            if intf_ip_ver == ip_ver:
                loopback_addr = intf_ip_addr
                break

    return loopback_addr


def get_vlan_interfaces():
    vlan_info = config_db.get_table('VLAN')
    vlan_interfaces = []
    vlan_intfs = config_db.get_table('VLAN_INTERFACE')
    # Skip L2 VLANs
    for vlan_name in vlan_info:
        if vlan_name in vlan_intfs:
            vlan_interfaces.append(vlan_name)

    return vlan_interfaces


def get_vlan_interface_members(vlan_intf_name):
    vlan_info = config_db.get_table('VLAN_MEMBER')
    vlan_interface_members = []

    for vlan_member in vlan_info:
        if vlan_member[0] == vlan_intf_name:
            vlan_interface_members.append(vlan_member[1])

    return vlan_interface_members


def get_vlan_interface_mac_address(vlan_intf_name):
    mac_addr_file = SYS_NET_PATH + vlan_intf_name + VLAN_MAC_ADDRESS_PATH
    mac_address = ''

    if os.path.isfile(mac_addr_file):
        with open (mac_addr_file) as f:
            mac_address = f.readline()

    return mac_address.strip()


def get_vlan_interface_vlan_id(vlan_intf_name):
    return vlan_intf_name[4:]


def get_vlan_interface_vxlan_id(vlan_intf_name):
    return vlan_intf_name[4:]


def get_vlan_addr_prefix(vlan_intf_name, ip_ver):
    vlan_intfs = config_db.get_table('VLAN_INTERFACE')
    vlan_addr = []
    vlan_prefix = []

    for intf in vlan_intfs:
        if not is_ip_prefix_in_key(intf):
             continue
        if vlan_intf_name in intf:
            intf_ip = IPNetwork(intf[1])
            intf_ip_addr = str(intf_ip.ip)
            intf_ip_ver = intf_ip.ip.version
            intf_prefixlen = intf_ip.prefixlen
            if intf_ip_ver == ip_ver:
                vlan_addr.append(intf_ip_addr)
                vlan_prefix.append(intf_prefixlen)

    return vlan_addr, vlan_prefix


def get_link_local_addr(vlan_interface):
    try:
        out = subprocess.check_output(['ip', '-6', 'addr', 'show', vlan_interface])
        out = out.decode('UTF-8')
        for line in out.splitlines():
            keys = line.split()
            if keys[0] == 'inet6':
                ip = IPNetwork(keys[1])
                if str(ip.ip).startswith("fe80"):
                    # Link local ipv6 address
                    return str(ip.ip)
    except Exception:
        log.log_error('failed to get %s addresses from o.s.' % vlan_interface)
    
    return None


def get_vlan_addresses(vlan_interface):
    vlan_id = get_vlan_interface_vlan_id(vlan_interface)
    vxlan_id = get_vlan_interface_vxlan_id(vlan_interface)

    mac_addr = None
    ipv4_addr, ipv4_prefix = get_vlan_addr_prefix(vlan_interface, 4)
    ipv6_addr, ipv6_prefix = get_vlan_addr_prefix(vlan_interface, 6)

    if len(ipv6_addr):
        link_local_addr = get_link_local_addr(vlan_interface)
        if link_local_addr and link_local_addr not in ipv6_addr:
            ipv6_addr.append(link_local_addr)
            ipv6_prefix.append('128')

    metadata = config_db.get_table('DEVICE_METADATA')
    mac_addr = metadata['localhost']['mac']
    if not mac_addr:
        mac_addr = get_vlan_interface_mac_address(vlan_interface)

    return vlan_id, vxlan_id, ipv4_addr, ipv4_prefix, ipv6_addr, ipv6_prefix, mac_addr

#
# Set up neighbor advertiser slice on Ferret
#

def construct_neighbor_advertiser_slice():
    switch_info_obj = {
        'name': get_switch_name(),
        'ipv4Addr': get_loopback_addr(4),
        'ipv6Addr': get_loopback_addr(6),
        'hwSku': get_switch_hwsku()
    }

    responding_schemes_obj = {
        'durationInSec': DEFAULT_DURATION
    }

    vlan_interfaces_obj = []

    all_vlan_interfaces = get_vlan_interfaces()

    vlan_intf_table = config_db.get_table('VLAN_INTERFACE')

    vxlanPort = appl_db.get(appl_db.APPL_DB, 'SWITCH_TABLE:switch', 'vxlan_port')

    for vlan_interface in all_vlan_interfaces:
        vlan_id, vxlan_id, ipv4_addr, ipv4_prefix, ipv6_addr, ipv6_prefix, mac_addr = get_vlan_addresses(vlan_interface)

        if not mac_addr:
            log.log_warning('Cannot find mac addr of vlan interface {}'.format(vlan_interface))
            continue

        ipv4_mappings = []
        ipv6_mappings = []
        ctr = 0
        for addr in ipv4_addr:
            if 'proxy_arp' in vlan_intf_table[vlan_interface] and vlan_intf_table[vlan_interface]['proxy_arp'] == 'enabled':
                ipPrefixLen = str(ipv4_prefix[ctr])
            else:
                ipPrefixLen = '32'
            mapping = {
                'ipAddr': addr,
                'ipPrefixLen': ipPrefixLen,
                'macAddr': mac_addr
            }
            ipv4_mappings.append(mapping)
            ctr += 1

        ctr = 0
        for addr in ipv6_addr:
            if 'proxy_arp' in vlan_intf_table[vlan_interface] and vlan_intf_table[vlan_interface]['proxy_arp'] == 'enabled':
                ipPrefixLen = str(ipv6_prefix[ctr])
            else:
                ipPrefixLen = '128'
            mapping = {
                'ipAddr': addr,
                'ipPrefixLen': ipPrefixLen,
                'macAddr': mac_addr
            }
            ipv6_mappings.append(mapping)
            ctr += 1

        if len(ipv4_mappings) > 0 or len(ipv6_mappings) > 0:
            vlan_interface_obj = {
                'vlanId': vlan_id,
                'vxlanId': vxlan_id,
                'ipv4AddrMappings': ipv4_mappings,
                'ipv6AddrMappings': ipv6_mappings
            }

            if vxlanPort:
                vlan_interface_obj['vxlanPort'] = vxlanPort

            vlan_interfaces_obj.append(vlan_interface_obj)

    slice_obj = {
        'switchInfo': switch_info_obj,
        'vlanInterfaces': vlan_interfaces_obj,
        'respondingSchemes': responding_schemes_obj
    }

    return slice_obj


def wrapped_ferret_request(request_slice, https_endpoint):
    response = None

    # NOTE: While we transition to HTTPS we're disabling the verify field. We
    # need to add a way to fetch certificates in this script ASAP.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        response = requests.post(https_endpoint,
                                 json=request_slice,
                                 timeout=DEFAULT_REQUEST_TIMEOUT,
                                 verify=False)

        if not response:
            raise RuntimeError("No response obtained from HTTPS endpoint")

        # If the request is unsuccessful (e.g. has a non 2xx response code),
        # we'll consider it failed
        response.raise_for_status()

    return response


def post_neighbor_advertiser_slice(ferret_service_vip):
    request_slice = construct_neighbor_advertiser_slice()
    save_as_json(request_slice, NEIGHBOR_ADVERTISER_REQUEST_SLICE_PATH)

    https_endpoint = "https://{}:448{}{}".format(ferret_service_vip, FERRET_NEIGHBOR_ADVERTISER_API_PREFIX, get_switch_name())
    response = None

    for retry in range(DEFAULT_FERRET_QUERY_RETRIES):
        try:
            response = wrapped_ferret_request(request_slice, https_endpoint)
        except Exception as e:
            log.log_error("The request failed, vip: {}, error: {}".format(ferret_service_vip, e))
            return None

        neighbor_advertiser_configuration = json.loads(response.content)
        ferret_server_ipv4_addr = neighbor_advertiser_configuration['ipv4Addr']

        # Retry the request if the provided DIP is in the device VLAN
        if is_dip_in_device_vlan(ferret_server_ipv4_addr):
            log.log_info("Failed to set up neighbor advertiser slice, vip: {}, dip {} is in device VLAN (attempt {}/{})".format(ferret_service_vip, ferret_server_ipv4_addr, retry + 1, DEFAULT_FERRET_QUERY_RETRIES))
            continue

        # If all the proceeding checks pass, return the provided DIP
        save_as_json(neighbor_advertiser_configuration, NEIGHBOR_ADVERTISER_RESPONSE_CONFIG_PATH)
        log.log_info("Successfully set up neighbor advertiser slice, vip: {}, dip: {}".format(ferret_service_vip, ferret_server_ipv4_addr))
        return ferret_server_ipv4_addr

    log.log_error("Failed to set up neighbor advertiser slice, vip: {}, returned dips were in device VLAN".format(ferret_service_vip))
    return None


def save_as_json(obj, file_path):
    dir_path = os.path.dirname(file_path)
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    with open(file_path, 'w') as outfile:
        json.dump(obj, outfile, sort_keys = True)


# This function tries to find the corresponding names of the mirror v4 and v6 table
# Right now, the name could be EVERFLOW/EVERFLOWv6 or SONIC_EVERFLOW/SONIC_EVERFLOWV6
def find_mirror_table_name():
    acl_tables = config_db.get_keys("ACL_TABLE")
    v4_table, v6_table = "", ""
    for table in acl_tables:
        if MIRROR_ACL_TABLE_NAME == table or \
                MIRROR_ACL_TABLE_PREFIX + MIRROR_ACL_TABLE_NAME == table:
            v4_table = table
        if MIRROR_ACL_TABLEV6_NAME == table or \
                MIRROR_ACL_TABLE_PREFIX + MIRROR_ACL_TABLEV6_NAME == table:
            v6_table = table
    if not v4_table:
        log.log_error(MIRROR_ACL_TABLE_NAME + " table does not exist")
    if not v6_table:
        log.log_error(MIRROR_ACL_TABLEV6_NAME + " table does not exist")
    return (v4_table, v6_table)


#
# Set mirror tunnel
#


def add_mirror_session(dst_ipv4_addr):
    session_info = {
        'src_ip': get_loopback_addr(4),
        'dst_ip': dst_ipv4_addr
    }

    config_db.set_entry('MIRROR_SESSION', MIRROR_SESSION_NAME, session_info)


def add_mirror_acl_rule():
    (v4_table, v6_table) = find_mirror_table_name()

    if v4_table:
        acl_rule_info = {
            'PRIORITY': '8888',
            'ether_type': '2054',
            'mirror_action': MIRROR_SESSION_NAME
        }
        config_db.set_entry('ACL_RULE',
                (v4_table, MIRROR_ACL_RULE_NAME), acl_rule_info)

    if v6_table:
        acl_rule_info = {
            'PRIORITY': '8887',
            'ICMPV6_TYPE': '135',
            'mirror_action': MIRROR_SESSION_NAME
        }
        config_db.set_entry('ACL_RULE',
                (v6_table, MIRROR_ACL_RULEV6_NAME), acl_rule_info)


def set_mirror_tunnel(ferret_server_ip):
    add_mirror_session(ferret_server_ip)
    # Ensure the mirror session is created before creating the rules
    time.sleep(DEFAULT_CONFIG_DB_WAIT_TIME)
    add_mirror_acl_rule()
    log.log_info('Finish setting mirror tunnel; Ferret: {}'.format(ferret_server_ip))


#
# Reset mirror tunnel
#

def remove_mirror_session():
    config_db.set_entry('MIRROR_SESSION', MIRROR_SESSION_NAME, None)


def remove_mirror_acl_rule():
    (v4_table, v6_table) = find_mirror_table_name()

    if v4_table:
        config_db.set_entry('ACL_RULE', (v4_table, MIRROR_ACL_RULE_NAME), None)

    if v6_table:
        config_db.set_entry('ACL_RULE', (v6_table, MIRROR_ACL_RULEV6_NAME), None)


def reset_mirror_tunnel():
    remove_mirror_acl_rule()
    # Ensure the rules are removed before removing the mirror session
    time.sleep(DEFAULT_CONFIG_DB_WAIT_TIME)
    remove_mirror_session()
    log.log_info('Finish resetting mirror tunnel')


#
# Set vxlan tunnel
#

def check_existing_tunnel():
    vxlan_tunnel = config_db.get_table('VXLAN_TUNNEL')
    if len(vxlan_tunnel):
        global VXLAN_TUNNEL_NAME
        VXLAN_TUNNEL_NAME = list(vxlan_tunnel.keys())[0]
        return True
    return False

def add_vxlan_tunnel(dst_ipv4_addr):
    vxlan_tunnel_info = {
        'src_ip': get_loopback_addr(4),
        'dst_ip': dst_ipv4_addr
    }

    config_db.set_entry('VXLAN_TUNNEL', VXLAN_TUNNEL_NAME, vxlan_tunnel_info)


def add_vxlan_tunnel_map():
    for (index, vlan_intf_name) in enumerate(get_vlan_interfaces(), 1):
        vxlan_tunnel_map_info = {
            'vni': get_vlan_interface_vxlan_id(vlan_intf_name),
            'vlan': vlan_intf_name
        }
        config_db.set_entry('VXLAN_TUNNEL_MAP', (VXLAN_TUNNEL_NAME, VXLAN_TUNNEL_MAP_PREFIX + str(index)), vxlan_tunnel_map_info)


def set_vxlan_tunnel(ferret_server_ip):
    if not check_existing_tunnel():
        add_vxlan_tunnel(ferret_server_ip)
    add_vxlan_tunnel_map()
    log.log_info('Finish setting vxlan tunnel; Ferret: {}'.format(ferret_server_ip))


#
# Reset vxlan tunnel
#

def remove_vxlan_tunnel():
    config_db.set_entry('VXLAN_TUNNEL', VXLAN_TUNNEL_NAME, None)


def remove_vxlan_tunnel_map():
    vxlan_tunnel_name = VXLAN_TUNNEL_NAME
    vxlan_tunnel = config_db.get_table('VXLAN_TUNNEL')
    if len(vxlan_tunnel):
        vxlan_tunnel_name = list(vxlan_tunnel.keys())[0]
    for (index, _) in enumerate(get_vlan_interfaces(), 1):
        config_db.set_entry('VXLAN_TUNNEL_MAP', (vxlan_tunnel_name, VXLAN_TUNNEL_MAP_PREFIX + str(index)), None)


def reset_vxlan_tunnel():
    remove_vxlan_tunnel_map()
    remove_vxlan_tunnel()
    log.log_info('Finish resetting vxlan tunnel')


#
# Main function
#

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', dest='vips', metavar='vips', type = str, required = False, help = 'ferret service vip list, required in set mode')
    parser.add_argument('-m', dest='mode', metavar='mode (set, reset)', type = str, required = True, choices=['set', 'reset'], help = 'operation mode')
    args = parser.parse_args()

    ferret_service_vips = args.vips
    operation_mode = args.mode

    if operation_mode == 'set' and ferret_service_vips is None:
        log.log_warning('ferret service vip is required in set mode')
        sys.exit(1)

    connect_config_db()
    connect_app_db()
    if operation_mode == 'set':
        set_success = False

        for ferret_service_vip in ferret_service_vips.split(','):
            ferret_server_ip = post_neighbor_advertiser_slice(ferret_service_vip)

            if ferret_server_ip:
                set_vxlan_tunnel(ferret_server_ip)
                set_mirror_tunnel(ferret_server_ip)
                set_success = True
                break

        if not set_success:
            log.log_error('Failed to set up neighbor advertiser slice, tried all vips in {}'.format(ferret_service_vips))
            sys.exit(1)

    if operation_mode == 'reset':
        reset_mirror_tunnel()
        reset_vxlan_tunnel()

    sys.exit(0)


if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        log.log_error('! [Failure] {} {}'.format(e, traceback.format_exc()))
        sys.exit(1)

