#!python

#####################################################################
#
# intfstat is a tool for summarizing l3 network statistics.
#
#####################################################################

import json
import argparse
import datetime
import sys
import os
import time

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        test_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, test_path)
        import mock_tables.dbconnector
except KeyError:
    pass

from collections import namedtuple, OrderedDict
from natsort import natsorted
from tabulate import tabulate
from utilities_common.netstat import ns_diff, table_as_json, STATUS_NA, format_brate, format_prate, format_number_with_comma
from utilities_common.cli import json_serial, UserCache
from swsscommon.swsscommon import SonicV2Connector

nstat_fields = (
    "rx_b_ok",
    "rx_p_ok",
    "tx_b_ok",
    "tx_p_ok",
    "rx_b_err",
    "rx_p_err",
    "tx_b_err",
    "tx_p_err"
)

NStats = namedtuple("NStats", nstat_fields)

header = [
    'IFACE',
    'RX_OK',
    'RX_BPS',
    'RX_PPS',
    'RX_ERR',
    'TX_OK',
    'TX_BPS',
    'TX_PPS',
    'TX_ERR'
]

rates_key_list = [ 'RX_BPS', 'RX_PPS', 'TX_BPS', 'TX_PPS']
ratestat_fields = ("rx_bps", "rx_pps", "tx_bps", "tx_pps")
RateStats = namedtuple("RateStats", ratestat_fields)

counter_names = (
    'SAI_ROUTER_INTERFACE_STAT_IN_OCTETS',
    'SAI_ROUTER_INTERFACE_STAT_IN_PACKETS',
    'SAI_ROUTER_INTERFACE_STAT_OUT_OCTETS',
    'SAI_ROUTER_INTERFACE_STAT_OUT_PACKETS',
    'SAI_ROUTER_INTERFACE_STAT_IN_ERROR_OCTETS',
    'SAI_ROUTER_INTERFACE_STAT_IN_ERROR_PACKETS',
    'SAI_ROUTER_INTERFACE_STAT_OUT_ERROR_OCTETS',
    'SAI_ROUTER_INTERFACE_STAT_OUT_ERROR_PACKETS'
)

RATES_TABLE_PREFIX = "RATES:"

COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_RIF_NAME_MAP = "COUNTERS_RIF_NAME_MAP"

class Intfstat(object):
    def __init__(self):
        self.db = SonicV2Connector(use_unix_socket_path=False)
        self.db.connect(self.db.COUNTERS_DB)
        self.db.connect(self.db.APPL_DB)

    def get_cnstat(self, rif=None):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            fields = [STATUS_NA] * len(nstat_fields)
            for pos, counter_name in enumerate(counter_names):
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
                if counter_data:
                    fields[pos] = str(counter_data)
            cntr = NStats._make(fields)._asdict()
            return cntr

        def get_rates(table_id):
            """
                Get the rates from specific table.
            """
            fields = ["0","0","0","0"]
            for pos, name in enumerate(rates_key_list):
                full_table_id = RATES_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, name)
                if counter_data is None:
                    fields[pos] = STATUS_NA
                elif fields[pos] != STATUS_NA:
                    fields[pos] = float(counter_data)
            cntr = RateStats._make(fields)
            return cntr

        # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()
        ratestat_dict = OrderedDict()

        # Get the info from database
        counter_rif_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_RIF_NAME_MAP)

        if counter_rif_name_map is None:
            print("No %s in the DB!" % COUNTERS_RIF_NAME_MAP)
            sys.exit(1)

        if rif and not rif in counter_rif_name_map:
            print("Interface %s missing from %s! Make sure it exists" % (rif, COUNTERS_RIF_NAME_MAP))
            sys.exit(2)

        if rif:
            cnstat_dict[rif] = get_counters(counter_rif_name_map[rif])
            ratestat_dict[rif] = get_rates(counter_rif_name_map[rif])
            return cnstat_dict, ratestat_dict

        for rif in natsorted(counter_rif_name_map):
            cnstat_dict[rif] = get_counters(counter_rif_name_map[rif])
            ratestat_dict[rif] = get_rates(counter_rif_name_map[rif])
        return cnstat_dict, ratestat_dict

    def cnstat_print(self, cnstat_dict, ratestat_dict, use_json):
        """
            Print the cnstat.
        """
        table = []

        for key, data in cnstat_dict.items():
            if key == 'time':
                continue

            rates = ratestat_dict.get(key, RateStats._make([STATUS_NA] * len(rates_key_list)))

            table.append((key,
                          format_number_with_comma(data['rx_p_ok']),
                          format_brate(rates.rx_bps),
                          format_prate(rates.rx_pps),
                          format_number_with_comma(data['rx_p_err']),
                          format_number_with_comma(data['tx_p_ok']),
                          format_brate(rates.tx_bps),
                          format_prate(rates.tx_pps),
                          format_number_with_comma(data['tx_p_err'])))

        if use_json:
            print(table_as_json(table, header))
        else:
            print(tabulate(table, header, tablefmt='simple', stralign='right'))

    def cnstat_diff_print(self, cnstat_new_dict, cnstat_old_dict, ratestat_dict, use_json):
        """
            Print the difference between two cnstat results.
        """

        table = []

        for key, cntr in cnstat_new_dict.items():
            if key == 'time':
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            rates = ratestat_dict.get(key, RateStats._make([STATUS_NA] * len(rates_key_list)))

            if old_cntr is not None:
                table.append((key,
                            ns_diff(cntr['rx_p_ok'], old_cntr['rx_p_ok']),
                            format_brate(rates.rx_bps),
                            format_prate(rates.rx_pps),
                            ns_diff(cntr['rx_p_err'], old_cntr['rx_p_err']),
                            ns_diff(cntr['tx_p_ok'], old_cntr['tx_p_ok']),
                            format_brate(rates.tx_bps),
                            format_prate(rates.tx_pps),
                            ns_diff(cntr['tx_p_err'], old_cntr['tx_p_err'])))
            else:
                table.append((key,
                            format_number_with_comma(cntr['rx_p_ok']),
                            format_brate(rates.rx_bps),
                            format_prate(rates.rx_pps),
                            format_number_with_comma(cntr['rx_p_err']),
                            format_number_with_comma(cntr['tx_p_ok']),
                            format_brate(rates.tx_bps),
                            format_prate(rates.tx_pps),
                            format_number_with_comma(cntr['tx_p_err'])))

        if use_json:
            print(table_as_json(table, header))
        else:
            print(tabulate(table, header, tablefmt='simple', stralign='right'))

    def cnstat_single_interface(self, rif, cnstat_new_dict, cnstat_old_dict):

        header = rif + '\n' + '-'*len(rif)
        body = """
        RX:
        %10s packets
        %10s bytes
        %10s error packets
        %10s error bytes
        TX:
        %10s packets
        %10s bytes
        %10s error packets
        %10s error bytes"""

        cntr = cnstat_new_dict.get(rif)

        if cnstat_old_dict and cnstat_old_dict.get(rif):
            old_cntr = cnstat_old_dict.get(rif)
            body = body % (ns_diff(cntr['rx_p_ok'], old_cntr['rx_p_ok']),
                        ns_diff(cntr['rx_b_ok'], old_cntr['rx_b_ok']),
                        ns_diff(cntr['rx_p_err'], old_cntr['rx_p_err']),
                        ns_diff(cntr['rx_b_err'], old_cntr['rx_b_err']),
                        ns_diff(cntr['tx_p_ok'], old_cntr['tx_p_ok']),
                        ns_diff(cntr['tx_b_ok'], old_cntr['tx_b_ok']),
                        ns_diff(cntr['tx_p_err'], old_cntr['tx_p_err']),
                        ns_diff(cntr['tx_b_err'], old_cntr['tx_b_err']))
        else:
            body = body % (format_number_with_comma(cntr['rx_p_ok']),
                        format_number_with_comma(cntr['rx_b_ok']),
                        format_number_with_comma(cntr['rx_p_err']),
                        format_number_with_comma(cntr['rx_b_err']),
                        format_number_with_comma(cntr['tx_p_ok']),
                        format_number_with_comma(cntr['tx_b_ok']),
                        format_number_with_comma(cntr['tx_p_err']),
                        format_number_with_comma(cntr['tx_b_err']))

        print(header)
        print(body)


def main():
    parser  = argparse.ArgumentParser(description='Display the interfaces state and counters',
                                        formatter_class=argparse.RawTextHelpFormatter,
                                        epilog="""
        Port state: (U)-Up (D)-Down (X)-Disabled
        Examples:
        intfstat -c -t test
        intfstat -t test
        intfstat -d -t test
        intfstat
        intfstat -r
        intfstat -a
        intfstat -p 20
        intfstat -i Vlan1000
        """)

    parser.add_argument('-c', '--clear', action='store_true', help='Copy & clear stats')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats, either the uid or the specified tag')
    parser.add_argument('-D', '--delete-all', action='store_true', help='Delete all saved stats')
    parser.add_argument('-j', '--json', action='store_true', help='Display in JSON format')
    parser.add_argument('-t', '--tag', type=str, help='Save stats with name TAG', default=None)
    parser.add_argument('-i', '--interface', type=str, help='Show stats for a single interface', required=False)
    parser.add_argument('-p', '--period', type=int, help='Display stats over a specified period (in seconds).', default=0)
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_saved_stats = args.delete
    delete_all_stats = args.delete_all
    use_json = args.json
    tag_name = args.tag
    wait_time_in_seconds = args.period
    interface_name = args.interface if args.interface else ""

    cnstat_file = "intfstat"

    cache = UserCache(tag=tag_name)

    cache_general = UserCache()
    cnstat_dir = cache.get_directory()
    cnstat_general_dir = cache_general.get_directory()

    cnstat_fqn_general_file = cnstat_general_dir + "/" + cnstat_file
    cnstat_fqn_file = cnstat_dir + "/" + cnstat_file

    if delete_all_stats:
        cache.remove_all()

    if delete_saved_stats:
        cache.remove()

    intfstat = Intfstat()
    cnstat_dict, ratestat_dict = intfstat.get_cnstat(rif=interface_name)

    if save_fresh_stats:
        try:
            # Add the information also to the general file - i.e. without the tag name
            if tag_name is not None:
                if os.path.isfile(cnstat_fqn_general_file):
                    try:
                        general_data = json.load(open(cnstat_fqn_general_file, 'r'))
                        for key, val in cnstat_dict.items():
                            general_data[key] = val
                        json.dump(general_data, open(cnstat_fqn_general_file, 'w'), default=json_serial)
                    except IOError as e:
                        sys.exit(e.errno)
            # Add the information also to tag specific file
            if os.path.isfile(cnstat_fqn_file):
                data = json.load(open(cnstat_fqn_file, 'r'))
                for key, val in cnstat_dict.items():
                    data[key] = val
                json.dump(data, open(cnstat_fqn_file, 'w'), default=json_serial)
            else:
                json.dump(cnstat_dict, open(cnstat_fqn_file, 'w'), default=json_serial)
        except IOError as e:
            sys.exit(e.errno)
        else:
            print("Cleared counters")
            sys.exit(0)

    if wait_time_in_seconds == 0:
        if os.path.isfile(cnstat_fqn_file) or (os.path.isfile(cnstat_fqn_general_file)):
            try:
                cnstat_cached_dict = {}
                if os.path.isfile(cnstat_fqn_file):
                    cnstat_cached_dict = json.load(open(cnstat_fqn_file, 'r'))
                else:
                    cnstat_cached_dict = json.load(open(cnstat_fqn_general_file, 'r'))

                print("Last cached time was " + str(cnstat_cached_dict.get('time')))
                if interface_name:
                    intfstat.cnstat_single_interface(interface_name, cnstat_dict, cnstat_cached_dict)
                else:
                    intfstat.cnstat_diff_print(cnstat_dict, cnstat_cached_dict, ratestat_dict, use_json)
            except IOError as e:
                print(e.errno, e)
        else:
            if tag_name:
                print("\nFile '%s' does not exist" % cnstat_fqn_file)
                print("Did you run 'intfstat -c -t %s' to record the counters via tag %s?\n" % (tag_name, tag_name))
            else:
                if interface_name:
                    intfstat.cnstat_single_interface(interface_name, cnstat_dict, None)
                else:
                    intfstat.cnstat_print(cnstat_dict, ratestat_dict, use_json)
    else:
        #wait for the specified time and then gather the new stats and output the difference.
        time.sleep(wait_time_in_seconds)
        print("The rates are calculated within %s seconds period" % wait_time_in_seconds)
        cnstat_new_dict, ratestat_new_dict = intfstat.get_cnstat(rif=interface_name)
        if interface_name:
            intfstat.cnstat_single_interface(interface_name, cnstat_new_dict, cnstat_dict)
        else:
            intfstat.cnstat_diff_print(cnstat_new_dict, cnstat_dict, ratestat_new_dict, use_json)

if __name__ == "__main__":
    main()
