#!/usr/bin/python3
#
# (C) 2006-2007 XenSource Ltd.
#!/usr/bin/python
#
# (C) 2006-2007 XenSource Ltd.
# Copyright (C) 2008-2010 Citrix Ltd.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation; version 2.1 only.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
#

import errno
import os
import sys
import XenAPIPlugin
import time
from threading import Thread
from xmlrpc import client
import datetime
import XenAPI
import random
import syslog
import fcntl
from contextlib import contextmanager

INITIAL_RUN_TIME = u'19700101T00:00:00Z'
VMSS_SYSLOG_FACILITY = syslog.LOG_LOCAL1
LOG_INFO = syslog.LOG_INFO
POOL_CONF_FILE = "/etc/xensource/pool.conf"
VM_THREAD_MAX = 1
POLICY_THREAD_MAX = 1
VERBOSE = True

errorcode_to_error_map = {
    'VMSS_SNAPSHOT_LOCK_FAILED': 'The snapshot phase is already executing for this snapshot policy. Please try again later',
    'VMSS_SNAPSHOT_SUCCEEDED': 'Successfully performed the snapshot phase of the snapshot policy',
    'VMSS_SNAPSHOT_FAILED': 'The snapshot phase of the snapshot policy failed',
    'VMSS_XAPI_LOGON_FAILURE':'Could not login to API session',
    'VMSS_SNAPSHOT_MISSED_EVENT': 'A scheduled snapshot event was missed due to another on-going scheduled snapshot run. This is unexpected behaviour, please re-configure your snapshot sub-policy',
}


def log_message(message, ident="VMSS", priority=LOG_INFO):
    for message_line in str(message).split('\n'):
        syslog.openlog(ident, 0, VMSS_SYSLOG_FACILITY)
        syslog.syslog(priority, "[%d] %s" % (os.getpid(), message_line))
        syslog.closelog()


def get_current_host_ref(session):
    with open("/etc/xensource-inventory") as fd:
        for line in fd:
            if line.strip().startswith("INSTALLATION_UUID"):
                uuid = line.split("'")[1]
                return session.xenapi.host.get_by_uuid(uuid)


class Lock:

    # Simple file-based lock on a local FS. With shared reader/writer
    # attributes. Replicating SM lock class since importing the same
    # violates design principles as suggested by Germano

    BASE_DIR = "/var/lock/vmss"

    def _open(self):
        """Create and open the lockable attribute base, if it doesn't exist.
        (But don't lock it yet.)"""

        # one directory per namespace
        self.nspath = os.path.join(Lock.BASE_DIR, self.ns)

        # the lockfile inside that namespace directory per namespace
        self.lockpath = os.path.join(self.nspath, self.name)

        number_of_enoent_retries = 10

        while True:
            self._mkdirs(self.nspath)

            try:
                self._open_lockfile()
            except IOError as e:
                # If another lock within the namespace has already
                # cleaned up the namespace by removing the directory,
                # _open_lockfile raises an ENOENT, in this case we retry.
                if e.errno == errno.ENOENT:
                    if number_of_enoent_retries > 0:
                        number_of_enoent_retries -= 1
                        continue
                raise
            break

    def _open_lockfile(self):
        """Provide a seam, so extreme situations could be tested"""
        log_message("lock: opening lock file {0:s}" .format(self.lockpath))
        self.lockfile = open(self.lockpath, "w+")

    def _close(self):
        """Close the lock, which implies releasing the lock."""
        if self.lockfile is not None:
            self.lockfile.close()
            log_message("lock: closed {0:s}" .format(self.lockpath))
            self.lockfile = None

    def _mknamespace(ns):

        if ns is None:
            return ".nil"

        assert not ns.startswith(".")
        assert ns.find(os.path.sep) < 0
        return ns
    _mknamespace = staticmethod(_mknamespace)

    def __init__(self, name, ns=None):
        self.lockfile = None

        self.ns = Lock._mknamespace(ns)

        assert not name.startswith(".")
        assert name.find(os.path.sep) < 0
        self.name = name

        self._open()

    __del__ = _close

    def cleanup(name, ns = None):
        ns = Lock._mknamespace(ns)
        path = os.path.join(Lock.BASE_DIR, ns, name)
        if os.path.exists(path):
            Lock._unlink(path)

    cleanup = staticmethod(cleanup)

    def cleanupAll(ns = None):
        ns = Lock._mknamespace(ns)
        nspath = os.path.join(Lock.BASE_DIR, ns)

        if not os.path.exists(nspath):
            return

        for file in os.listdir(nspath):
            path = os.path.join(nspath, file)
            Lock._unlink(path)

        Lock._rmdir(nspath)

    cleanupAll = staticmethod(cleanupAll)

    #
    # Lock and attribute file management
    #

    def _mkdirs(path):
        """Concurrent makedirs() catching EEXIST."""
        if os.path.exists(path):
            return
        try:
            os.makedirs(path)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise LockException("Failed to makedirs({0:s})" .format(path))
    _mkdirs = staticmethod(_mkdirs)

    def _unlink(path):
        """Non-raising unlink()."""
        log_message("lock: unlinking lock file {0:s}" .format(path))
        try:
            os.unlink(path)
        except Exception as e:
            log_message("Failed to unlink({0:s}): {1:s}" .format(path, e))
    _unlink = staticmethod(_unlink)

    def _rmdir(path):
        """Non-raising rmdir()."""
        log_message("lock: removing lock dir {0:s}" .format(path))
        try:
            os.rmdir(path)
        except Exception as e:
            log_message("Failed to rmdir({0:s}): {1:s}" .format(path, e))
    _rmdir = staticmethod(_rmdir)

    #
    # Actual Locking
    #

    def acquire(self):
        """Blocking lock aquisition, with warnings. We don't expect to lock a
        lot. If so, not to collide. Coarse log statements should be ok
        and aid debugging."""
        fd = self.lockfile.fileno()
        fcntl.flock(fd, fcntl.LOCK_EX)

        if VERBOSE:
            log_message("lock: acquired {0:s}" .format(self.lockpath))

    def acquireNoblock(self):
        """Acquire lock if possible, or return false if lock already held"""
        fd = self.lockfile.fileno()
        try:
            fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
            ret = True
        except IOError as ioe:
            if ioe.errno not in [errno.EAGAIN, errno.EACCES, errno.EWOULDBLOCK]:
                raise
            ret = False

        if VERBOSE:
            log_message("lock: tried lock {0:s}, acquired: {1:b}"
                    .format(self.lockpath, ret))
        return ret


    def release(self):
        """Release a previously acquired lock."""
        fd = self.lockfile.fileno()
        fcntl.flock(fd, fcntl.LOCK_UN)

        if VERBOSE:
            log_message("lock: released {0:s}" .format(self.lockpath))

def get_session():
    session = XenAPI.xapi_local()
    try:
        session.xenapi.login_with_password('__dom0__vmss','')
    except Exception as e:
        raise Exception("%s. Error: %s" %
                        (errorcode_to_error_map['VMSS_XAPI_LOGON_FAILURE'],
                        str(e)))
    return session

@contextmanager
def xapi_session():
    # By tightly coupling session creation and session destroy, we
    # prevent session leaks and can add extra intelligence in the
    # future to retry etc.

    session = None
    try:
        session = get_session()
        yield session
        # Do not capture any exception from yield on purpose. Let
        # it be handled outside this session manager.
    finally:
        if session is not None:
            session.xenapi.session.logout()

def destroy_snapshot(session, snap_ref):
    try:
        # Generate a list of VM VDIs so we can verify snapshot VDIs are related
        vdimap = {}
        VBDs = session.xenapi.VM.get_VBDs(snap_ref)
        for vbd in VBDs:
            if session.xenapi.VBD.get_type(vbd) == 'Disk':
                # store the vdi
                vdimap[session.xenapi.VBD.get_VDI(vbd)] = '1'

        # Now destroy the VM
        # First hard_shutdown the VM, this is required for a checkpoint
        try:
            session.xenapi.VM.hard_shutdown(snap_ref)
        except Exception as e:
            # This must be a snapshot, rather than a checkpoint
            pass

        # Now try destroying the VM, this should work for both snapshot
        # and checkpoint
        try:
            session.xenapi.VM.destroy(snap_ref)
        except Exception as e:
            log_message("Could not destroy the VM: {0:s}, error: {1:s}"
                        .format(snap_ref, e))

        for vdi in vdimap.keys():
            try:
                session.xenapi.VDI.destroy(vdi)
            except Exception as e:
                log_message("Could not destroy the vdi: {0:s}. Error: {1:s}"
                            .format(vdi, e))

        # Now attempt to destroy the VBDs, if not dont worry about it,
        # they will be GCd later
        for vbd in VBDs:
            try:
                session.xenapi.VBD.destroy(vbd)
            except:
                pass

    except Exception as e:
        log_message("Could not destroy snapshot successfully, please destroy "
                    "the snapshot {0:s} manually. Error: {1:s}" .format(snap_ref,e))

class TakeSnapshot(Thread):
    def __init__(self, session, vmss_ref, vm_ref):
        Thread.__init__(self) # init the thread
        self.session = session
        self.vmss_ref = vmss_ref
        self.vm_ref = vm_ref
        self.ret_val = str(True)

    def run(self):
        try:

            # Identify the snapshot type
            snapshot_type = self.session.xenapi.VMSS.get_type(
                                                     self.vmss_ref)

            # Now create the snapshot name
            vm = self.session.xenapi.VM.get_uuid(self.vm_ref)
            vm_name = self.session.xenapi.VM.get_name_label(self.vm_ref)
            snap_name = ("%s-%s-%s" %
                         (vm_name, vm[0:16],time.strftime("%Y%m%d-%H%M",
                          time.localtime())))
            snap_name = snap_name.replace(' ', '-')

            # Start a snapshot/checkpoint operation for the VM
            log_message("Processing VM: {0:s} " .format(vm))
            if snapshot_type == "snapshot":
                snap_ref = self.session.xenapi.VM.snapshot(
                                                self.vm_ref, snap_name)
            elif snapshot_type == "checkpoint":
                snap_ref = self.session.xenapi.VM.checkpoint(
                                                self.vm_ref, snap_name)
            elif snapshot_type == "snapshot_with_quiesce":
                snap_ref = self.session.xenapi.VM.snapshot_with_quiesce(
                    self.vm_ref, snap_name)

            # Set the snapshot name to DDMMYYYY-HHMM
            timeOfSnap = str(self.session.xenapi.VM.get_snapshot_time(
                                                snap_ref))

            snap_name = ("%s%s%s-%s%s" %
                         (timeOfSnap[0:4], timeOfSnap[4:6], timeOfSnap[6:8],
                          timeOfSnap[9:11], timeOfSnap[12:14]))
            self.session.xenapi.VM.set_name_label(snap_ref, snap_name)

            # Find the oldest snapshot and delete it if required
            snaps = self.session.xenapi.VM.get_snapshots(self.vm_ref)
            oldest = self.session.xenapi.host.get_servertime(
                        get_current_host_ref(self.session))
            noOfSnaps = 0
            oldestSnap = ''
            current_retention_value = int(self.session.xenapi.VMSS.get_retained_snapshots(
                                                         self.vmss_ref))
            for snap_ref in snaps:
                if not self.session.xenapi.VM.get_is_vmss_snapshot(snap_ref):
                    continue
                else:
                    noOfSnaps += 1

                if oldest > self.session.xenapi.VM.get_snapshot_time(snap_ref):
                    oldest = self.session.xenapi.VM.get_snapshot_time(snap_ref)
                    oldestSnap = snap_ref

            # If the no of snapshots has past the retention value (it
            # should just be past by 1, if not throw an exception)
            if (noOfSnaps > current_retention_value):
                if (noOfSnaps -
                       int(self.session.xenapi.VMSS.get_retained_snapshots(
                                                         self.vmss_ref))) != 1:
                    log_message("WARNING: The difference between number of "
                                "snapshots ({0:d}) and the retention value ({1:d}) is more "
                                "than one, this is an inconsistent state. "
                                "Please contact your pool operator." .format(noOfSnaps,current_retention_value))

                if not oldestSnap:
                    raise Exception("no snapshots found older than current snapshot.")

                log_message("Snapshot retention value reached for VM: {0:s}. "
                            "Deleting the oldest snapshot: {1:s}" .format(
                    vm, self.session.xenapi.VM.get_uuid(oldestSnap)))

                destroy_snapshot(self.session, oldestSnap)

            log_message("Completed processing VM: {0:s} " .format(vm))

        except Exception as e:
            log_message("snapshot failed with exception: {0:s}" .format(e))
            self.ret_val = str(e)


    def result(self):
        return (self.ret_val, self.vm_ref)

def process_VMs(session, vmss_ref):

    vm_to_snapshot_result = {}
    ret_val = True
    task_ref = None
    task_status = "success"
    error = ''

    try:
        snapshot_schedule = session.xenapi.VMSS.get_schedule(vmss_ref)
        snapshot_schedule['frequency'] = \
            session.xenapi.VMSS.get_frequency(
                        vmss_ref)

        # Get the time before snapshots
        before = datetime.datetime.now()

        #create xapi task here
        vmss_uuid = session.xenapi.VMSS.get_uuid(vmss_ref)
        task_name = "Executing policy: " + vmss_uuid
        task_ref = session.xenapi.task.create(task_name,'')
        task_uuid = session.xenapi.task.get_uuid(task_ref)
        log_message("task: {0:s} created for policy: {1:s}" .format(task_uuid,
                                                                    vmss_uuid))

        vms = session.xenapi.VMSS.get_VMs(vmss_ref)
        no_of_vms = len(vms)

        if no_of_vms % VM_THREAD_MAX:
            no_of_batches = no_of_vms // VM_THREAD_MAX + 1
        else:
            no_of_batches = no_of_vms // VM_THREAD_MAX

        entire_list_threads = []
        log_message("No of VMs: {0:d}" .format(no_of_vms))
        log_message("Max number of threads for processing VM: {0:d}" .format(
            VM_THREAD_MAX))

        for iter in range(0,no_of_batches):
            listThreads = []
            log_message("VM Batch: {0:d}"  .format(iter))
            for vmindex in range(0,VM_THREAD_MAX):
                realIndex = iter * VM_THREAD_MAX + vmindex
                if realIndex < no_of_vms:
                    # In each of these threads
                    s = TakeSnapshot(session, vmss_ref, vms[realIndex])
                    listThreads.append(s)
                    entire_list_threads.append(s)
                else:
                    break

            # Start the batch of threads simultaneously.
            for thread in listThreads:
                thread.start()

            # Wait till all the threads in a batch have finished.
            for thread in listThreads:
                thread.join()


        # If the snapshot failed for one or more VMs generate an
        # appropriately formatted error message to return to the caller.

        for thread in entire_list_threads:
            vm_to_snapshot_result[thread.result()[1]] = thread.result()[0]
            if thread.result()[0] != str(True):
                log_message("The snapshot for VM {0:s} failed with exception "
                            "{1:s}" .format(thread.result()[1],
                                            thread.result()[0]))
                ret_val = False
                task_status = "failure"
        # Get the time after the snapshot
        after = datetime.datetime.now()

        # Get the last expected run time according to the schedule
        last_expected_run_time = \
                        get_last_expected_run_time(snapshot_schedule)

        log_message("snapshot start time: {0:s}, end time: {1:s}, "
                    "Last expected run time: {2:s}" .format(str(before),
                                                            str(after),
                                                            str(last_expected_run_time)))

        # When the number of VMs are more and scheduled interval is less there
        # are changes for the schedule policy to not get triggered since
        # the previous invocation of this schedule is still running. We will
        # have to notify this event.

        if (before < last_expected_run_time and after > last_expected_run_time):
            create_alert( session, session.xenapi.VMSS.get_uuid(vmss_ref),
                          "warn", create_structured_alert(
                    session, 'warn', {},
                    "VMSS_SNAPSHOT_MISSED_EVENT"), create_email_body(
                    session, {}, 'VMSS_SNAPSHOT_MISSED_EVENT'),
                          "VMSS_SNAPSHOT_MISSED_EVENT")

        # We reach this point only if there were no exceptions raised  hence
        # update the snapshot last executed time for future reference

        session.xenapi.VMSS.set_last_run_time(vmss_ref,client.DateTime(str(
        client.DateTime(time.mktime(datetime.datetime.utcnow().timetuple()))) + "Z"))

    except Exception as e:
        log_message("The snapshot for the schedule policy {0:s} failed with "
                    "exception {1:s}" .format(vmss_uuid, e))
        error = str(e)
        ret_val = False
        task_status = "failure"

    finally:
        if task_ref:
            session.xenapi.task.set_status(task_ref, task_status)
        return (ret_val, vm_to_snapshot_result, error)

#
#TODO: remove unwanted arguments passed by xapi
#
def schedule_snapshots( xapi1, xapi2):

    # schedule_snapshots function gets called from xapi hostcall plugin,
    # by default hostcall plugin passes two arguments to the function being called
    # so we accept (xapi1, xapi2) and ignore them

    ret_val = str(True)
    try:
        child_list = []
        with xapi_session() as session:

            # Get all VMSS objects from the system
            vmss_list = session.xenapi.VMSS.get_all()
            vmss_list = random.sample(vmss_list, len(vmss_list))

            for vmss in vmss_list:
                if not session.xenapi.VMSS.get_enabled(vmss):
                    continue
                # Handle each object in a separate thread.
                s = ProcessPolicy(vmss)
                child_list.append(s)

        # In case the list is non empty, spawn threads in batches from a child
        # process

        if child_list:
            if os.fork() == 0:

                # Place a lock to have only one instance of VMSS running at any
                # given time

                vmss_lock = create_global_lock()
                if not acquire_lock(vmss_lock):
                    raise Exception("%s" %
                            errorcode_to_error_map["VMSS_SNAPSHOT_LOCK_FAILED"])

                no_of_policies = len(child_list)

                if no_of_policies % POLICY_THREAD_MAX:
                    no_of_batches = no_of_policies // POLICY_THREAD_MAX + 1
                else:
                    no_of_batches = no_of_policies // POLICY_THREAD_MAX

                log_message("No of Policies: %s" % no_of_policies)
                log_message("Max number of threads allocated for policy : %d" %
                      POLICY_THREAD_MAX)
                for iter in range(0,no_of_batches):
                    list_threads = []
                    log_message("Policy Batch: %s" % iter)
                    for index in range(0,POLICY_THREAD_MAX):
                        real_index = iter * POLICY_THREAD_MAX + index
                        if real_index < no_of_policies:
                            # In each of these threads
                            list_threads.append(child_list[real_index])
                        else:
                            break

                    # Start all the threads simultaneously.
                    for thread in list_threads:
                        thread.start()

                    # Wait till all the threads have finished.
                    for thread in list_threads:
                        thread.join()

                if vmss_lock:
                    release_lock(vmss_lock)

    except Exception as e:
        log_message("Exception in schedule_snapshots: {0:s}" .format(e))
        ret_val = str(e)

    finally:
        return ret_val

class ProcessPolicy(Thread):
    def __init__(self, vmss_ref):
        Thread.__init__(self) # init the thread
        self.vmss_ref = vmss_ref


    def run(self):
        snapshot = False
        args = {}

        try:
            # Get the last snapshot run time for the policy from XAPI.
            with xapi_session() as session:
                vmss_uuid = session.xenapi.VMSS.get_uuid(self.vmss_ref)
                snapshot_last_run_time = \
                    session.xenapi.VMSS.get_last_run_time(
                                                        self.vmss_ref)

                # Get the snapshot schedule details for the policy from XAPI.

                snapshot_schedule = \
                    session.xenapi.VMSS.get_schedule(
                                                   self.vmss_ref)
                snapshot_schedule['frequency'] = \
                    session.xenapi.VMSS.get_frequency(self.vmss_ref)

            # Use the snapshot schedule, last snapshot run time and the current
            # time to figure out if a snapshot should be executed now.

            snapshot = \
                should_operation_be_run(
                            snapshot_schedule, snapshot_last_run_time)

            # Prepare args for execute_policy

            args['vmss_uuid'] = vmss_uuid
            if snapshot:
                execute_policy("None", args)
            else:
                log_message("Not processing policy: {0:s}" .format(
                    vmss_uuid))

        except Exception as e:
            log_message("ProcessPolicy failed with exception: {0:s}" .format(e))


def get_last_expected_run_time(schedule, inUTC = False):
    last_expected_run_time = None

    try:
        now = datetime.datetime.now()

        # check operation frequency
        if schedule['frequency'] == 'hourly':
            # calculate the last expected run time, based on the current time.
            if now.minute > int(schedule['min']):
                # current mins are more than schedule mins so no need to
                # change the hour
                last_expected_run_time = \
                    datetime.datetime(now.year, now.month, now.day, now.hour,
                                      int(schedule['min']),0,0)
            else:
                last_expected_run_time = \
                    (datetime.datetime(now.year, now.month, now.day, now.hour,
                                       int(schedule['min']),0,0) -
                     datetime.timedelta(hours = 1))
        elif schedule['frequency'] == 'daily':
            # calculate the last expected run time, based on the
            # current date and time.
            if (now.hour > int(schedule['hour']) or
               ((now.hour == int(schedule['hour'])) and
                (now.minute > int(schedule['min'])))):
                # current hours are more than schedule hours so no need
                # to change the day
                last_expected_run_time = \
                    datetime.datetime(now.year, now.month, now.day,
                                      int(schedule['hour']),
                                      int(schedule['min']),0,0)
            else:
                last_expected_run_time = \
                    (datetime.datetime(now.year, now.month, now.day,
                                       int(schedule['hour']),
                                       int(schedule['min']),0,0) -
                     datetime.timedelta(days = 1))
        elif schedule['frequency'] == 'weekly':
            # First create a map of the days in the schedule for
            # ease of computation later
            dayMap = {}
            for day in schedule['days'].split(','):
                dayMap[day] = '1'

            lastDayFound = False

            # calculate the last expected run time, based on the current
            # date and time.
            # if current time is less than scheduled time
            if (now.hour < int(schedule['hour']) or
                ((now.hour == int(schedule['hour']) and
                 now.minute < int(schedule['min'])))):
                # go to the last day on the scheduled list excluding today
                noOfDays = 1
            else:
                # go to the last day on the scheduled list including today
                noOfDays = 0

            newDate = now
            while not lastDayFound and noOfDays < 8:
                td = datetime.timedelta(days = noOfDays)
                newDate = now - td
                if newDate.strftime("%A") in dayMap:
                    lastDayFound = True
                else:
                    noOfDays += 1

            if not lastDayFound:
                raise Exception("Could not find the last expected execution "
                                "time for the schedule: %s"
                                % schedule)

            # generate a date with this day and the scheduled time
            last_expected_run_time = \
                datetime.datetime(newDate.year, newDate.month, newDate.day,
                                  int(schedule['hour']),
                                  int(schedule['min']),0,0)

                # Now check if this needs to be converted into UTC time
        if inUTC:
            secs = time.mktime(last_expected_run_time.timetuple())
            last_expected_run_time = time.gmtime(secs)

    except Exception as e:
        log_message("There was an exception in finding out the last expected "
                    "run time of a schedule. {0:s}" .format(e))

    return last_expected_run_time

def should_operation_be_run(schedule, last_run_time):
    try:
        # Get the current time
        now = datetime.datetime.utcnow()

        # check if the operation is due because of the schedule
        if is_due_for_run(schedule):
            return True

        # not due for run yet, if the last run time is the initial
        # time then check if we still have to run it!
        if last_run_time == INITIAL_RUN_TIME:
            last_expected_run_time = \
            client.DateTime(get_last_expected_run_time(schedule, True))
            if client.DateTime(time.mktime(now.timetuple())) > last_expected_run_time:
                log_message("scheduling policy for first time")
                return True
            return False


        # if not, check if it should be run anyways as it wasnt run in
        # the last timeslot for some reason

        # if the last run time is in the future then run the operation
        if (last_run_time >
                 client.DateTime(time.mktime(now.timetuple()))):
            log_message("The last run time is in the future then run the "
                        "operation, run the operation to be safe.")
            return True

        last_expected_run_time = \
                client.DateTime(get_last_expected_run_time(schedule, True))

        # Now check if the last run time was before the last expected run time
        if last_run_time < last_expected_run_time:
            log_message("The last expected run time was {0:s}, however the "
                        "operation was last run at {1:s}, hence run it "
                        "again." .format(last_expected_run_time, last_run_time))
            return True
        else:
            return False

    except Exception as e:
        log_message("Exception in should_operation_be_run: {0:s}" .format(e))
        return False

def is_due_for_run(schedule):
    try:
        # Find the current time and extract required information
        now = datetime.datetime.now()
        day = now.strftime("%A")
        hour = now.hour
        min = now.minute

        # Now compare with the schedule passed in
        if min != int(schedule['min']):
            return False

        if schedule['frequency'] == 'hourly':
            return True

        if hour != int(schedule['hour']):
            return False

        if schedule['frequency'] == 'daily':
            return True

        # If we have come to this point the frequency is definitely weekly
        # however still putting in a check just in case we have monthly
        # frequency in the later releases

        day_map = {}
        for dayofweek in schedule['days'].split(','):
            day_map[dayofweek] = '1'

        return schedule['frequency'] == 'weekly' and day in day_map

    except Exception as e:
        log_message("Exception in is_due_for_run: {0:s}" .format(e))
        return False

    return False

#
#TODO: remove unwanted arguments passed by xapi
#
def execute_policy(xapi1, args):
    ret_val = str(True)
    try:
        log_message("Processing policy: {0:s}" .format(args['vmss_uuid']))
        with xapi_session() as session:
            policy_lock = None
            vmss_uuid = args['vmss_uuid']
            vmss_ref = session.xenapi.VMSS.get_by_uuid(vmss_uuid)
            if not session.xenapi.VMSS.get_enabled(vmss_ref):
                log_message("Policy {0:s} is not enabled" .format(args[
                                                                      'vmss_uuid']))
                return ret_val # true

            if len(session.xenapi.VMSS.get_VMs(vmss_ref)) == 0:
                log_message("No VMs assigned to policy: {0:s}" .format(args[
                                'vmss_uuid']))
                return ret_val # true

            # we reach this point only when we need to process a policy
            # therefore acquire a policy lock

            policy_lock = get_snapshot_lock(vmss_uuid)
            if not acquire_lock(policy_lock):
                create_alert(session, vmss_uuid, "warn",
                             create_structured_alert(session, 'warn',{},
                                        "VMSS_SNAPSHOT_LOCK_FAILED"),
                             create_email_body(session, {},"VMSS_SNAPSHOT_LOCK_FAILED"),
                             "VMSS_SNAPSHOT_LOCK_FAILED")
                raise Exception("%s" %
                        errorcode_to_error_map["VMSS_SNAPSHOT_LOCK_FAILED"])

            (ret_val_snapshot, vm_to_snapshot_result, error_snapshot) = \
                process_VMs(session, vmss_ref)

            if ret_val_snapshot:
                create_alert(session, vmss_uuid, "info",
                             create_structured_alert(session, 'info',{},
                                        "VMSS_SNAPSHOT_SUCCEEDED"),
                             create_email_body(session, vm_to_snapshot_result),
                             "VMSS_SNAPSHOT_SUCCEEDED")
            else:
                # Generate snapshot schedule failure alerts here.
                create_alert(session, vmss_uuid, "error",
                             create_structured_alert(session, 'error',
                                                     vm_to_snapshot_result),
                    create_email_body(session, vm_to_snapshot_result),
                    "VMSS_SNAPSHOT_FAILED")
                log_message("process_VMs failed with the following error "
                            "details: {0:s} and {1:s}." .format(
                    vm_to_snapshot_result, error_snapshot))

                raise Exception

            log_message("Completed processing policy: {0:s}" .format(vmss_uuid))

    except Exception as e:
        log_message("Exception in execute_policy: {0:s}" .format(e))
        ret_val = '%s.' % str(e)

    finally:
        if policy_lock:
            release_lock(policy_lock)  # release policy lock
        return ret_val

def acquire_lock(l):
    try:
        return l.acquireNoblock()

    except Exception as e:
        log_message("There was an exception in acquiring lock. Exception: {"
                    "0:s}" .format(e))
        return False

def create_global_lock():
    return Lock("schedule.all","vmss")

def get_snapshot_lock(vmss_uuid):
    return Lock("%s.running" % vmss_uuid, "vmss")

def release_lock(l):
    try:
        l.release()
        return True

    except Exception as e:
        log_message("There was an exception releasing lock. Exception: {0:s}" .format(e))
        return False

def trigger_schedule_snapshots():

# This function is the entry point for cron job
# Algo:
# 1. Check if the host is master, if not exit
# 2. Check if atleast one VMSS is enabled, if yes then call schedule_snapshots

    f = open(POOL_CONF_FILE,'r')
    if f.read() != 'master':
        return

    try:
        with xapi_session() as session:
            log_message("===Kicking cron job for VMSS===")
            call_plugin = False
            for vmss in session.xenapi.VMSS.get_all():
                if session.xenapi.VMSS.get_enabled(vmss):
                    call_plugin = True
                    break
            if not call_plugin:
                log_message("VMSS policy not enabled for this pool, Exiting cron "
                            "job.")
            else:
                # Find the local host uuid
                host_ref = get_current_host_ref(session)
                text = session.xenapi.host.call_plugin( host_ref, "vmss",
                                                        "schedule_snapshots", {})

    except Exception as e:
        log_message("Exception in trigger_schedule_snapshots: %s" % str(e))

def create_email_body(session, vm_to_error_map = {}, error_code = '',
                      additional_error_info = ''):
    # This will only be called for errors and warnings so we
    # do not need an alert type First handle the case where a vm to error map
    # is passed in.

    failed_VMs = 0
    error_str = ''
    try:
        if vm_to_error_map != {}:
            for vm in vm_to_error_map.keys():
                if vm_to_error_map[vm] != str(True):
                    failed_VMs += 1
                    vm_uuid = session.xenapi.VM.get_uuid(vm)
                    vm_name = session.xenapi.VM.get_name_label(vm)
                    error_str += ("VM: %s UUID: %s Error:%s" %
                                 (vm_name, vm_uuid, vm_to_error_map[vm]))
                    error_str += ',\n'

            error_str = error_str.strip('\n')
            error_str = error_str.strip(',')

            return ("Snapshot failed on {0:d} out of {1:d} VMs with the "
                    "following errors: \n\nDetails:\n{2:s}" .format(failed_VMs,
                                          len(vm_to_error_map.keys()), error_str))

        # Now handle if an error code is passed in
        if error_code != '':
            if additional_error_info != '':
                return ("failed with error: {0:s}. Additional error details: "
                        "{1:s}." .format(errorcode_to_error_map[error_code],
                                errorcode_to_error_map[additional_error_info]))
            else:
                return ("failed with error: {0:s}." .format(errorcode_to_error_map[error_code]))

    except:
        log_message("Exception in create_email_body")
        return ''

def create_structured_alert(session, alert_type,
                          vm_to_error_map = {}, error_code = ''):
    try:
        data_str = ("<XCData><time>%s</time><messagetype>%s</messagetype>" %
                   (datetime.datetime.now(), alert_type))
        if alert_type == 'error':
            if vm_to_error_map != {}:
                # Normal error with a vm to error map
                for vm in vm_to_error_map.keys():
                    if vm_to_error_map[vm] != str(True):
                        vm_uuid = session.xenapi.VM.get_uuid(vm)
                        error = vm_to_error_map[vm].split(',')[0]
                        error = error.lstrip('[')
                        error = error.lstrip('\'')
                        error = error.rstrip(']')
                        error = error.rstrip('\'')
                        data_str += \
                          ("<error><vm>%s</vm><errorcode>%s</errorcode></error>"
                          % (vm_uuid, error))

        if alert_type == 'warn' or alert_type == 'info':
            data_str += '<message>%s</message>' % (error_code)

        data_str += "</XCData>"

        log_message ("RETURN in create_structured_alert: {0:s}" .format(str(
            data_str)))
        return data_str
    except:
        log_message ("Exception in create_structured_alert")
        return ''

def create_alert(session, vmss_uuid, alert_type, structured_alert,
                 email_body = '', error_code = ''):
    try:
        if alert_type == 'error':
            session.xenapi.message.create(error_code, "1", "VMSS", vmss_uuid,
                                email_body)
        elif alert_type == 'warn':
            session.xenapi.message.create(error_code, "3", "VMSS", vmss_uuid,
                                        email_body)
        elif alert_type == 'info':
            session.xenapi.message.create(error_code, "5", "VMSS", vmss_uuid,
                                        error_code)
    except Exception as e:
        log_message("Failed to create alerts for vmss {0:s} with alert level: "
                    "{1:s}. Error: {2:s}" .format(
            vmss_uuid, alert_type,str(e)))

if __name__ == "__main__":
    log_message("Entering VMSS")
    XenAPIPlugin.dispatch({"schedule_snapshots": schedule_snapshots,
                           "snapshot_now": execute_policy})
