#!python

"""
    Script to show nat configuration 
    Example of the output:

    root@sonic:/home/admin# sudo natconfig -s

    Nat Type  IP Protocol Global IP      Global L4 Port  Local IP       Local L4 Port  Twice-Nat Id
    --------  ----------- ------------   --------------  -------------  -------------  ------------
    dnat      all         65.55.45.5     ---             10.0.0.1       ---            ---
    dnat      all         65.55.45.6     ---             10.0.0.2       ---            ---
    dnat      tcp         65.55.45.7     2000            20.0.0.1       4500           1
    snat      tcp         20.0.0.2       4000            65.55.45.8     1030           1

    root@sonic:/home/admin# sudo natconfig -p

    Pool Name      Global IP Range             Global L4 Port Range
    ------------   -------------------------   --------------------
    Pool1          65.55.45.5                  100-200
    Pool2          65.55.45.6-65.55.45.8       ---
    Pool3          65.55.45.10-65.55.45.15     500-1000

    root@sonic:/home/admin# sudo natconfig -b

    Binding Name   Pool Name      Access-List    Nat Type  Twice-Nat Id
    ------------   ------------   ------------   --------  ------------
    Bind1          Pool1          ---            snat      ---
    Bind2          Pool2          1              snat      1
    Bind3          Pool3          1,2            snat      --

    root@sonic:/home/admin# sudo natconfig -g
    
    Admin Mode : disabled
    Global Timeout : 600 
    TCP Timeout    : 86400
    UDP Timeout    : 300

"""

import argparse
import sys

from tabulate import tabulate
from swsscommon.swsscommon import ConfigDBConnector


class NatConfig(object):

    def __init__(self):
        super(NatConfig,self).__init__()
        self.config_db = ConfigDBConnector()
        self.config_db.connect()
        return

    def fetch_static_nat(self):
        """
            Fetch Static NAT config from CONFIG DB.
        """
        self.static_nat_data = []

        static_nat_dict = self.config_db.get_table('STATIC_NAT')

        if not static_nat_dict:
            return

        for key,values in static_nat_dict.items():
            ip_protocol = "all"
            global_port = "---"
            local_port = "---"
            nat_type = "dnat"
            twice_nat_id = "---"

            if isinstance(key, str) is True:
                global_ip = key
            else:
                continue

            local_ip = values["local_ip"]

            if "local_port" in values: 
                local_port = values["local_port"]

            if "nat_type" in values:
                nat_type = values["nat_type"]

            if "twice_nat_id" in values:
                twice_nat_id = values["twice_nat_id"]
    
            self.static_nat_data.append((nat_type,) + (ip_protocol,) + (global_ip,) + (global_port,) + (local_ip,) + (local_port,) + (twice_nat_id,))

        self.static_nat_data.sort(key = lambda x: x[0])

    def fetch_static_napt(self):
        """
            Fetch Static NAPT config from CONFIG DB.
        """
        self.static_napt_data = []

        static_napt_dict = self.config_db.get_table('STATIC_NAPT')

        if not static_napt_dict:
            return

        for key,values in static_napt_dict.items():
            local_port = "---"
            nat_type = "dnat"
            twice_nat_id = "---"

            if isinstance(key, tuple) is False:
                continue

            if (len(key) == 3):
                global_ip = key[0]
                global_port = key[2]
                ip_protocol = key[1]
            else:
                continue

            local_ip = values["local_ip"]

            if "local_port" in values:
                local_port = values["local_port"]

            if "nat_type" in values:
                nat_type = values["nat_type"]

            if "twice_nat_id" in values:
                twice_nat_id = values["twice_nat_id"]

            self.static_napt_data.append((nat_type,) + (ip_protocol,) + (global_ip,) + (global_port,) + (local_ip,) + (local_port,) + (twice_nat_id,))

        self.static_napt_data.sort(key = lambda x: x[0])

    def fetch_pool(self):
        """
            Fetch NAT Pool config from CONFIG DB.
        """
        self.nat_pool_data = []

        nat_pool_dict = self.config_db.get_table('NAT_POOL')

        if not nat_pool_dict:
            return

        for key,values in nat_pool_dict.items():
            global_port = "---"

            if isinstance(key, str) is True:
                pool_name = key
            else:
                continue

            global_ip = values["nat_ip"]

            if "nat_port" in values:
                if values["nat_port"] != "NULL":
                    global_port = values["nat_port"]

            self.nat_pool_data.append((pool_name,) + (global_ip,) + (global_port,))

        self.nat_pool_data.sort(key = lambda x: x[0])

    def fetch_binding(self):
        """
            Fetch NAT Binding config from CONFIG DB.
        """
        self.nat_binding_data = []

        nat_binding_dict = self.config_db.get_table('NAT_BINDINGS')

        if not nat_binding_dict:
            return

        for key,values in nat_binding_dict.items():
            access_list = "---"
            nat_type = "snat"
            twice_nat_id = "---"

            if isinstance(key, str) is True:
                binding_name = key
            else:
                continue

            pool_name = values["nat_pool"]

            if "access_list" in values:
                access_list = values["access_list"]

            if "nat_type" in values:
                nat_type = values["nat_type"]

            if "twice_nat_id" in values:
                if values["twice_nat_id"] != "NULL":
                    twice_nat_id = values["twice_nat_id"]

            self.nat_binding_data.append((binding_name,) + (pool_name,) + (access_list,) + (nat_type,) + (twice_nat_id,))

        self.nat_binding_data.sort(key = lambda x: x[0])

    def fetch_nat_zone(self):
        """
            Fetch NAT zone config from CONFIG DB.
        """
        interfaces = ['INTERFACE', 'VLAN_INTERFACE', 'PORTCHANNEL_INTERFACE', 'LOOPBACK_INTERFACE'] 

        self.nat_zone_data = []

        for i in interfaces:
            interface_zone_dict = self.config_db.get_table(i)

            if not interface_zone_dict:
                continue

            for key,values in interface_zone_dict.items():
                zone = "0"

                if isinstance(key, str) is False:
                    continue

                if "nat_zone" in values:
                    zone = values["nat_zone"]

                self.nat_zone_data.append((key,) + (zone,))

        self.nat_zone_data.sort(key = lambda x: x[0])

    def display_static(self):
        """
            Display the static nat and napt
        """

        HEADER = ['Nat Type', 'IP Protocol', 'Global IP', 'Global Port', 'Local IP', 'Local Port', ' Twice-NAT Id'] 
        output = []

        for nat in self.static_nat_data:
            output.append([nat[0], nat[1], nat[2], nat[3], nat[4], nat[5], nat[6]])

        for napt in self.static_napt_data:
            output.append([napt[0], napt[1], napt[2], napt[3], napt[4], napt[5], napt[6]])

        print("")
        print(tabulate(output, HEADER))
        print("")

    def display_pool(self):
        """
            Display the nat pool
        """

        HEADER = ['Pool Name', 'Global IP Range', 'Global Port Range']
        output = []

        for nat in self.nat_pool_data:
            output.append([nat[0], nat[1], nat[2]])

        print("")
        print(tabulate(output, HEADER))
        print("")

    def display_binding(self):
        """
            Display the nat binding
        """

        HEADER = ['Binding Name', 'Pool Name', 'Access-List', 'Nat Type', 'Twice-NAT Id']
        output = []

        for nat in self.nat_binding_data:
            output.append([nat[0], nat[1], nat[2], nat[3], nat[4]])

        print("")
        print(tabulate(output, HEADER))
        print("")

    def display_global(self):
        """
            Fetch NAT Global config from CONFIG DB and Display it.
        """
        self.nat_global_data = []

        global_data = self.config_db.get_entry('NAT_GLOBAL', 'Values')
        if global_data:
           print("")
           if 'admin_mode' in global_data:
               print("Admin Mode     : {}".format(global_data['admin_mode']))
           else:
               print("Admin Mode     : disabled")
           if 'nat_timeout' in global_data:
               print("Global Timeout : {} secs".format(global_data['nat_timeout']))
           else:
               print("Global Timeout : 600 secs")
           if 'nat_tcp_timeout' in global_data:
               print("TCP Timeout    : {} secs".format(global_data['nat_tcp_timeout']))
           else:
               print("TCP Timeout    : 86400 secs")
           if 'nat_udp_timeout' in global_data:
               print("UDP Timeout    : {} secs".format(global_data['nat_udp_timeout']))
           else:
               print("UDP Timeout    : 300 secs")
           print("")
        else:
           print("")
           print("Admin Mode     : disabled")
           print("Global Timeout : 600 secs")
           print("TCP Timeout    : 86400 secs")
           print("UDP Timeout    : 300 secs")
           print("")
           return

    def display_nat_zone(self):
        """
            Display the nat zone
        """

        HEADER = ['Port', 'Zone']
        output = []

        for nat in self.nat_zone_data:
            output.append([nat[0], nat[1]])

        print("")
        print(tabulate(output, HEADER))
        print("")

def main():
    parser = argparse.ArgumentParser(description='Display the nat configuration information',
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog="""
    Examples:
    natconfig -s
    natconfig -p
    natconfig -b
    natconfig -g
    natconfig -z
    """)

    parser.add_argument('-s', '--static', action='store_true', help='Show the nat static configuration')
    parser.add_argument('-p', '--pool', action='store_true', help='Show the nat pool configuration')
    parser.add_argument('-b', '--binding', action='store_true', help='Show the nat binding configuration')
    parser.add_argument('-g', '--globalvalues', action='store_true', help='Show the nat global configuration')
    parser.add_argument('-z', '--zones', action='store_true', help='Show the nat zone configuration')

    args = parser.parse_args()
    
    show_static = args.static
    show_pool = args.pool
    show_binding = args.binding
    show_global = args.globalvalues
    show_zone = args.zones

    try:
        nat = NatConfig()
        if show_static:
            nat.fetch_static_nat()
            nat.fetch_static_napt()
            nat.display_static()
        elif show_pool:
            nat.fetch_pool()
            nat.display_pool()
        elif show_binding:
            nat.fetch_binding()
            nat.display_binding()
        elif show_global:
            nat.display_global()
        elif show_zone:
            nat.fetch_nat_zone()
            nat.display_nat_zone()
    except Exception as e:
        print(str(e))
        sys.exit(1)

if __name__ == "__main__":
    main()

