#!python

import argparse
import os
import json
import sys

from natsort import natsorted
from tabulate import tabulate

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()

except KeyError:
    pass

import utilities_common.multi_asic as multi_asic_util
from flow_counter_util.route import build_route_pattern, extract_route_pattern, exit_if_route_flow_counter_not_support, DEFAULT_VRF, COUNTERS_ROUTE_TO_PATTERN_MAP
from utilities_common import constants
from utilities_common.netstat import format_number_with_comma, table_as_json, ns_diff, format_prate
from utilities_common.cli import UserCache

# Flow counter meta data, new type of flow counters can extend this dictinary to reuse existing logic
flow_counter_meta = {
    'trap': {
        'headers': ['Trap Name', 'Packets', 'Bytes', 'PPS'],
        'name_map': 'COUNTERS_TRAP_NAME_MAP'
    },
    'route': {
        'headers': ['Route pattern', 'VRF', 'Matched routes', 'Packets', 'Bytes'],
        'name_map': 'COUNTERS_ROUTE_NAME_MAP'
    }
}
flow_counters_fields = ['SAI_COUNTER_STAT_PACKETS', 'SAI_COUNTER_STAT_BYTES']

# Only do diff for 'Packets' and 'Bytes'
diff_column_positions = set([0, 1])

FLOW_COUNTER_TABLE_PREFIX = "COUNTERS:"
RATES_TABLE_PREFIX = 'RATES:'
PPS_FIELD = 'RX_PPS'


class FlowCounterStats(object):
    def __init__(self, args):
        self.db = None
        self.multi_asic = multi_asic_util.MultiAsic(namespace_option=args.namespace)
        self.args = args
        meta_data = flow_counter_meta[args.type]
        self.name_map = meta_data['name_map']
        self.headers = meta_data['headers']
        self.cache = UserCache()
        self.data_file = os.path.join(self.cache.get_directory(), "flow-counter-stats")
        if self.args.delete:
            self.cache.remove()
        self.data = {}

    def show(self):
        """Show flow counter statistic
        """
        self._collect_and_diff()
        headers, table = self._prepare_show_data()
        self._print_data(headers, table)

    def _collect_and_diff(self):
        """Collect statistic from db and diff from old data if any
        """
        self._collect()
        old_data = self._load()
        need_update_cache = self._diff(old_data, self.data)
        if need_update_cache:
            self._save(old_data)

    def _adjust_headers(self, headers):
        """Adjust table headers based on platforms

        Args:
            headers (list): Original headers

        Returns:
            headers (list): Headers with 'ASIC ID' column if it is a multi ASIC platform
        """
        return ['ASIC ID'] + headers if self.multi_asic.is_multi_asic else headers

    def _prepare_show_data(self):
        """Prepare headers and table data for output

        Returns:
            headers (list): Table headers
            table (list): Table data
        """
        table = []
        headers = self._adjust_headers(self.headers)

        for ns, stats in natsorted(self.data.items()):
            if self.args.namespace is not None and self.args.namespace != ns:
                continue
            for name, values in natsorted(stats.items()):
                if self.multi_asic.is_multi_asic:
                    row = [ns]
                else:
                    row = []
                row.extend([name, format_number_with_comma(values[0]), format_number_with_comma(values[1]), format_prate(values[2])])
                table.append(row)

        return headers, table

    def _print_data(self, headers, table):
        """Print statistic data based on output format

        Args:
            headers (list): Table headers
            table (list): Table data
        """
        if self.args.json:
            print(table_as_json(table, headers))
        else:
            print(tabulate(table, headers, tablefmt='simple', stralign='right'))

    def clear(self):
        """Clear flow counter statistic. This function does not clear data from ASIC. Instead, it saves flow counter statistic to a file. When user
           issue show command after clear, it does a diff between new data and saved data.
        """
        self._collect()
        self._save(self.data)
        print('Flow Counters were successfully cleared')

    @multi_asic_util.run_on_multi_asic
    def _collect(self):
        """Collect flow counter statistic from DB. This function is called on a multi ASIC context.
        """
        self.data.update(self._get_stats_from_db())

    def _get_stats_from_db(self):
        """Get flow counter statistic from DB.

        Returns:
            dict: A dictionary. E.g: {<namespace>: {<trap_name>: [<value_in_pkts>, <value_in_bytes>, <rx_pps>, <counter_oid>]}}
        """
        ns = self.multi_asic.current_namespace
        name_map = self.db.get_all(self.db.COUNTERS_DB, self.name_map)
        data = {ns: {}}
        if not name_map:
            return data

        for name, counter_oid in name_map.items():
            values = self._get_stats_value(counter_oid)

            full_table_id = RATES_TABLE_PREFIX + counter_oid
            counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, PPS_FIELD)
            values.append('0' if counter_data is None else counter_data)
            values.append(counter_oid)
            data[ns][name] = values
        return data

    def _get_stats_value(self, counter_oid):
        """Get statistic value from COUNTERS_DB COUNTERS table

        Args:
            counter_oid (string): OID of a generic counter

        Returns:
            values (list): A list of statistics value
        """
        values = []
        full_table_id = FLOW_COUNTER_TABLE_PREFIX + counter_oid
        for field in flow_counters_fields:
            counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, field)
            values.append('0' if counter_data is None else counter_data)
        return values

    def _save(self, data):
        """Save flow counter statistic to a file
        """
        try:
            if os.path.exists(self.data_file):
                os.remove(self.data_file)

            with open(self.data_file, 'w') as f:
                json.dump(data, f)
        except IOError as e:
            print('Failed to save statistic - {}'.format(repr(e)))

    def _load(self):
        """Load flow counter statistic from a file

        Returns:
            dict: A dictionary. E.g: {<namespace>: {<trap_name>: [<value_in_pkts>, <value_in_bytes>, <rx_pps>, <counter_oid>]}}
        """
        if not os.path.exists(self.data_file):
            return None

        try:
            with open(self.data_file, 'r') as f:
                data = json.load(f)
        except IOError as e:
            print('Failed to load statistic - {}'.format(repr(e)))
            return None

        return data

    def _diff(self, old_data, new_data):
        """Do a diff between new data and old data.

        Args:
            old_data (dict): E.g: {<namespace>: {<trap_name>: [<value_in_pkts>, <value_in_bytes>, <rx_pps>, <counter_oid>]}}
            new_data (dict): E.g: {<namespace>: {<trap_name>: [<value_in_pkts>, <value_in_bytes>, <rx_pps>, <counter_oid>]}}

        Returns:
            bool: True if cache need to be updated
        """
        if not old_data:
            return False

        need_update_cache = False
        for ns, stats in new_data.items():
            if ns not in old_data:
                continue
            old_stats = old_data[ns]
            for name, values in stats.items():
                if name not in old_stats:
                    continue

                old_values = old_stats[name]
                if values[-1] != old_values[-1]:
                    # Counter OID not equal means the trap was removed and added again. Removing a trap would cause
                    # the stats value restart from 0. To avoid get minus value here, it should not do diff in case
                    # counter OID is changed.
                    old_values[-1] = values[-1]
                    for i in diff_column_positions:
                        old_values[i] = '0'
                        values[i] = ns_diff(values[i], old_values[i])
                    need_update_cache = True
                    continue

                has_negative_diff = False
                for i in diff_column_positions:
                    # If any diff has negative value, set all counter values to 0 and update cache
                    if int(values[i]) < int(old_values[i]):
                        has_negative_diff = True
                        break

                if has_negative_diff:
                    for i in diff_column_positions:
                        old_values[i] = '0'
                        values[i] = ns_diff(values[i], old_values[i])
                    need_update_cache = True
                    continue

                for i in diff_column_positions:
                    values[i] = ns_diff(values[i], old_values[i])

        return need_update_cache


class RouteFlowCounterStats(FlowCounterStats):
    SHOW_BY_PREFIX_HEADERS = ['Route', 'VRF', 'Route pattern', 'Packets', 'Bytes']

    def __init__(self, args):
        super(RouteFlowCounterStats,self).__init__(args)

    def _print_data(self, headers, table):
        """Print statistic data based on output format

        Args:
            headers (list): Table headers
            table (list): Table data
        """
        if self.args.json:
            # The first column of the table might have duplicate value, have to
            # add an extra index field to make table_as_json work
            print(table_as_json(([i] + line for i, line in enumerate(table)), ['Index'] + headers))
        else:
            print(tabulate(table, headers, tablefmt='simple', stralign='right'))

    def _prepare_show_data(self):
        """Prepare table headers and table data for output. If "--prefix" is specified, fetch data that matches
           the given prefix; if "--prefix_pattern" is specified, fetch data that matches the given pattern;
           otherwise, fetch all data.
        Returns:
            headers (list): Table headers
            table (list): Table data
        """
        if self.args.prefix:
            return self._prepare_show_data_by_prefix()
        else:
            return self._prepare_show_data_by_pattern()

    def _prepare_show_data_by_prefix(self):
        """Prepare table header and table data by given prefix
        Returns:
            headers (list): Table headers
            table (list): Table data
        """
        table = []

        headers = self._adjust_headers(self.SHOW_BY_PREFIX_HEADERS)
        for ns, pattern_entry in self.data.items():
            if self.args.namespace is not None and self.args.namespace != ns:
                continue
            for route_pattern, prefix_entry in pattern_entry.items():
                if self.args.prefix in prefix_entry:
                    vrf, prefix_pattern = extract_route_pattern(route_pattern)
                    if vrf != self.args.vrf:
                        continue

                    values = prefix_entry[self.args.prefix]
                    if self.multi_asic.is_multi_asic:
                        row = [ns]
                    else:
                        row = []
                    row.extend([self.args.prefix,
                                self.args.vrf,
                                prefix_pattern,
                                format_number_with_comma(values[0]),
                                format_number_with_comma(values[1])])
                    table.append(row)

        return headers, table

    def _prepare_show_data_by_pattern(self):
        """Prepare table header and table data by given pattern. If pattern is not specified, show all data.
        Returns:
            headers (list): Table headers
            table (list): Table data
        """
        table = []

        headers = self._adjust_headers(self.headers)
        for ns, pattern_entries in natsorted(self.data.items()):
            if self.args.namespace is not None and self.args.namespace != ns:
                continue
            if self.args.prefix_pattern:
                route_pattern = build_route_pattern(self.args.vrf, self.args.prefix_pattern)
                if route_pattern in pattern_entries:
                    self._fill_table_for_prefix_pattern(table, ns, pattern_entries[route_pattern], self.args.prefix_pattern, self.args.vrf)
                break
            else:
                for route_pattern, prefix_entries in natsorted(pattern_entries.items()):
                    vrf, prefix_pattern = extract_route_pattern(route_pattern)
                    self._fill_table_for_prefix_pattern(table, ns, prefix_entries, prefix_pattern, vrf)

        return headers, table

    def _fill_table_for_prefix_pattern(self, table, ns, prefix_entries, prefix_pattern, vrf):
        """Fill table data for prefix pattern
        Args:
            table (list): Table data to fill
            ns (str): Namespace
            prefix_entries (dict): Prefix to value map
            prefix_pattern (str): Prefix pattern
            vrf (str): VRF
        """
        is_first_row = True
        for prefix, values in natsorted(prefix_entries.items()):
            if self.multi_asic.is_multi_asic:
                row = [ns if is_first_row or self.args.json else '']
            else:
                row = []
            row.extend([prefix_pattern if is_first_row or self.args.json else '',
                        vrf if is_first_row or self.args.json else '',
                        prefix,
                        format_number_with_comma(values[0]),
                        format_number_with_comma(values[1])])
            table.append(row)
            if is_first_row:
                is_first_row = False

    def clear(self):
        """Clear statistic based on arguments.
           1. If "--prefix" is specified, clear the data matching the prefix
           2. If "--prefix_pattern" is specified, clear the data matching the pattern
           3. Otherwise, clear all data
           If "--namepsace" is not specified, clear the data belongs to the default namespace.
           If "--vrf" is not specified, use default VRF.
        """
        if self.args.prefix:
            self.clear_by_prefix()
        elif self.args.prefix_pattern:
            self.clear_by_pattern()
        else:
            super(RouteFlowCounterStats, self).clear()

    @multi_asic_util.run_on_multi_asic
    def clear_by_prefix(self):
        """Clear the data matching the prefix
        """
        ns = constants.DEFAULT_NAMESPACE if self.args.namespace is None else self.args.namespace
        if ns != self.multi_asic.current_namespace:
            return

        name_map = self.db.get_all(self.db.COUNTERS_DB, self.name_map)
        prefix_vrf = build_route_pattern(self.args.vrf, self.args.prefix)
        if not name_map or prefix_vrf not in name_map:
            print('Cannot find {} in COUNTERS_DB {} table'.format(self.args.prefix, self.name_map))
            return

        route_to_pattern_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_ROUTE_TO_PATTERN_MAP)
        if not route_to_pattern_map or prefix_vrf not in route_to_pattern_map:
            print('Cannot find {} in {} table'.format(self.args.prefix, COUNTERS_ROUTE_TO_PATTERN_MAP))
            return

        self.data = self._load()
        if not self.data:
            self.data = {}

        if ns not in self.data:
            self.data[ns] = {}

        route_pattern = route_to_pattern_map[prefix_vrf]
        if route_pattern not in self.data[ns]:
            self.data[ns][route_pattern] = {}

        counter_oid = name_map[prefix_vrf]
        values = self._get_stats_value(counter_oid)
        values.append(counter_oid)
        self.data[ns][route_pattern][self.args.prefix] = values
        self._save(self.data)
        print('Flow Counters of the specified route were successfully cleared')

    @multi_asic_util.run_on_multi_asic
    def clear_by_pattern(self):
        """Clear the data matching the specified pattern
        """
        ns = constants.DEFAULT_NAMESPACE if self.args.namespace is None else self.args.namespace
        if ns != self.multi_asic.current_namespace:
            return

        route_to_pattern_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_ROUTE_TO_PATTERN_MAP)
        expect_route_pattern = build_route_pattern(self.args.vrf, self.args.prefix_pattern)
        matching_prefix_vrf_list = [prefix_vrf for prefix_vrf, route_pattern in route_to_pattern_map.items() if route_pattern == expect_route_pattern]
        if not matching_prefix_vrf_list:
            print('Cannot find {} in COUNTERS_DB {} table'.format(self.args.prefix_pattern, COUNTERS_ROUTE_TO_PATTERN_MAP))
            return

        data_to_update = {}
        name_map = self.db.get_all(self.db.COUNTERS_DB, self.name_map)
        for prefix_vrf in matching_prefix_vrf_list:
            if prefix_vrf not in name_map:
                print('Warning: cannot find {} in {}'.format(prefix_vrf, self.name_map))
                continue

            counter_oid = name_map[prefix_vrf]
            values = self._get_stats_value(counter_oid)
            values.append(counter_oid)
            _, prefix = extract_route_pattern(prefix_vrf)
            data_to_update[prefix] = values

        self.data = self._load()
        if not self.data:
            self.data = {}

        if not self.data or ns not in self.data:
            self.data[ns] = {}

        if expect_route_pattern not in self.data[ns]:
            self.data[ns][expect_route_pattern] = {}

        self.data[ns][expect_route_pattern].update(data_to_update)
        self._save(self.data)

    def _get_stats_from_db(self):
        """Get flow counter statistic from DB.
        Returns:
            dict: A dictionary. E.g: {<namespace>: {(<route_pattern>): {<prefix>: [<value_in_pkts>, <value_in_bytes>, <counter_oid>]}}}
        """
        ns = self.multi_asic.current_namespace
        name_map = self.db.get_all(self.db.COUNTERS_DB, self.name_map)
        data = {ns: {}}
        if not name_map:
            return data

        route_to_pattern_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_ROUTE_TO_PATTERN_MAP)
        if not route_to_pattern_map:
            return data

        for prefix_vrf, route_pattern in route_to_pattern_map.items():
            if route_pattern not in data[ns]:
                data[ns][route_pattern] = {}

            counter_oid = name_map[prefix_vrf]
            values = self._get_stats_value(counter_oid)
            values.append(counter_oid)
            _, prefix = extract_route_pattern(prefix_vrf)
            data[ns][route_pattern][prefix] = values

        return data

    def _diff(self, old_data, new_data):
        """Do a diff between new data and old data.
        Args:
            old_data (dict): E.g: {<namespace>: {(<route_pattern>): {<prefix>: [<value_in_pkts>, <value_in_bytes>, <counter_oid>]}}}
            new_data (dict): E.g: {<namespace>: {(<route_pattern>): {<prefix>: [<value_in_pkts>, <value_in_bytes>, <counter_oid>]}}}
        """
        need_update_cache = False
        if not old_data:
            return need_update_cache

        for ns, stats in new_data.items():
            if ns not in old_data:
                continue

            old_stats = old_data[ns]
            for route_pattern, prefix_entries in stats.items():
                if route_pattern not in old_stats:
                    continue

                old_prefix_entries = old_stats[route_pattern]
                for name, values in prefix_entries.items():
                    if name not in old_prefix_entries:
                        continue

                    old_values = old_prefix_entries[name]
                    if values[-1] != old_values[-1]:
                        # Counter OID not equal means the generic counter was removed and added again. Removing a generic counter  would cause
                        # the stats value restart from 0. To avoid get minus value here, it should not do diff in case
                        # counter OID is changed.
                        old_values[-1] = values[-1]
                        for i in diff_column_positions:
                            old_values[i] = '0'
                            values[i] = ns_diff(values[i], old_values[i])
                        need_update_cache = True
                        continue

                    has_negative_diff = False
                    for i in diff_column_positions:
                        # If any diff has negative value, set all counter values to 0 and update cache
                        if int(values[i]) < int(old_values[i]):
                            has_negative_diff = True
                            break

                    if has_negative_diff:
                        for i in diff_column_positions:
                            old_values[i] = '0'
                            values[i] = ns_diff(values[i], old_values[i])
                        need_update_cache = True
                        continue

                    for i in diff_column_positions:
                        values[i] = ns_diff(values[i], old_values[i])

        return need_update_cache


def main():
    parser  = argparse.ArgumentParser(description='Display the flow counters',
                                    formatter_class=argparse.RawTextHelpFormatter,
                                    epilog="""
Examples:
  flow_counters_stat -c -t trap
  flow_counters_stat -t trap
  flow_counters_stat -d -t trap
""")
    parser.add_argument('-c', '--clear', action='store_true', help='Copy & clear stats')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats')
    parser.add_argument('-j', '--json', action='store_true', help='Display in JSON format')
    parser.add_argument('-n','--namespace', default=None, help='Display flow counters for specific namespace')
    parser.add_argument('-t', '--type', required=True, choices=['trap', 'route'],help='Flow counters type')
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--prefix_pattern', help='Prefix pattern') # for route flow counter only, ignored by other type
    group.add_argument('--prefix', help='Prefix') # for route flow counter only, ignored by other type
    parser.add_argument('--vrf', help='VRF name', default=DEFAULT_VRF) # for route flow counter only, ignored by other type

    args = parser.parse_args()

    if args.type == 'trap':
        stats = FlowCounterStats(args)
    elif args.type == 'route':
        exit_if_route_flow_counter_not_support()
        stats = RouteFlowCounterStats(args)
    else:
        print('Invalid flow counter type: {}'.format(args.type))
        exit(1)

    if args.clear:
        stats.clear()
    else:
        stats.show()


if __name__ == '__main__':
    main()
