#!python

#####################################################################
#
# portstat is a tool for summarizing network statistics.
#
#####################################################################

import json
import argparse
import os.path
import sys
import time
from collections import OrderedDict

# mock the redis for unit test purposes #
try:
    if os.environ.get("UTILITIES_UNIT_TESTING", "") == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector

        if os.environ.get("UTILITIES_UNIT_TESTING_IS_SUP", "") == "1":
            import mock
            import sonic_py_common
            from swsscommon.swsscommon import SonicV2Connector
            sonic_py_common.device_info.is_supervisor = mock.MagicMock(return_value=True)
            if os.environ["UTILITIES_UNIT_TESTING_IS_PACKET_CHASSIS"] == "1":
                sonic_py_common.device_info.is_voq_chassis = mock.MagicMock(return_value=False)
                sonic_py_common.device_info.is_packet_chassis = mock.MagicMock(return_value=True)
            else:
                sonic_py_common.device_info.is_voq_chassis = mock.MagicMock(return_value=True)
            SonicV2Connector.delete_all_by_pattern = mock.MagicMock()
    if os.environ.get("UTILITIES_UNIT_TESTING_TOPOLOGY", "") == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()

except KeyError:
    pass

from utilities_common import constants
from utilities_common.intf_filter import parse_interface_in_filter

from utilities_common.cli import json_serial, UserCache
from utilities_common.portstat import Portstat

def main():
    parser  = argparse.ArgumentParser(description='Display the ports state and counters',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Port state: (U)-Up (D)-Down (X)-Disabled
Examples:
  portstat -c -t test
  portstat -t test
  portstat -d -t test
  portstat -e
  portstat
  portstat -r
  portstat -R
  portstat -a
  portstat -p 20
  portstat -l -i Ethernet4,Ethernet8,Ethernet12-20,PortChannel100-102
""")

    parser.add_argument('-a', '--all', action='store_true', help='Display all the stats counters')
    parser.add_argument('-c', '--clear', action='store_true', help='Copy & clear stats')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats, either the uid or the specified tag')
    parser.add_argument('-D', '--delete-all', action='store_true', help='Delete all saved stats')
    parser.add_argument('-e', '--errors', action='store_true', help='Display interface errors')
    parser.add_argument('-f', '--fec-stats', action='store_true', help='Display FEC related statistics')
    parser.add_argument('-j', '--json', action='store_true', help='Display in JSON format')
    parser.add_argument('-r', '--raw', action='store_true', help='Raw stats (unmodified output of netstat)')
    parser.add_argument('-R', '--rate', action='store_true', help='Display interface rates')
    parser.add_argument('-T', '--trim', action='store_true', help='Display trimming related statistics')
    parser.add_argument('-t', '--tag', type=str, help='Save stats with name TAG', default=None)
    parser.add_argument('-p', '--period', type=int, help='Display stats over a specified period (in seconds).', default=0)
    parser.add_argument('-i', '--interface', type=str, help='Display stats for interface lists.', default=None)
    parser.add_argument('-s','--show',   default=constants.DISPLAY_EXTERNAL, help='Display all interfaces or only external interfaces')
    parser.add_argument('-n','--namespace', default=None, help='Display interfaces for specific namespace')
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    parser.add_argument('-l', '--detail', action='store_true', help='Display detailed statistics.')
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_saved_stats = args.delete
    delete_all_stats = args.delete_all
    errors_only = args.errors
    fec_stats_only = args.fec_stats
    rates_only = args.rate
    use_json = args.json
    raw_stats = args.raw
    trim_stats_only = args.trim
    tag_name = args.tag
    wait_time_in_seconds = args.period
    print_all = args.all
    intf_fs = args.interface
    namespace = args.namespace
    display_option = args.show
    detail = args.detail

    cache = UserCache(tag=tag_name)

    cnstat_file = "portstat"
    cnstat_dir = cache.get_directory()
    cnstat_fqn_file = cnstat_dir + "/" + cnstat_file

    if delete_all_stats:
        cache.remove_all()

    if delete_saved_stats:
        cache.remove()

    intf_list = parse_interface_in_filter(intf_fs)

    # When saving counters to the file, save counters
    # for all ports(Internal and External)
    if save_fresh_stats:
        namespace = None
        display_option = constants.DISPLAY_ALL

    portstat = Portstat(namespace, display_option)
    cnstat_dict, ratestat_dict = portstat.get_cnstat_dict()

    # Now decide what information to display
    if raw_stats:
        portstat.cnstat_print(cnstat_dict, ratestat_dict, intf_list, use_json, print_all, errors_only, fec_stats_only, rates_only, trim_stats_only)
        sys.exit(0)

    if save_fresh_stats:
        try:
            json.dump(cnstat_dict, open(cnstat_fqn_file, 'w'), default=json_serial)
        except IOError as e:
            sys.exit(e.errno)
        else:
            print("Cleared counters")
            sys.exit(0)

    if wait_time_in_seconds == 0:
        cnstat_cached_dict = OrderedDict()
        if os.path.isfile(cnstat_fqn_file):
            try:
                cnstat_cached_dict = json.load(open(cnstat_fqn_file, 'r'))
                if not detail:
                    print("Last cached time was " + str(cnstat_cached_dict.get('time')))
                portstat.cnstat_diff_print(cnstat_dict, cnstat_cached_dict, ratestat_dict, intf_list, use_json, print_all, errors_only, fec_stats_only, rates_only, trim_stats_only, detail)
            except IOError as e:
                print(e.errno, e)
        else:
            if tag_name:
                print("\nFile '%s' does not exist" % cnstat_fqn_file)
                print("Did you run 'portstat -c -t %s' to record the counters via tag %s?\n" % (tag_name, tag_name))
            else:
                portstat.cnstat_print(cnstat_dict, ratestat_dict, intf_list, use_json, print_all, errors_only, fec_stats_only, rates_only, trim_stats_only, detail)
    else:
        #wait for the specified time and then gather the new stats and output the difference.
        time.sleep(wait_time_in_seconds)
        print("The rates are calculated within %s seconds period" % wait_time_in_seconds)
        cnstat_new_dict, ratestat_new_dict = portstat.get_cnstat_dict()
        portstat.cnstat_diff_print(cnstat_new_dict, cnstat_dict, ratestat_new_dict, intf_list, use_json, print_all, errors_only, fec_stats_only, rates_only, trim_stats_only, detail)

if __name__ == "__main__":
    main()
