#!python

import os
import subprocess
import sys

import netaddr
import netifaces
from natsort import natsorted
from tabulate import tabulate

from sonic_py_common import multi_asic
from swsscommon import swsscommon
from utilities_common import constants
from utilities_common.general import load_db_config
from utilities_common import multi_asic as multi_asic_util


try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":

        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
        if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
            import mock_tables.mock_multi_asic
            mock_tables.dbconnector.load_namespace_config()
        else:
            import mock_tables.mock_single_asic
            mock_tables.mock_single_asic.add_unknown_intf=True
except KeyError:
    pass


def get_bgp_peer():
    """
    collects local and bgp neighbor ip along with device name in below format
    {
        'local_addr1':['neighbor_device1_name', 'neighbor_device1_ip'],
        'local_addr2':['neighbor_device2_name', 'neighbor_device2_ip']
    }
    """
    bgp_peer = {}
    config_db = swsscommon.ConfigDBConnector()
    config_db.connect()
    data = config_db.get_table('BGP_NEIGHBOR')

    for neighbor_ip in data.keys():
        local_addr = data[neighbor_ip]['local_addr']
        neighbor_name = data[neighbor_ip]['name']
        bgp_peer.setdefault(local_addr, [neighbor_name, neighbor_ip])
    return bgp_peer


def skip_ip_intf_display(interface, display_option):
    if display_option != constants.DISPLAY_ALL:
        if interface.startswith('Ethernet') and multi_asic.is_port_internal(interface):
            return True
        elif interface.startswith('PortChannel') and multi_asic.is_port_channel_internal(interface):
            return True
        elif interface.startswith('Loopback4096'):
            return True
        elif interface.startswith('eth0'):
            return True
        elif interface.startswith('veth'):
            return True
    return False


def get_if_admin_state(iface, namespace):
    """
    Given an interface name, return its admin state reported by the kernel
    """
    cmd = ["cat", "/sys/class/net/{0}/flags".format(iface)]
    if namespace != constants.DEFAULT_NAMESPACE:
        cmd = ["sudo", "ip", "netns", "exec", namespace] + cmd
    try:
        proc = subprocess.Popen(
            cmd,
            stderr=subprocess.STDOUT,
            stdout=subprocess.PIPE,
            text=True)
        state_file = proc.communicate()[0]
        proc.wait()

    except OSError:
        print("Error: unable to get admin state for {}".format(iface))
        return "error"

    try:
        content = state_file.rstrip()
        flags = int(content, 16)
    except ValueError:
        return "error"

    if flags & 0x1:
        return "up"
    else:
        return "down"


def get_if_oper_state(iface, namespace):
    """
    Given an interface name, return its oper state reported by the kernel.
    """
    cmd = ["cat", "/sys/class/net/{0}/carrier".format(iface)]
    if namespace != constants.DEFAULT_NAMESPACE:
        cmd = ["sudo", "ip", "netns", "exec", namespace] + cmd
    try:
        proc = subprocess.Popen(
            cmd,
            stderr=subprocess.STDOUT,
            stdout=subprocess.PIPE,
            text=True)
        state_file = proc.communicate()[0]
        proc.wait()

    except OSError:
        print("Error: unable to get oper state for {}".format(iface))
        return "error"

    oper_state = state_file.rstrip()
    if oper_state == "1":
        return "up"
    else:
        return "down"


def get_if_master(iface):
    """
    Given an interface name, return its master reported by the kernel.
    """
    oper_file = "/sys/class/net/{0}/master"
    if os.path.exists(oper_file.format(iface)):
        real_path = os.path.realpath(oper_file.format(iface))
        return os.path.basename(real_path)
    else:
        return ""


def get_ip_intfs_in_namespace(af, namespace, display):
    """
    Get all the ip intefaces from the kernel for the given namespace
    """
    ip_intfs = {}
    interfaces = multi_asic_util.multi_asic_get_ip_intf_from_ns(namespace)
    bgp_peer = get_bgp_peer()
    for iface in interfaces:
        ip_intf_attr = []
        if namespace != constants.DEFAULT_NAMESPACE and skip_ip_intf_display(iface, display):
            continue
        try:
            ipaddresses = multi_asic_util.multi_asic_get_ip_intf_addr_from_ns(namespace, iface)
        except ValueError:
            continue
        if af in ipaddresses:
            ifaddresses = []
            bgp_neighs = {}
            ip_intf_attr = []
            for ipaddr in ipaddresses[af]:
                neighbor_name = 'N/A'
                neighbor_ip = 'N/A'
                local_ip = str(ipaddr['addr'])
                if af == netifaces.AF_INET:
                    netmask = netaddr.IPAddress(ipaddr['netmask']).netmask_bits()
                else:
                    netmask = ipaddr['netmask'].split('/', 1)[-1]
                local_ip_with_mask = "{}/{}".format(local_ip, str(netmask))
                ifaddresses.append(["", local_ip_with_mask])
                try:
                    neighbor_name = bgp_peer[local_ip][0]
                    neighbor_ip = bgp_peer[local_ip][1]
                except KeyError:
                    pass

                bgp_neighs.update({local_ip_with_mask: [neighbor_name, neighbor_ip]})

            if len(ifaddresses) > 0:
                admin = get_if_admin_state(iface, namespace)
                oper = get_if_oper_state(iface, namespace)
                master = get_if_master(iface)

            ip_intf_attr = {
                "vrf": master,
                "ipaddr": natsorted(ifaddresses),
                "admin": admin,
                "oper": oper,
                "bgp_neighs": bgp_neighs,
                "ns": namespace
            }

            ip_intfs[iface] = ip_intf_attr
    return ip_intfs


def display_ip_intfs(ip_intfs,address_family):
    header = ['Interface', 'Master', 'IPv4 address/mask',
              'Admin/Oper', 'BGP Neighbor', 'Neighbor IP']

    if address_family == 'ipv6':
        header[2] = 'IPv6 address/mask'
    
    data = []
    for ip_intf, v in natsorted(ip_intfs.items()):
        ip_address = v['ipaddr'][0][1]
        neighbour_name = v['bgp_neighs'][ip_address][0]
        neighbour_ip = v['bgp_neighs'][ip_address][1]
        data.append([ip_intf, v['vrf'], v['ipaddr'][0][1], v['admin'] + "/" + v['oper'],  neighbour_name,  neighbour_ip])
        for ifaddr in v['ipaddr'][1:]:
            neighbour_name = v['bgp_neighs'][ifaddr[1]][0]
            neighbour_ip = v['bgp_neighs'][ifaddr[1]][1]
            data.append(["", "", ifaddr[1], "", neighbour_name, neighbour_ip])
    print(tabulate(data, header, tablefmt="simple", stralign='left', missingval=""))


def get_ip_intfs(af, namespace, display):
    '''
    Get all the ip interface present on the device.
    This include ip interfaces on the host as well as ip
    interfaces in each network namespace
    '''
    device = multi_asic_util.MultiAsic(namespace_option=namespace,
                                       display_option=display)
    namespace_list = device.get_ns_list_based_on_options()

    # for single asic devices there is one namespace DEFAULT_NAMESPACE
    # for multi asic devices, there is one network namespace
    # for each asic and one on the host
    if device.is_multi_asic:
        namespace_list.append(constants.DEFAULT_NAMESPACE)

    ip_intfs = {}
    for namespace in namespace_list:
        ip_intfs_in_ns = get_ip_intfs_in_namespace(af, namespace, display)
        # multi asic device can have same ip interface in different namespace
        # so remove the duplicates
        if device.is_multi_asic:
            for ip_intf, v in ip_intfs_in_ns.items():
                if ip_intf in ip_intfs:
                    if v['ipaddr'] != ip_intfs[ip_intf]['ipaddr']:
                        ip_intfs[ip_intf]['ipaddr'] += (v['ipaddr'])
                        ip_intfs[ip_intf]['bgp_neighs'].update(v['bgp_neighs'])
                    continue
                else:
                    ip_intfs[ip_intf] = v
        else:
            ip_intfs.update(ip_intfs_in_ns)
    return ip_intfs


def main():
    # This script gets the ip interfaces from different linux
    # network namespaces. This can be only done from root user.
    if os.geteuid() != 0 and os.environ.get("UTILITIES_UNIT_TESTING", "0") != "2":
        sys.exit("Root privileges required for this operation")
    
    parser = multi_asic_util.multi_asic_args()
    parser.add_argument('-a', '--address_family', type=str, help='ipv4 or ipv6 interfaces', default="ipv4")
    args = parser.parse_args()
    namespace = args.namespace
    display = args.display

    if args.address_family == "ipv4":
        af = netifaces.AF_INET
    elif args.address_family == "ipv6":
        af = netifaces.AF_INET6
    else:
        sys.exit("Invalid argument -a {}".format(args.address_family))

    load_db_config()
    ip_intfs = get_ip_intfs(af, namespace, display)
    display_ip_intfs(ip_intfs,args.address_family)

    sys.exit(0)


if __name__ == "__main__":
    main()
