#!python

"""
Utility for blocking and unblocking traffic from given source ip address on ACL tables.

The block operation will insert a DENY rule at the top of the table. The unblock operation
will remove an existing DENY rule that has been created by the block operation (i.e. it does
NOT insert an ALLOW rule, only removes DENY rules).

Since SONiC supports multi ACL rules share the same priority, all ACL rules created by null_route_helper will
use the highest priority(9999).

Example:

Block traffic from 10.2.3.4:
./null_route_helper block acl_table_name 10.2.3.4

Unblock all traffic from 10.2.3.4:
./null_route_helper unblock acl_table_name 10.2.3.4

List all acl rules added by this script
./null_route_helper list acl_table_name
"""


from __future__ import print_function

import syslog
import sys
import click
import ipaddress
import tabulate

from swsscommon.swsscommon import ConfigDBConnector


CONFIG_DB_ACL_TABLE_TABLE = "ACL_TABLE"
CONFIG_DB_ACL_RULE_TABLE = "ACL_RULE"
CONFIG_DB_VLAN_TABLE = "VLAN"

ACTION_ALLOW = "FORWARD"
ACTION_DENY = "DROP"
ACTION_LIST = "LIST"

# Since SONiC supports multi ACL rules share the same priority, we use 9999 (the highest) for all rules
ACL_RULE_PRIORITY = 9999
# The key of rule will be overridden with BLOCK_RULE_ + ip
ACL_RULE_PREFIX = 'BLOCK_RULE_'

# Internet Protocol version 4 EtherType
ETHER_TYPE_IPV4 = 0x0800

def notice(msg):
    """
    Log a NOTICE message to the console and syslog
    """
    syslog.syslog(syslog.LOG_NOTICE, msg)
    print(msg)


def error(msg):
    """
    Log an ERR message to the console and syslog, and exit the program with an error code
    """
    syslog.syslog(syslog.LOG_ERR, msg)
    print(msg, file=sys.stderr)
    sys.exit(1)


def ip_ver(ip_prefix):
    return ipaddress.ip_network(ip_prefix, False).version


def confirm_required_table_existence(configdb, sub_table_name):
    """
    Check the existence of required ACL table, and exit if absent 
    """
    target_table = configdb.get_entry(CONFIG_DB_ACL_TABLE_TABLE, sub_table_name)

    if not target_table:
        error("Table {} not found, exiting...".format(sub_table_name))
    
    return True


def get_acl_rule_key(ip_prefix):
    """
    Get the key that will be used to refer to the ACL rule used to block traffic from a source ip.
    Since the rules are all given the same priority in SONiC, we can't identify a rule based on the priority.
    So, we use the destination IP being blocked to give each rule a unique name in the system.
    """
    return ACL_RULE_PREFIX + str(ip_prefix)


def get_all_acl_rules(configdb, table_name):
    """
    Return a dict of existed acl rules
    {(u'NULL_ROUTE_TABLE', u'BLOCK_RULE_1.1.1.1/32'): {'PRIORITY': '9999', 'PACKET_ACTION': 'FORWARD', 'SRC_IP': '1.1.1.1/32'},...}
    """
    key = CONFIG_DB_ACL_RULE_TABLE + '|' + table_name
    all_rules = configdb.get_table(key)
    block_rules = {}
    for k, v in all_rules.items():
        if k[1].startswith(ACL_RULE_PREFIX):
            block_rules[k] = v

    return block_rules


def validate_input(ip_address):
    """
    Validate the format of input
    """
    try:
        ip_n = ipaddress.ip_network(ip_address, False)
        ver = ip_n.version
        prefix_len = ip_n.prefixlen
        # Prefix len must be 32 for IPV4 and 128 for IPV6 
        if ver == 4 and prefix_len == 32 or ver == 6 and prefix_len == 128:
            return ip_n.with_prefixlen
        
        error("Prefix length must be 32 (IPv4) or 128 (IPv6)")
    except ValueError as e:
        error("Could not parse {} as a valid IP address; exception={}".format(ip_address, e))


def build_acl_rule(priority, src_ip):
    """
    Bild DROP rule for given src_ip and priority
    """
    rule = {
        "PRIORITY": str(priority),
        "PACKET_ACTION": "DROP"
    }
    if ip_ver(src_ip) == 4:
        rule['ETHER_TYPE'] = str(ETHER_TYPE_IPV4)
        rule['SRC_IP'] = src_ip
    else:
        rule['IP_TYPE'] = 'IPV6ANY'
        rule['SRC_IPV6'] = src_ip

    return rule


def get_rule(configdb, table_name, ip_prefix):
    """
    Get Acl rule for given ip_prefix
    """
    key_name = 'SRC_IP' if ip_ver(ip_prefix) == 4 else 'SRC_IPV6'
    all_rules = get_all_acl_rules(configdb, table_name)
    for key, rule in all_rules.items():
        if ip_prefix == rule.get(key_name, None):
            if ip_prefix:
                return {key: rule}

    return None


def update_acl_table(configdb, acl_table_name, ip_prefix, action):
    """
    Update ACL table to apply new rules for given ip_prefix. 'action' is supposed to be in ['DENY', 'ALLOW']
    For 'DENY', an 'DROP' rule for given ip_prefix will be added if not existed
    For 'ALLOW', we will try to remove the existing 'DENY' rule, and nothing is changed if not existed
    """
    confirm_required_table_existence(configdb, acl_table_name)
    rule = get_rule(configdb, acl_table_name, ip_prefix)
    rule_key = list(rule.keys())[0] if rule else None
    rule_value = list(rule.values())[0] if rule else None
    if action == ACTION_ALLOW:
        if not rule:
            return
        # Delete existing BLOCK rule for given ip_prefix
        # Pass None as data will delete the entry
        configdb.mod_entry(CONFIG_DB_ACL_RULE_TABLE, rule_key, None)
    else:
        if rule:
            if rule_value['PACKET_ACTION'] == 'DROP':
                return
            else:
                # If there is 'FORWARDED' ACL rule, then change it to 'DROP'
                rule_value['PACKET_ACTION'] = 'DROP'
                configdb.mod_entry(CONFIG_DB_ACL_RULE_TABLE, rule_key, rule_value)
        else:
            priority = ACL_RULE_PRIORITY
            new_rule_key = (acl_table_name, get_acl_rule_key(ip_prefix))
            new_rule_value = build_acl_rule(priority, ip_prefix)
            configdb.set_entry(CONFIG_DB_ACL_RULE_TABLE, new_rule_key, new_rule_value)


def list_all_null_route_rules(configdb, table_name):
    """
    List all rules added by this script
    """
    
    confirm_required_table_existence(configdb, table_name)
    header = ("Table", "Rule", "Priority", "Action", "Match")
    all_rules = get_all_acl_rules(configdb, table_name)

    match_keys = ["SRC_IP", "SRC_IPV6"]
    data = []
    for (_, rule_id), rule in all_rules.items():
        priority = rule.get("PRIORITY", "N/A")
        action = rule.get("PACKET_ACTION", "N/A")
        match = "N/A"
        for k in match_keys:
            if k in rule:
                match = rule[k]
                break

        data.append([table_name, rule_id, priority, action, match])

    print(tabulate.tabulate(data, headers=header, tablefmt="simple", missingval=""))


def null_route_helper(table_name, action, ip_prefix=None):
    """
    Helper function called by 'click'.
    """
    configdb = ConfigDBConnector()
    configdb.connect()
    if action == ACTION_LIST:
        list_all_null_route_rules(configdb, table_name)
    else:
        ip_prefix = validate_input(ip_prefix)
        update_acl_table(configdb, table_name, ip_prefix, action)


@click.group()
def cli():
    pass


# ./null_route_helper block table_name 1.2.3.4
@cli.command('block')
@click.argument("table_name", type=click.STRING, required=True)
@click.argument("ip_prefix", type=click.STRING, required=True)
def block(table_name, ip_prefix):
    """
    Block traffic from given src ip prefix
    """
    null_route_helper(table_name, ACTION_DENY, ip_prefix)


# ./null_route_helper unblock table_name 1.2.3.4
@cli.command('unblock')
@click.argument("table_name", type=click.STRING, required=True)
@click.argument("ip_prefix", type=click.STRING, required=True)
def unblock(table_name, ip_prefix):
    """
    Unblock traffic from given src ip prefix
    """
    null_route_helper(table_name, ACTION_ALLOW, ip_prefix)


# ./null_route_helper list table_name
@cli.command('list')
@click.argument("table_name", type=click.STRING, required=True)
def list_rules(table_name):
    """
    List all rules *added by this script*
    """
    null_route_helper(table_name, ACTION_LIST)


if __name__ == "__main__":
    cli()

