#!python

#####################################################################
#
# dropstat is a tool for displaying drop counters.
#
#####################################################################

# FUTURE IMPROVEMENTS
# - Add the ability to filter by group and type
# - Refactor calls to COUNTERS_DB to reduce redundancy
# - Cache DB queries to reduce # of expensive queries

import json
import argparse
import os
import socket
import sys

from collections import OrderedDict
from natsort import natsorted
from tabulate import tabulate

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "1":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        test_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, test_path)
        import mock_tables.dbconnector
        socket.gethostname = lambda: 'sonic_drops_test'
        os.getuid = lambda: 27
except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector, ConfigDBConnector
from utilities_common.cli import UserCache


# COUNTERS_DB Tables
DEBUG_COUNTER_PORT_STAT_MAP = 'COUNTERS_DEBUG_NAME_PORT_STAT_MAP'
DEBUG_COUNTER_SWITCH_STAT_MAP = 'COUNTERS_DEBUG_NAME_SWITCH_STAT_MAP'
COUNTERS_PORT_NAME_MAP = 'COUNTERS_PORT_NAME_MAP'
COUNTER_TABLE_PREFIX = 'COUNTERS:'

# ASIC_DB Tables
ASIC_SWITCH_INFO_PREFIX = 'ASIC_STATE:SAI_OBJECT_TYPE_SWITCH:'

# APPL_DB Tables
PORT_STATUS_TABLE_PREFIX = "PORT_TABLE:"
PORT_OPER_STATUS_FIELD = "oper_status"
PORT_ADMIN_STATUS_FIELD = "admin_status"
PORT_STATUS_VALUE_UP = 'UP'
PORT_STATUS_VALUE_DOWN = 'DOWN'
PORT_SPEED_FIELD = "speed"

PORT_STATE_UP = 'U'
PORT_STATE_DOWN = 'D'
PORT_STATE_DISABLED = 'X'
PORT_STATE_NA = 'N/A'

# CONFIG_DB Tables
DEBUG_COUNTER_CONFIG_TABLE = 'DEBUG_COUNTER'

# Standard Port-Level Counters
std_port_rx_counters = ['SAI_PORT_STAT_IF_IN_ERRORS', 'SAI_PORT_STAT_IF_IN_DISCARDS']
std_port_tx_counters = ['SAI_PORT_STAT_IF_OUT_ERRORS', 'SAI_PORT_STAT_IF_OUT_DISCARDS']

# Standard Port-Level Headers
std_port_description_header = ['IFACE', 'STATE']
std_port_headers_map = {
    'SAI_PORT_STAT_IF_IN_ERRORS':    'RX_ERR',
    'SAI_PORT_STAT_IF_IN_DISCARDS':  'RX_DROPS',
    'SAI_PORT_STAT_IF_OUT_ERRORS':   'TX_ERR',
    'SAI_PORT_STAT_IF_OUT_DISCARDS': 'TX_DROPS'
}

# Standard Switch-Level Headers
std_switch_description_header = ['DEVICE']


def get_dropstat_dir():
    return UserCache().get_directory()


class DropStat(object):
    def __init__(self):
        self.config_db = ConfigDBConnector()
        self.config_db.connect()

        self.db = SonicV2Connector(use_unix_socket_path=False)
        self.db.connect(self.db.COUNTERS_DB)
        self.db.connect(self.db.ASIC_DB)
        self.db.connect(self.db.APPL_DB)

        dropstat_dir = get_dropstat_dir()
        self.port_drop_stats_file = os.path.join(dropstat_dir, 'port-stats')
        self.switch_drop_stats_file = os.path.join(dropstat_dir + 'switch-stats')

        self.stat_lookup = {}
        self.reverse_stat_lookup = {}

    def show_drop_counts(self, group, counter_type):
        """
            Prints out the current drop counts at the port-level and
            switch-level.
        """

        self.show_port_drop_counts(group, counter_type)
        print('')
        self.show_switch_drop_counts(group, counter_type)

    def clear_drop_counts(self):
        """
            Clears the current drop counts.
        """

        try:
            json.dump(self.get_counts_table(self.gather_counters(std_port_rx_counters + std_port_tx_counters, DEBUG_COUNTER_PORT_STAT_MAP), COUNTERS_PORT_NAME_MAP),
                        open(self.port_drop_stats_file, 'w+'))
            json.dump(self.get_counts(self.gather_counters([], DEBUG_COUNTER_SWITCH_STAT_MAP), self.get_switch_id()),
                        open(self.switch_drop_stats_file, 'w+'))
        except IOError as e:
            print(e)
            sys.exit(e.errno)
        print("Cleared drop counters")

    def show_port_drop_counts(self, group, counter_type):
        """
            Prints out the drop counts at the port level, if such counts exist.
        """

        port_drop_ckpt = {}

        # Grab the latest clear checkpoint, if it exists
        if os.path.isfile(self.port_drop_stats_file):
            port_drop_ckpt = json.load(open(self.port_drop_stats_file, 'r'))

        counters = self.gather_counters(std_port_rx_counters + std_port_tx_counters, DEBUG_COUNTER_PORT_STAT_MAP, group, counter_type)
        headers = std_port_description_header + self.gather_headers(counters, DEBUG_COUNTER_PORT_STAT_MAP)

        if not counters:
            return

        table = []
        for key, value in self.get_counts_table(counters, COUNTERS_PORT_NAME_MAP).items():
            row = [key, self.get_port_state(key)]
            for counter in counters:
                row.append(value.get(counter, 0) - port_drop_ckpt.get(key, {}).get(counter, 0))
            table.append(row)

        if table:
            print(tabulate(table, headers, tablefmt='simple', stralign='right'))

    def show_switch_drop_counts(self, group, counter_type):
        """
            Prints out the drop counts at the switch level, if such counts exist.
        """

        switch_drop_ckpt = {}

        # Grab the latest clear checkpoint, if it exists
        if os.path.isfile(self.switch_drop_stats_file):
            switch_drop_ckpt = json.load(open(self.switch_drop_stats_file, 'r'))

        counters = self.gather_counters([], DEBUG_COUNTER_SWITCH_STAT_MAP, group, counter_type)
        headers = std_switch_description_header + self.gather_headers(counters, DEBUG_COUNTER_SWITCH_STAT_MAP)

        if not counters:
            return

        switch_id = self.get_switch_id()
        switch_stats = self.get_counts(counters, switch_id)

        if not switch_stats:
            return

        row = [socket.gethostname()]
        for counter in counters:
            row.append(switch_stats.get(counter, 0) - switch_drop_ckpt.get(counter, 0))

        if row:
            print(tabulate([row], headers, tablefmt='simple', stralign='right'))

    def gather_counters(self, std_counters, object_stat_map, group=None, counter_type=None):
        """
            Gather the list of counters to be counted, filtering out those that are not in
            the group or not the right counter type.
        """

        configured_counters = self.get_configured_counters(object_stat_map)
        counters = std_counters + configured_counters
        return [ctr for ctr in counters
                if self.in_group(ctr, object_stat_map, group) and
                self.is_type(ctr, object_stat_map, counter_type)]

    def gather_headers(self, counters, object_stat_map):
        """
            Gather the list of headers that are needed to display the given counters.
        """

        headers = []
        counter_names = self.get_reverse_stat_lookup(object_stat_map)

        for counter in counters:
            if counter in std_port_headers_map:
                headers.append(std_port_headers_map[counter])
            else:
                headers.append(self.get_alias(counter_names[counter]))

        return headers

    def get_counts(self, counters, oid):
            """
                Get the drop counts for an individual counter.
            """

            counts = {}

            table_id = COUNTER_TABLE_PREFIX + oid
            for counter in counters:
                counter_data = self.db.get(self.db.COUNTERS_DB, table_id, counter)
                if counter_data is None:
                    counts[counter] = 0
                else:
                    counts[counter] = int(counter_data)
            return counts

    def get_counts_table(self, counters, object_table):
        """
            Returns a dictionary containing a mapping from an object (like a port)
            to its drop counts. Drop counts are contained in a dictionary that maps
            counter name to its counts.
        """

        counter_object_name_map = self.db.get_all(self.db.COUNTERS_DB, object_table)
        current_stat_dict = OrderedDict()

        if counter_object_name_map is None:
            return current_stat_dict

        for obj in natsorted(counter_object_name_map):
            current_stat_dict[obj] = self.get_counts(counters, counter_object_name_map[obj])
        return current_stat_dict

    def get_switch_id(self):
        """
            Returns the ID of the current switch
        """

        switch_id = self.db.keys(self.db.ASIC_DB, ASIC_SWITCH_INFO_PREFIX + '*')[0]
        return switch_id[len(ASIC_SWITCH_INFO_PREFIX):]

    def get_stat_lookup(self, object_stat_map):
        """
            Retrieves the mapping from counter name -> object stat for
            the given object type.
        """

        if not self.stat_lookup.get(object_stat_map, None):
            stats_map = self.db.get_all(self.db.COUNTERS_DB, object_stat_map)
            if stats_map:
                self.stat_lookup[object_stat_map] = stats_map
            else:
                self.stat_lookup[object_stat_map] = None

        return self.stat_lookup[object_stat_map]

    def get_reverse_stat_lookup(self, object_stat_map):
        """
            Retrieves the mapping from object stat -> counter name for
            the given object type.
        """

        if not self.reverse_stat_lookup.get(object_stat_map, None):
            stats_map = self.get_stat_lookup(object_stat_map)
            if stats_map:
                self.reverse_stat_lookup[object_stat_map] = {v: k for k, v in stats_map.items()}
            else:
                self.reverse_stat_lookup[object_stat_map] = None

        return self.reverse_stat_lookup[object_stat_map]

    def get_configured_counters(self, object_stat_map):
        """
            Returns the list of counters that have been configured to
            track packet drops.
        """

        counters = self.get_stat_lookup(object_stat_map)

        configured_counters = []
        if not counters:
            return configured_counters

        return list(counters.values())

    def get_counter_name(self, object_stat_map, counter_stat):
        """
            Gets the name of the counter associated with the given
            counter stat.
        """

        lookup_table = self.get_reverse_stat_lookup(object_stat_map)

        if not lookup_table:
            return None

        return lookup_table.get(counter_stat, None)

    def get_alias(self, counter_name):
        """
            Gets the alias for the given counter name. If the counter
            has no alias then the counter name is returned.
        """

        alias_query = self.config_db.get_entry(DEBUG_COUNTER_CONFIG_TABLE, counter_name)

        if not alias_query:
            return counter_name

        return alias_query.get('alias', counter_name)

    def in_group(self, counter_stat, object_stat_map, group):
        """
            Checks whether the given counter_stat is part of the
            given group.

            If no group is provided this method will return True.
        """

        if not group:
            return True

        if counter_stat in std_port_rx_counters or counter_stat in std_port_tx_counters:
            return False

        group_query = self.config_db.get_entry(DEBUG_COUNTER_CONFIG_TABLE, self.get_counter_name(object_stat_map, counter_stat))

        if not group_query:
            return False

        return group == group_query.get('group', None)

    def is_type(self, counter_stat, object_stat_map, counter_type):
        """
            Checks whether the type of the given counter_stat is the same as
            counter_type.

            If no counter_type is provided this method will return True.
        """

        if not counter_type:
            return True

        if counter_stat in std_port_rx_counters and counter_type == 'PORT_INGRESS_DROPS':
            return True

        if counter_stat in std_port_tx_counters and counter_type == 'PORT_EGRESS_DROPS':
            return True

        type_query = self.config_db.get_entry(DEBUG_COUNTER_CONFIG_TABLE, self.get_counter_name(object_stat_map, counter_stat))

        if not type_query:
            return False

        return counter_type == type_query.get('type', None)

    def get_port_state(self, port_name):
        """
            Get the state of the given port.
        """
        full_table_id = PORT_STATUS_TABLE_PREFIX + port_name
        admin_state = self.db.get(self.db.APPL_DB, full_table_id, PORT_ADMIN_STATUS_FIELD)
        oper_state = self.db.get(self.db.APPL_DB, full_table_id, PORT_OPER_STATUS_FIELD)
        if admin_state is None or oper_state is None:
            return PORT_STATE_NA
        elif admin_state.upper() == PORT_STATUS_VALUE_DOWN:
            return PORT_STATE_DISABLED
        elif admin_state.upper() == PORT_STATUS_VALUE_UP and oper_state.upper() == PORT_STATUS_VALUE_UP:
            return PORT_STATE_UP
        elif admin_state.upper() == PORT_STATUS_VALUE_UP and oper_state.upper() == PORT_STATUS_VALUE_DOWN:
            return PORT_STATE_DOWN
        else:
            return PORT_STATE_NA


def main():
    parser = argparse.ArgumentParser(description='Display drop counters',
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog="""
Examples:
  dropstat
""")

    # Version
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')

    # Actions
    parser.add_argument('-c', '--command', type=str, help='Desired action to perform')

    # Variables
    parser.add_argument('-g', '--group',   type=str, help='The group of the target drop counter', default=None)
    parser.add_argument('-t', '--type',    type=str, help='The type of the target drop counter', default=None)

    args = parser.parse_args()

    command = args.command

    group = args.group
    counter_type = args.type

    dcstat = DropStat()
    if command == 'clear':
        dcstat.clear_drop_counts()
    elif command == 'show':
        dcstat.show_drop_counts(group, counter_type)
    else:
        print("Command not recognized")


if __name__ == '__main__':
    main()
