#!python
## Script to list partition information and set the boot partition

import re
import os
import sys
import argparse
import logging
import tempfile
from sonic_py_common.general import getstatusoutput_noshell

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)

re_fixedp = r'\d*\.\d+|\d+'
re_hex = r'[0-9A-F]'

## Returns:
##   None   - if the return code of the command is non-zero
##   String - the standard output of the command, may be empty string
def runcmd(cmd):
    logger.info('runcmd: {0}'.format(cmd))
    rc, out = getstatusoutput_noshell(cmd)
    if rc == 0:
        return out
    else:
        logger.error('Failed to run: {0}\n{1}'.format(cmd, out))
        return None

def print_partitions(blkdev):
    assert blkdev
    out = runcmd(['sudo', 'lsblk', '-r', '-o', 'PARTLABEL,NAME'])
    ## Parse command output and print
    found_table = False
    for line in out.splitlines():
        if not found_table:
            found_table = re.match(r'PARTLABEL NAME', line) is not None
            if found_table:
                logger.info('Got the table head!')
            continue

        ## Parse the table
        m = re.match(r'(\S*) ([a-z]+)(\d*)', line)
        if not m:
            logger.warn('Unexpected lsblk output: %s', line)
            break
        ## Note: label is hex-escaped (\x<code>)
        label = m.group(1)
        label = label.decode('string-escape')
        name  = m.group(2)
        ## Note: index may be empty for the block device
        index = m.group(3)
        if name != blkdev or not index:
            continue
        print(index, label)

## Get the current boot partition index
def get_boot_partition(blkdev):
    out = runcmd(['cat', '/proc/mounts'])
    if out is None: return None

    ## Parse command output and return the current boot partition index
    for line in out.splitlines():
        m = re.match(r'{0}(\d+) / .*'.format(blkdev), line)
        if not m:
            continue
        index = m.group(1)
        return int(index)

    logger.error('Unexpected /proc/mounts output: %s', out)
    return None

def set_boot_partition(blkdev, index):
    ## Mount the partition
    assert index is not None
    devnode = blkdev + str(index)
    mntpath = tempfile.mkdtemp()
    try:
        out = runcmd(['sudo', 'mount', devnode, mntpath])
        logger.info('mount out={0}'.format(out))
        if out is None: return
        ## Set GRUB bootable
        out = runcmd(['sudo', 'grub-install', '--boot-directory='+mntpath, "--recheck", blkdev])
        return out is not None
    finally:
        ## Cleanup
        cmd1 = ['sudo', 'fuser', '-km', mntpath]
        rc1, out1 = getstatusoutput_noshell(cmd1)
        if rc1 != 0:
            logger.error('Failed to run: {0}\n{1}'.format(cmd1, out1))
            cmd2 = ['sudo', 'unmount', mntpath]
            rc2, out2 = getstatusoutput_noshell(cmd2)
            if rc2 == 0:
                logger.info('Running command: {0}\n{1}'.format(' '.join(cmd2), out2))
            else:
                logger.error('Failed to run: {0}\n{1}'.format(cmd2, out2))
        else:
            logger.info('Running command: {0}\n{1}'.format(' '.join(cmd1), out1))
        os.rmdir(mntpath)

def main():
    parser = argparse.ArgumentParser(description='The default output is a list of partition information.')
    parser.add_argument('-i', '--index', type=int,
        help='Set the index partition as boot partition. After reboot, GRUB will boot from that partion.')
    parser.add_argument("-v", "--verbose", action="store_true",
        help="increase output verbosity")
    args = parser.parse_args()

    if args.verbose:
        logger.setLevel(logging.INFO)

    ## Find ONIE partition and get the block device containing ONIE
    out = runcmd(["sudo", "blkid"])
    if not out: return -1
    for line in out.splitlines():
        m = re.match(r'/dev/(\w+)\d+: LABEL="ONIE-BOOT"', line)
        if not m: continue
        blkdev0 = m.group(1)
        blkdev = '/dev/' + blkdev0
        logger.info('blkdev={0}'.format(blkdev))
        break
    else:
        logger.error('Cannot find block device containing ONIE')
        return -1

    cur = get_boot_partition(blkdev)
    if cur is None:
        logger.error('Failed to get boot partition')
        return -1
    print('Current rootfs partition is: {0}'.format(cur))

    ## Handle the command line
    if args.index is None:
        print_partitions(blkdev0)
        return 0
    if args.index <= 0:
        logger.error('Index should be larger than 0')
        return -1

    logger.info("cur={0}".format(cur))
    if cur == args.index:
        logger.info('No action needed, the partition is already the boot partition')
        return 0

    suc = set_boot_partition(blkdev, args.index)
    if not suc:
        logger.error('Failed to set boot partition')
        return -1
    return 0

if __name__ == "__main__":
    rc = main()
    sys.exit(rc)
