#!python

"""
    Script to show nat entries and nat statistics in a summary view

    Example of the output:
    root@sonic:/home/admin# sudo natshow -c

    Static NAT Entries         ..................... 2
    Static NAPT Entries        ..................... 3
    Dynamic NAT Entries        ..................... 0
    Dynamic NAPT Entries       ..................... 0
    Static Twice NAT Entries   ..................... 0
    Static Twice NAPT Entries  ..................... 2
    Dynamic Twice NAT Entries  ..................... 0
    Dynamic Twice NAPT Entries ..................... 0
    Total SNAT/SNAPT Entries   ..................... 7
    Total DNAT/DNAPT Entries   ..................... 0
    Total Entries              ..................... 7

    root@sonic:/home/admin# sudo natshow -t

    Static NAT Entries         ..................... 2
    Static NAPT Entries        ..................... 3
    Dynamic NAT Entries        ..................... 0
    Dynamic NAPT Entries       ..................... 0
    Static Twice NAT Entries   ..................... 0
    Static Twice NAPT Entries  ..................... 2
    Dynamic Twice NAT Entries  ..................... 0
    Dynamic Twice NAPT Entries ..................... 0
    Total SNAT/SNAPT Entries   ..................... 7
    Total DNAT/DNAPT Entries   ..................... 0
    Total Entries              ..................... 7

    Protocol  Source             Destination         Translated Source   Translated Destination
    --------  -----------------  ------------------  ------------------  ----------------------
    all       10.0.0.1           ---                 65.55.45.5          ---  
    all       10.0.0.2           ---                 65.55.45.6          ---
    tcp       20.0.0.1:4500      ---                 65.55.45.7:2000     ---
    udp       20.0.0.1:4000      ---                 65.55.45.7:1030     ---
    tcp       20.0.0.1:6000      ---                 65.55.45.7:1024     ---
    udp       20.0.0.1:7000      65.55.45.8:1200     65.55.45.7:1100     20.0.0.2:8000
    tcp       20.0.0.1:6000      65.55.45.8:1500     65.55.45.7:1300     20.0.0.3:9000

    root@sonic:/home/admin# sudo natshow -s

    Protocol  Source             Destination         Packets    Bytes
    --------  -----------------  ------------------  --------   ---------
    all       10.0.0.1           ---                      802     1009280
    all       10.0.0.2           ---                       23        5590
    tcp       20.0.0.1:4500      ---                      110       12460
    udp       20.0.0.1:4000      ---                     1156      789028
    tcp       20.0.0.1:6000      ---                       30       34800
    udp       20.0.0.1:7000      65.55.45.8:1200          128      110204
    tcp       20.0.0.1:6000      65.55.45.8:1500            8        3806

"""

import argparse
import json
import sys
import re

from swsscommon.swsscommon import SonicV2Connector
from tabulate import tabulate

class NatShow(object):

    def __init__(self):
        super(NatShow,self).__init__()
        self.asic_db = SonicV2Connector()
        self.appl_db = SonicV2Connector()
        self.counters_db = SonicV2Connector()
        return

    def fetch_count(self):
        """
            Fetch NAT entries count from COUNTERS DB.
        """
        self.counters_db.connect(self.counters_db.COUNTERS_DB)
        self.static_nat_entries = 0
        self.dynamic_nat_entries = 0
        self.static_napt_entries = 0
        self.dynamic_napt_entries = 0
        self.static_twice_nat_entries = 0
        self.dynamic_twice_nat_entries = 0
        self.static_twice_napt_entries = 0
        self.dynamic_twice_napt_entries = 0
        self.snat_entries = 0
        self.dnat_entries = 0


        exists = self.counters_db.exists(self.counters_db.COUNTERS_DB, 'COUNTERS_GLOBAL_NAT:Values')
        if exists:
            counter_entry = self.counters_db.get_all(self.counters_db.COUNTERS_DB, 'COUNTERS_GLOBAL_NAT:Values')
            if 'STATIC_NAT_ENTRIES' in counter_entry:
                self.static_nat_entries = counter_entry['STATIC_NAT_ENTRIES']
            if 'DYNAMIC_NAT_ENTRIES' in counter_entry:
                self.dynamic_nat_entries = counter_entry['DYNAMIC_NAT_ENTRIES']
            if 'STATIC_NAPT_ENTRIES' in counter_entry:
                self.static_napt_entries = counter_entry['STATIC_NAPT_ENTRIES']
            if 'DYNAMIC_NAPT_ENTRIES' in counter_entry:
                self.dynamic_napt_entries = counter_entry['DYNAMIC_NAPT_ENTRIES']
            if 'STATIC_TWICE_NAT_ENTRIES' in counter_entry:
                self.static_twice_nat_entries = counter_entry['STATIC_TWICE_NAT_ENTRIES']
            if 'DYNAMIC_TWICE_NAT_ENTRIES' in counter_entry:
                self.dynamic_twice_nat_entries = counter_entry['DYNAMIC_TWICE_NAT_ENTRIES']
            if 'STATIC_TWICE_NAPT_ENTRIES' in counter_entry:
                self.static_twice_napt_entries = counter_entry['STATIC_TWICE_NAPT_ENTRIES']
            if 'DYNAMIC_TWICE_NAPT_ENTRIES' in counter_entry:
                self.dynamic_twice_napt_entries = counter_entry['DYNAMIC_TWICE_NAPT_ENTRIES']
            if 'SNAT_ENTRIES' in counter_entry:
                self.snat_entries = counter_entry['SNAT_ENTRIES']
            if 'DNAT_ENTRIES' in counter_entry:
                self.dnat_entries = counter_entry['DNAT_ENTRIES']

    def fetch_translations(self):
        """
            Fetch NAT entries from ASIC DB.
        """
        self.asic_db.connect(self.asic_db.ASIC_DB)
        self.nat_entries_list = []

        nat_str = self.asic_db.keys('ASIC_DB', "ASIC_STATE:SAI_OBJECT_TYPE_NAT_ENTRY:*")
        if not nat_str:
            return

        for nat_entry in nat_str:
            nat = json.loads(nat_entry .split(":", 2)[-1])
            if not nat:
                continue

            ip_protocol = "all"
            translated_dst = "---"
            translated_src = "---"

            ent = self.asic_db.get_all('ASIC_DB', nat_entry, blocking=True)

            nat_type = nat['nat_type']

            if nat_type == "SAI_NAT_TYPE_DESTINATION_NAT":
                translated_dst_ip = ent["SAI_NAT_ENTRY_ATTR_DST_IP"]
                if "SAI_NAT_ENTRY_ATTR_L4_DST_PORT" in ent:
                    translated_dst_port = ent["SAI_NAT_ENTRY_ATTR_L4_DST_PORT"]
                    translated_dst = translated_dst_ip + ":" + translated_dst_port
                else:
                    translated_dst = translated_dst_ip
            elif nat_type == "SAI_NAT_TYPE_SOURCE_NAT":
                translated_src_ip = ent["SAI_NAT_ENTRY_ATTR_SRC_IP"]
                if "SAI_NAT_ENTRY_ATTR_L4_SRC_PORT" in ent:
                    translated_src_port = ent["SAI_NAT_ENTRY_ATTR_L4_SRC_PORT"]
                    translated_src = translated_src_ip + ":" + translated_src_port
                else:
                    translated_src = translated_src_ip
            elif nat_type == "SAI_NAT_TYPE_DOUBLE_NAT":
                translated_dst_ip = ent["SAI_NAT_ENTRY_ATTR_DST_IP"]
                if "SAI_NAT_ENTRY_ATTR_L4_DST_PORT" in ent:
                    translated_dst_port = ent["SAI_NAT_ENTRY_ATTR_L4_DST_PORT"]
                    translated_dst = translated_dst_ip + ":" + translated_dst_port
                else:
                    translated_dst = translated_dst_ip

                translated_src_ip = ent["SAI_NAT_ENTRY_ATTR_SRC_IP"]
                if "SAI_NAT_ENTRY_ATTR_L4_SRC_PORT" in ent:
                    translated_src_port = ent["SAI_NAT_ENTRY_ATTR_L4_SRC_PORT"]
                    translated_src = translated_src_ip + ":" + translated_src_port
                else:
                    translated_src = translated_src_ip
            else:
                continue

            source_ip = nat['nat_data']['key']["src_ip"]
            destination_ip = nat['nat_data']['key']["dst_ip"]
            source_port = nat['nat_data']['key']["l4_src_port"]
            destination_port = nat['nat_data']['key']["l4_dst_port"]
            protocol = nat['nat_data']['key']["proto"]

            if (source_ip == "0.0.0.0"):
                source_ip = "---"

            if (destination_ip == "0.0.0.0"):
                destination_ip = "---"

            if (source_port != "0"):
                source = source_ip + ":" + source_port
            else:
                source = source_ip

            if (destination_port != "0"):
                destination = destination_ip + ":" + destination_port
            else:
                destination = destination_ip
            
            if (protocol == "6"):
                ip_protocol = "tcp"
            elif (protocol == "17"):
                ip_protocol = "udp"

            self.nat_entries_list.append((ip_protocol,) + (source,) + (destination,) + (translated_src,) + (translated_dst,))

        self.nat_entries_list.sort(key = lambda x: x[0])
        return

    def fetch_statistics(self):
        """
            Fetch NAT statistics from Counters DB.
        """
        self.appl_db.connect(self.appl_db.APPL_DB)
        self.counters_db.connect(self.counters_db.COUNTERS_DB)
        self.nat_statistics_list = []

        nat_table_keys = self.appl_db.keys(self.appl_db.APPL_DB, "NAT_TABLE:*")
        if nat_table_keys:
            for i in nat_table_keys:
                nat_entry = re.split(':', i, maxsplit=1)[-1].strip()
                if nat_entry:
                    exists = self.counters_db.exists(self.counters_db.COUNTERS_DB, 'COUNTERS_NAT:{}'.format(nat_entry))

                    if not exists:
                        continue

                    nat_keys = re.split(':', nat_entry)
                    nat_values = self.appl_db.get_all(self.appl_db.APPL_DB,'NAT_TABLE:{}'.format(nat_entry))

                    ip_protocol = "all"
                    source = "---"
                    destination = "---"

                    if nat_values['nat_type'] == "snat":
                        source = nat_keys[0]
                    else:
                        destination = nat_keys[0]

                    counter_entry = self.counters_db.get_all(self.counters_db.COUNTERS_DB, 'COUNTERS_NAT:{}'.format(nat_entry)) 
                    packets = counter_entry['NAT_TRANSLATIONS_PKTS']
                    byte = counter_entry['NAT_TRANSLATIONS_BYTES']

                    self.nat_statistics_list.append((ip_protocol,) + (source,) + (destination,) + (packets,) + (byte,))

        napt_table_keys = self.appl_db.keys(self.appl_db.APPL_DB, "NAPT_TABLE:*")
        if napt_table_keys:
            for i in napt_table_keys:
                napt_entry = re.split(':', i, maxsplit=1)[-1].strip()
                if napt_entry:
                    exists = self.counters_db.exists(self.counters_db.COUNTERS_DB, 'COUNTERS_NAPT:{}'.format(napt_entry))

                    if not exists:
                        continue

                    napt_keys = re.split(':', napt_entry)
                    napt_values = self.appl_db.get_all(self.appl_db.APPL_DB,'NAPT_TABLE:{}'.format(napt_entry))

                    ip_protocol = napt_keys[0].lower()
                    source = "---"
                    destination = "---"

                    if napt_values['nat_type'] == "snat":
                        source = napt_keys[1] + ':' + napt_keys[2]
                    else:
                        destination = napt_keys[1] + ':' + napt_keys[2]

                    counter_entry = self.counters_db.get_all(self.counters_db.COUNTERS_DB, 'COUNTERS_NAPT:{}'.format(napt_entry))
                    packets = counter_entry['NAT_TRANSLATIONS_PKTS']
                    byte = counter_entry['NAT_TRANSLATIONS_BYTES']

                    self.nat_statistics_list.append((ip_protocol,) + (source,) + (destination,) + (packets,) + (byte,))

        nat_twice_table_keys = self.appl_db.keys(self.appl_db.APPL_DB, "NAT_TWICE_TABLE:*")
        if nat_twice_table_keys:
            for i in nat_twice_table_keys:
                nat_twice_entry = re.split(':', i, maxsplit=1)[-1].strip()
                if nat_twice_entry:
                    exists = self.counters_db.exists(self.counters_db.COUNTERS_DB, 'COUNTERS_TWICE_NAT:{}'.format(nat_twice_entry))

                    if not exists:
                        continue

                    nat_twice_keys = re.split(':', nat_twice_entry)
                    nat_twice_values = self.appl_db.get_all(self.appl_db.APPL_DB,'NAT_TWICE_TABLE:{}'.format(nat_twice_entry))

                    ip_protocol = "all"

                    source = nat_twice_keys[0]
                    destination = nat_twice_keys[1]

                    counter_entry = self.counters_db.get_all(self.counters_db.COUNTERS_DB, 'COUNTERS_TWICE_NAT:{}'.format(nat_twice_entry))
                    packets = counter_entry['NAT_TRANSLATIONS_PKTS']
                    byte = counter_entry['NAT_TRANSLATIONS_BYTES']

                    self.nat_statistics_list.append((ip_protocol,) + (source,) + (destination,) + (packets,) + (byte,))

        napt_twice_table_keys = self.appl_db.keys(self.appl_db.APPL_DB, "NAPT_TWICE_TABLE:*")
        if napt_twice_table_keys:
            for i in napt_twice_table_keys:
                napt_twice_entry = re.split(':', i, maxsplit=1)[-1].strip()
                if napt_twice_entry:
                    exists = self.counters_db.exists(self.counters_db.COUNTERS_DB, 'COUNTERS_TWICE_NAPT:{}'.format(napt_twice_entry))

                    if not exists:
                        continue

                    napt_twice_keys = re.split(':', napt_twice_entry)
                    napt_twice_values = self.appl_db.get_all(self.appl_db.APPL_DB,'NAPT_TWICE_TABLE:{}'.format(napt_twice_entry))

                    ip_protocol = napt_twice_keys[0].lower()

                    source = napt_twice_keys[1] + ':' + napt_twice_keys[2]
                    destination = napt_twice_keys[3] + ':' + napt_twice_keys[4]

                    counter_entry = self.counters_db.get_all(self.counters_db.COUNTERS_DB, 'COUNTERS_TWICE_NAPT:{}'.format(napt_twice_entry))
                    packets = counter_entry['NAT_TRANSLATIONS_PKTS']
                    byte = counter_entry['NAT_TRANSLATIONS_BYTES']

                    self.nat_statistics_list.append((ip_protocol,) + (source,) + (destination,) + (packets,) + (byte,))

        self.nat_statistics_list.sort(key = lambda x: x[0])
        return

    def display_count(self):
        """
            Display the nat entries count
        """

        totalEntries = int(self.static_nat_entries) + int(self.dynamic_nat_entries) + int(self.static_napt_entries) + int(self.dynamic_napt_entries)
        totalEntries += (int(self.static_twice_nat_entries) + int(self.dynamic_twice_nat_entries) + int(self.static_twice_napt_entries) + int(self.dynamic_twice_napt_entries))

        print("")
        print("Static NAT Entries         ..................... {}".format(self.static_nat_entries))
        print("Static NAPT Entries        ..................... {}".format(self.static_napt_entries))
        print("Dynamic NAT Entries        ..................... {}".format(self.dynamic_nat_entries))
        print("Dynamic NAPT Entries       ..................... {}".format(self.dynamic_napt_entries))
        print("Static Twice NAT Entries   ..................... {}".format(self.static_twice_nat_entries))
        print("Static Twice NAPT Entries  ..................... {}".format(self.static_twice_napt_entries))
        print("Dynamic Twice NAT Entries  ..................... {}".format(self.dynamic_twice_nat_entries))
        print("Dynamic Twice NAPT Entries ..................... {}".format(self.dynamic_twice_napt_entries))
        print("Total SNAT/SNAPT Entries   ..................... {}".format(self.snat_entries))
        print("Total DNAT/DNAPT Entries   ..................... {}".format(self.dnat_entries))
        print("Total Entries              ..................... {}".format(totalEntries))
        print("")

    def display_translations(self):
        """
            Display the nat transactions
        """

        HEADER = ['Protocol', 'Source', 'Destination', 'Translated Source', 'Translated Destination']
        output = []

        for nat in self.nat_entries_list:
            output.append([nat[0], nat[1], nat[2], nat[3], nat[4]])

        print(tabulate(output, HEADER))
        print("")

    def display_statistics(self):
        """
            Display the nat statistics
        """

        HEADER = ['Protocol', 'Source', 'Destination', 'Packets', 'Bytes']
        output = []

        for nat in self.nat_statistics_list:
            output.append([nat[0], nat[1], nat[2], nat[3], nat[4]])

        print("")
        print(tabulate(output, HEADER))
        print("")

def main():
    parser = argparse.ArgumentParser(description='Display the nat information',
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog="""
    Examples:
    natshow -t
    natshow -s
    natshow -c
    """)

    parser.add_argument('-t', '--translations', action='store_true', help='Show the nat translations')
    parser.add_argument('-s', '--statistics', action='store_true', help='Show the nat statistics')
    parser.add_argument('-c', '--count', action='store_true', help='Show the nat translations count')

    args = parser.parse_args()
    
    show_translations = args.translations
    show_statistics = args.statistics
    show_count = args.count

    try:
        if show_translations:
            nat = NatShow()
            nat.fetch_count()
            nat.fetch_translations()
            nat.display_count()
            nat.display_translations()
        elif show_statistics:
            nat = NatShow()
            nat.fetch_statistics()
            nat.display_statistics()
        elif show_count:
            nat = NatShow()
            nat.fetch_count()
            nat.display_count()

    except Exception as e:
        print(str(e))
        sys.exit(1)

if __name__ == "__main__":
    main()

