#! /usr/bin/python3

import os
import socket
import subprocess
import sys
import threading
import time

import yaml

from bottle import abort, Bottle, request, static_file

CONFIG_PATH = '/etc/sonic/arista-provision-config.yaml'

class WaitForSonicThread(threading.Thread):
   def __init__(self, slotId, timeout=5 * 60, step=1):
      super().__init__(name='wait for %d' % slotId)
      self.slotId = slotId
      self.timeout = timeout
      self.step = step
      self.ip = '127.100.%d.1' % slotId

   def log(self, fmt, *args):
      prefix = 'WaitForSlot(%d): ' % self.slotId
      print(prefix + fmt % args, file=sys.stderr)

   def ping(self):
      cmd = ["ping", "-c1", "-w5", self.ip]
      res = subprocess.run(cmd, capture_output=True)
      return res.returncode == 0

   def checkSonic(self):
      # NOTE: simple SSH check for now. assume sonic is fine if ssh is up
      try:
         c = socket.create_connection((self.ip, 22))
         version = c.recv(1024).decode('ascii').rstrip()
         c.close()
         self.log('ssh is up with version: %s' % version)
      except socket.error:
         return False
      return True

   def _waitForAbootExited(self, until):
      while self.ping():
         if time.time() >= until:
            raise TimeoutError("Waiting for Aboot to exit timed out")
         if self.checkSonic():
            break
         time.sleep(self.step)

   def _waitForSonicReady(self, until):
      while not self.ping():
         if time.time() >= until:
            raise TimeoutError("Waiting for Sonic mgmt interface timed out")
         time.sleep(self.step)

      while not self.checkSonic():
         if time.time() >= until:
            raise TimeoutError("Waiting for Sonic to be ready timed out")
         time.sleep(self.step)

   def _disableSonicProvisioning(self):
      cmd = [
         "arista", "linecard", "-i", str(self.slotId),
         "provision", "--set", "none"
      ]
      subprocess.run(cmd, capture_output=True)

   def run(self):
      begin = time.time()
      end = begin + self.timeout

      try:
         self.log('waiting for Aboot phase to complete')
         self._waitForAbootExited(until=end)
      except TimeoutError as e:
         self.log('timeout %s', str(e))
         return

      try:
         self.log('waiting for Sonic to be ready')
         self._waitForSonicReady(until=end)
      except TimeoutError as e:
         self.log('timeout %s', str(e))
         return

      self.log('disabling provisioning')
      self._disableSonicProvisioning()

class ProvisionMonitor(object):
   def __init__(self, daemon):
      self.daemon_ = daemon
      self.tasks_ = {}

   def __del__(self):
      for task in self.tasks_.values():
         self.log('waiting for task %s', task)
         task.join()

   def log(self, fmt, *args):
      prefix = 'ProvisionMonitor(): '
      print(prefix + fmt % args, file=sys.stderr)

   def _startMonitoringTask(self, slotId):
      thread = WaitForSonicThread(slotId)
      self.tasks_[slotId] = thread
      thread.start()

   def checkPendingTasks(self):
      for slotId, task in list(self.tasks_.items()):
         if not task.is_alive():
            self.log('removing completed task %s', task)
            task.join(0)
            del self.tasks_[slotId]

   def monitorSlotId(self, slotId):
      self.checkPendingTasks()
      task = self.tasks_.get(slotId)
      if task is not None and task.is_alive():
         return
      self._startMonitoringTask(slotId)

class ProvisionDaemon(object):
   def __init__(self):
      self.app_ = Bottle()
      self.config_ = {
         'root_dir' : '/host/provision',
         'provision_address' : '127.100.1.1',
         'provision_port' : 12321,
      }

      self.monitor_ = ProvisionMonitor(self)

      self._setupRoutes()

   def _setupRoutes(self):
      self.app_.route('/provision/<path:path>', callback=self.handleProvision)

   def loadConf(self, confPath):
      with open(confPath) as c:
         conf = yaml.safe_load(c)
         self.config_.update(conf)

   def _generateProvisionManifest(self, fromDir):
      manifest = ''

      for pathDir, _, files in os.walk(fromDir):
         for f in files:
            filePath = os.path.join(pathDir, f)
            filePath = filePath[len(fromDir) + 1:]
            fileUrl = os.path.join('provision', filePath)
            fileTargetPath = os.path.join('/mnt/flash', filePath)
            manifest += '%s:%s\n' % (fileUrl, fileTargetPath)

      return manifest

   def handleProvision(self, path):
      slotId = request.params.get('slotId')
      if slotId is None:
         abort(404)
      try:
         slotId = int(slotId)
      except ValueError:
         abort(404)

      queryDiskDir = os.path.join(self.config_['root_dir'], str(slotId))
      if path == 'manifest':
         self.monitor_.monitorSlotId(slotId)
         return self._generateProvisionManifest(queryDiskDir)

      return static_file(path, root=queryDiskDir)

   def run(self):
      self.app_.run(host=self.config_['provision_address'],
                    port=self.config_['provision_port'])

if __name__ == '__main__':
   d = ProvisionDaemon()
   if os.path.exists(CONFIG_PATH):
      d.loadConf(CONFIG_PATH)
   d.run()
