#!python

"""
using aclshow to display SONiC switch acl rules and counters

usage: aclshow [-h] [-v] [-c] [-vv] [-t TABLES] [-r RULES]

Display SONiC switch ACL Counters/status

optional arguments:
  -h,  --help                 show this help message and exit
  -v,  --version              show program's version number and exit
  -vv, --verbose              verbose output (progress, etc)
  -c,  --clear                clear ACL counters statistics
  -a,  --all                  show all ACL counters
  -r RULES,  --rules RULES    action by specific rules list: Rule_1,Rule_2
  -t TABLES, --tables TABLES  action by specific tables list: Table_1,Table_2
"""

import argparse
import json
import os
import sys

from swsscommon.swsscommon import SonicV2Connector, ConfigDBConnector
from utilities_common.cli import UserCache

from tabulate import tabulate

COUNTERS = "COUNTERS"
ACL_COUNTER_RULE_MAP = "ACL_COUNTER_RULE_MAP"

### acl display header
ACL_HEADER = ["RULE NAME", "TABLE NAME", "PRIO", "PACKETS COUNT", "BYTES COUNT"]

COUNTER_PACKETS_ATTR = "SAI_ACL_COUNTER_ATTR_PACKETS"
COUNTER_BYTES_ATTR = "SAI_ACL_COUNTER_ATTR_BYTES"

USER_CACHE = UserCache()
COUNTERS_CACHE_DIR = USER_CACHE.get_directory()
COUNTERS_CACHE = os.path.join(COUNTERS_CACHE_DIR, 'aclstat')

class AclStat(object):
    """
    Process aclstat
    """

    ACL_TABLE = "ACL_TABLE"
    ACL_RULE = "ACL_RULE"

    def __init__(self, rules, tables):
        self.acl_tables = {}
        self.acl_rules = {}
        self.acl_counters = {}
        self.saved_acl_counters = {}
        self.rule_list = []
        self.table_list = []

        if rules:
            self.rule_list = rules.split(",")

        if tables:
            self.table_list = tables.split(",")

        # Set up db connections
        self.db = SonicV2Connector(use_unix_socket_path=False)
        self.db.connect(self.db.COUNTERS_DB)

        self.configdb = ConfigDBConnector()
        self.configdb.connect()

    def previous_counters(self):
        """
        if user ever did a clear counter action, then read the saved counter reading when clear statistics
        """
        def remap_keys(list):
            res = {}
            for e in list:
                res[e['key'][0], e['key'][1]] = e['value']
            return res

        if os.path.isfile(COUNTERS_CACHE):
            try:
                with open(COUNTERS_CACHE) as fp:
                    self.saved_acl_counters = remap_keys(json.load(fp))
            except Exception:
                pass

    def intersect(self, a, b):
        return list(set(a) & set(b))

    def redis_acl_read(self, verboseflag):
        """
        read redis database for acl counters
        """

        def get_acl_rule_counter_map():
            """
            Return ACL_COUNTER_RULE_MAP
            """
            if self.db.exists(self.db.COUNTERS_DB, ACL_COUNTER_RULE_MAP):
                return self.db.get_all(self.db.COUNTERS_DB, ACL_COUNTER_RULE_MAP)
            return {}

        def fetch_acl_tables():
            """
            Get ACL tables from the DB
            """
            self.acl_tables = self.configdb.get_table(self.ACL_TABLE)

            if verboseflag:
                print("Total number of ACL Tables: %d" % len(self.acl_tables))
            if self.table_list:
                self.acl_tables = { table:content for (table, content) in self.acl_tables.items() if table in self.table_list }
            else:
                self.acl_tables = { table:content for (table, content) in self.acl_tables.items() if table in ['DATAACL'] }

        def fetch_acl_rules():
            """
            Get ACL rules from the DB
            """
            self.acl_rules = self.configdb.get_table(self.ACL_RULE)

            if verboseflag:
                print("Total number of ACL Rules: %d" % len(self.acl_rules))
            if self.table_list:
                self.acl_rules = { (table, rule):content for ((table, rule), content) in self.acl_rules.items() if table in self.table_list }
            if self.rule_list:
                self.acl_rules = { (table, rule):content for ((table, rule), content) in self.acl_rules.items() if rule in self.rule_list }

        def fetch_acl_counters():
            """
            Get ACL counters from the DB
            """
            counters_db_separator = self.db.get_db_separator(self.db.COUNTERS_DB)
            rule_to_counter_map = get_acl_rule_counter_map()
            for table, rule in self.acl_rules:
                rule_identifier = table + counters_db_separator + rule
                if not rule_to_counter_map:
                    continue
                counter_oid = rule_to_counter_map.get(rule_identifier)
                if not counter_oid:
                    continue
                counters_db_key = COUNTERS + counters_db_separator + counter_oid
                cnt_props = self.db.get_all(self.db.COUNTERS_DB, counters_db_key)
                if not cnt_props:
                    continue
                self.acl_counters[table, rule] = cnt_props

            if verboseflag:
                print()

        if verboseflag:
            print("Reading ACL info...")
        fetch_acl_tables()
        fetch_acl_rules()
        fetch_acl_counters()

    def get_counter_value(self, key, type):
        """
        return the counter value or the difference comparing with the saved value in string format
        """
        if key not in self.acl_counters:
            return 'N/A'

        if key in self.saved_acl_counters:
            new_value = int(self.acl_counters[key][type]) - int(self.saved_acl_counters[key][type])
            if new_value >= 0:
                return str(new_value)

        return str(self.acl_counters[key][type])

    def display_acl_stat(self, display_all):
        """
        print out ACL rules and counters
        """
        def get_action(rule):
            for action in ['PACKET_ACTION', 'MIRROR_ACTION']:
                if action in rule:
                    return rule[action]
            return "(no action found)"

        header = ACL_HEADER
        aclstat = []
        for rule_key in self.acl_rules:
            if not display_all and (self.get_counter_value(rule_key, COUNTER_PACKETS_ATTR) == '0' or \
                    self.get_counter_value(rule_key, COUNTER_PACKETS_ATTR) == 'N/A'):
                continue
            rule = self.acl_rules[rule_key]
            rule_priority = -1
            for key,val in rule.items():
                if key.upper() == "PRIORITY":
                    rule_priority = val
            line = [rule_key[1], rule_key[0],
                    rule_priority,
                    self.get_counter_value(rule_key, COUNTER_PACKETS_ATTR),
                    self.get_counter_value(rule_key, COUNTER_BYTES_ATTR)]
            aclstat.append(line)

        # sort the list with table name first and then descending priority
        aclstat.sort(key=lambda x: (x[1], -int(x[2])))
        print(tabulate(aclstat, header))

    def clear_counters(self):
        """
        clear counters -- write current counters to file in /tmp
        """
        def remap_keys(dict):
            return [{'key': k, 'value': v} for k, v in dict.items()]

        with open(COUNTERS_CACHE, 'w') as fp:
            json.dump(remap_keys(self.acl_counters), fp)

def main():
    parser = argparse.ArgumentParser(description='Display SONiC switch Acl Rules and Counters',
                                     formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-a', '--all', action='store_true', help='Show all ACL counters')
    parser.add_argument('-c', '--clear', action='store_true', help='Clear ACL counters statistics')
    parser.add_argument('-r', '--rules', type=str, help='action by specific rules list: Rule1_Name,Rule2_Name', default=None)
    parser.add_argument('-t', '--tables', type=str, help='action by specific tables list: Table1_Name,Table2_Name', default=None)
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    parser.add_argument('-vv', '--verbose', action='store_true', help='Verbose output', default=False)
    args = parser.parse_args()

    try:
        acls = AclStat(args.rules, args.tables)
        acls.redis_acl_read(args.verbose)
        if args.clear:
            acls.clear_counters()
            return
        acls.previous_counters()
        acls.display_acl_stat(args.all)
    except Exception as e:
        print(str(e), file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()
