#!python

#####################################################################
#
# tunnelstat is a tool for summarizing various tunnel statistics.
#
#####################################################################

import json
import argparse
import datetime
import sys
import os
import sys
import time

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        test_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, test_path)
        import mock_tables.dbconnector
except KeyError:
    pass

from collections import namedtuple, OrderedDict
from natsort import natsorted
from tabulate import tabulate
from utilities_common.netstat import ns_diff, table_as_json, STATUS_NA, format_prate
from utilities_common.cli import json_serial, UserCache
from swsscommon.swsscommon import SonicV2Connector


nstat_fields = ("rx_b_ok", "rx_p_ok", "tx_b_ok", "tx_p_ok")
NStats = namedtuple("NStats", nstat_fields)

header = ['IFACE', 'RX_PKTS', 'RX_BYTES', 'RX_PPS','TX_PKTS', 'TX_BYTES', 'TX_PPS']

rates_key_list = [ 'RX_BPS', 'RX_PPS', 'TX_BPS', 'TX_PPS']
ratestat_fields = ("rx_bps", "rx_pps", "tx_bps", "tx_pps")
RateStats = namedtuple("RateStats", ratestat_fields)


counter_names = (
    'SAI_TUNNEL_STAT_IN_OCTETS',
    'SAI_TUNNEL_STAT_IN_PACKETS',
    'SAI_TUNNEL_STAT_OUT_OCTETS',
    'SAI_TUNNEL_STAT_OUT_PACKETS'
)

counter_types = {
    "vxlan": "SAI_TUNNEL_TYPE_VXLAN"
}

COUNTER_TABLE_PREFIX = "COUNTERS:"
RATES_TABLE_PREFIX = "RATES:"
COUNTERS_TUNNEL_NAME_MAP = "COUNTERS_TUNNEL_NAME_MAP"
COUNTERS_TUNNEL_TYPE_MAP = "COUNTERS_TUNNEL_TYPE_MAP"


class Tunnelstat(object):
    def __init__(self):
        self.db = SonicV2Connector(use_unix_socket_path=False)
        self.db.connect(self.db.COUNTERS_DB)
        self.db.connect(self.db.APPL_DB)

    def get_cnstat(self, tunnel=None, tun_type=None):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            fields = [STATUS_NA] * (len(nstat_fields))
            for pos, counter_name in enumerate(counter_names):
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
                if counter_data:
                    fields[pos] = str(counter_data)
            cntr = NStats._make(fields)._asdict()
            return cntr

        def get_rates(table_id):
            """
                Get the rates from specific table.
            """
            fields = ["0","0","0","0"]
            for pos, name in enumerate(rates_key_list):
                full_table_id = RATES_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, name)
                if counter_data is None:
                    fields[pos] = STATUS_NA
                elif fields[pos] != STATUS_NA:
                    fields[pos] = float(counter_data)
            cntr = RateStats._make(fields)
            return cntr

        # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        ratestat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()

        # Get the info from database
        counter_tunnel_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_TUNNEL_NAME_MAP)
        counter_tunnel_type_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_TUNNEL_TYPE_MAP)

        if counter_tunnel_name_map is None:
            print("No %s in the DB!" % COUNTERS_TUNNEL_NAME_MAP)
            sys.exit(1)

        if counter_tunnel_type_map is None:
            print("No %s in the DB!" % COUNTERS_TUNNEL_TYPE_MAP)
            sys.exit(1)

        if tun_type and tun_type not in counter_types:
            print("Unknown tunnel type %s" % tun_type)
            sys.exit(1)

        if tunnel and not tunnel in counter_tunnel_name_map:
            print("Interface %s missing from %s! Make sure it exists" % (tunnel, COUNTERS_TUNNEL_NAME_MAP))
            sys.exit(2)

        if tunnel:
            if tun_type and counter_types[tun_type] != counter_tunnel_type_map[counter_tunnel_name_map[tunnel]]:
                print("Mismtch in tunnel type. Requested type %s actual type %s" % (
                      counter_types[tun_type], counter_tunnel_type_map[counter_tunnel_name_map[tunnel]]))
                sys.exit(2)
            cnstat_dict[tunnel] = get_counters(counter_tunnel_name_map[tunnel])
            ratestat_dict[tunnel] = get_rates(counter_tunnel_name_map[tunnel])
            return cnstat_dict, ratestat_dict

        for tunnel in natsorted(counter_tunnel_name_map):
            if not tun_type or counter_types[tun_type] == counter_tunnel_type_map[counter_tunnel_name_map[tunnel]]:
                cnstat_dict[tunnel] = get_counters(counter_tunnel_name_map[tunnel])
                ratestat_dict[tunnel] = get_rates(counter_tunnel_name_map[tunnel])
        return cnstat_dict, ratestat_dict

    def cnstat_print(self, cnstat_dict, ratestat_dict, use_json):
        """
            Print the cnstat.
        """
        table = []

        for key, data in cnstat_dict.items():
            if key == 'time':
                continue

            rates = ratestat_dict.get(key, RateStats._make([STATUS_NA] * len(rates_key_list)))
            table.append((key, data['rx_p_ok'], data['rx_b_ok'], format_prate(rates.rx_pps),
                               data['tx_p_ok'], data['tx_b_ok'], format_prate(rates.tx_pps)))

        if use_json:
            print(table_as_json(table, header))

        else:
            print(tabulate(table, header, tablefmt='simple', stralign='right'))

    def cnstat_diff_print(self, cnstat_new_dict, cnstat_old_dict, ratestat_dict, use_json):
        """
            Print the difference between two cnstat results.
        """

        table = []

        for key, cntr in cnstat_new_dict.items():
            if key == 'time':
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            rates = ratestat_dict.get(key, RateStats._make([STATUS_NA] * len(rates_key_list)))
            if old_cntr is not None:
                table.append((key,
                            ns_diff(cntr['rx_p_ok'], old_cntr['rx_p_ok']),
                            ns_diff(cntr['rx_b_ok'], old_cntr['rx_b_ok']),
                            format_prate(rates.rx_pps),
                            ns_diff(cntr['tx_p_ok'], old_cntr['tx_p_ok']),
                            ns_diff(cntr['tx_b_ok'], old_cntr['tx_b_ok']),
                            format_prate(rates.tx_pps)))
            else:
                table.append((key,
                            cntr['rx_p_ok'],
                            cntr['rx_b_ok'],
                            format_prate(rates.rx_pps),
                            cntr['tx_p_ok'],
                            cntr['tx_b_ok'],
                            format_prate(rates.tx_pps)))
        if use_json:
            print(table_as_json(table, header))
        else:
            print(tabulate(table, header, tablefmt='simple', stralign='right'))

    def cnstat_single_tunnel(self, tunnel, cnstat_new_dict, cnstat_old_dict):

        header = tunnel + '\n' + '-'*len(tunnel)
        body = """
        RX:
        %10s packets
        %10s bytes
        TX:
        %10s packets
        %10s bytes"""

        cntr = cnstat_new_dict.get(tunnel)

        if cnstat_old_dict:
            old_cntr = cnstat_old_dict.get(tunnel)
            if old_cntr:
                body = body % (ns_diff(cntr['rx_p_ok'], old_cntr['rx_p_ok']),
                            ns_diff(cntr['rx_b_ok'], old_cntr['rx_b_ok']),
                            ns_diff(cntr['tx_p_ok'], old_cntr['tx_p_ok']),
                            ns_diff(cntr['tx_b_ok'], old_cntr['tx_b_ok']))
        else:
            body = body % (cntr['rx_p_ok'], cntr['rx_b_ok'], cntr['tx_p_ok'], cntr['tx_b_ok'])

        print(header)
        print(body)


def main():
    parser  = argparse.ArgumentParser(description='Display the tunnels state and counters',
                                        formatter_class=argparse.RawTextHelpFormatter,
                                        epilog="""
        Port state: (U)-Up (D)-Down (X)-Disabled
        Examples:
        tunnelstat -c -t test
        tunnelstat -t test
        tunnelstat -d -t test
        tunnelstat
        tunnelstat -p 20
        tunnelstat -i Vlan1000
        """)

    parser.add_argument('-c', '--clear', action='store_true', help='Copy & clear stats')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats, either the uid or the specified tag')
    parser.add_argument('-D', '--delete-all', action='store_true', help='Delete all saved stats')
    parser.add_argument('-j', '--json', action='store_true', help='Display in JSON format')
    parser.add_argument('-t', '--tag', type=str, help='Save stats with name TAG', default=None)
    parser.add_argument('-i', '--tunnel', type=str, help='Show stats for a single tunnel', required=False)
    parser.add_argument('-p', '--period', type=int, help='Display stats over a specified period (in seconds).', default=0)
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    parser.add_argument('-T', '--type', type=str, help ='Display Vxlan tunnel stats', required=False)
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_saved_stats = args.delete
    delete_all_stats = args.delete_all
    use_json = args.json
    tag_name = args.tag if args.tag else ""
    wait_time_in_seconds = args.period
    tunnel_name = args.tunnel if args.tunnel else ""
    tunnel_type = args.type if args.type else ""

    cache = UserCache(tag=tag_name)

    cnstat_file = "tunnelstat"

    cnstat_dir = cache.get_directory()
    cnstat_fqn_file = cnstat_dir + "/" + cnstat_file

    if delete_all_stats:
        cache.remove_all()

    if delete_saved_stats:
        cache.remove()

    tunnelstat = Tunnelstat()
    cnstat_dict,ratestat_dict = tunnelstat.get_cnstat(tunnel=tunnel_name, tun_type=tunnel_type)

    if save_fresh_stats:
        try:
            json.dump(cnstat_dict, open(cnstat_fqn_file, 'w'), default=json_serial)
        except IOError as e:
            sys.exit(e.errno)
        else:
            print("Cleared counters")
            sys.exit(0)

    if wait_time_in_seconds == 0:
        if os.path.isfile(cnstat_fqn_file):
            try:
                cnstat_cached_dict = json.load(open(cnstat_fqn_file, 'r'))
                print("Last cached time was " + str(cnstat_cached_dict.get('time')))
                if tunnel_name:
                    tunnelstat.cnstat_single_tunnel(tunnel_name, cnstat_dict, cnstat_cached_dict)
                else:
                    tunnelstat.cnstat_diff_print(cnstat_dict, cnstat_cached_dict, ratestat_dict, use_json)
            except IOError as e:
                print(e.errno, e)
        else:
            if tag_name:
                print("\nFile belonging to tag %s does not exist" % tag_name)
                print("Did you run 'tunnelstat -c -t %s' to record the counters via tag %s?\n" % (tag_name, tag_name))
            else:
                if tunnel_name:
                    tunnelstat.cnstat_single_tunnel(tunnel_name, cnstat_dict, None)
                else:
                    tunnelstat.cnstat_print(cnstat_dict, ratestat_dict, use_json)
    else:
        #wait for the specified time and then gather the new stats and output the difference.
        time.sleep(wait_time_in_seconds)
        cnstat_new_dict, ratestat_dict = tunnelstat.get_cnstat(tunnel=tunnel_name, tun_type=tunnel_type)
        if tunnel_name:
            tunnelstat.cnstat_single_tunnel(tunnel_name, cnstat_new_dict, cnstat_dict)
        else:
            tunnelstat.cnstat_diff_print(cnstat_new_dict, cnstat_dict, ratestat_dict, use_json)

if __name__ == "__main__":
    main()
