#!/usr/bin/env python

import os
import json
import copy
from copy import deepcopy
import subprocess
import math
import sys
import subprocess
import logging
import argparse


# specifying parameters for an uncategorized or unlabeled job
UNCATEGORIZED_CORES = 1
UNCATEGORIZED_MEMORY = 200 #MBytes
UNCATEGORIZED_DISK = 20 #MBytes
UNCATEGORIZED_TIME = 30 #sec
TCP_BANDWIDTH = 10 #MBytes/sec
VM_SPIN_UP_TIME = 100 #seconds
VM_SHUT_DOWN_TIME = 60 #seconds
VM_IDLE_TIME = 30 #seconds
MAX_VM_COUNT = 20
TIME_STRIDES = 60 #seconds
UTIL_PERCENTAGE = .5 # to identify if job is too small for a vm to be reused, e.g. reuse a vm iff job.cores > util% * vm.cores &&  job.memory > util% * vm.memory
VPC = None
SUBNET = None
GATEWAY = None
SECURITY_GROUP_ID = None
KEYPAIR_NAME = None
EXAMPLE_INSTANCE_TYPE = "t2.micro"

VM_INFO = {
  "c4.large": {
    "cores": 2,
    "memory": 3840,
    "price": 0.1
  },
  "c4.xlarge": {
    "cores": 4,
    "memory": 7680,
    "price": 0.199
  },
  "c4.2xlarge": {
    "cores": 8,
    "memory": 15360,
    "price": 0.398
  },
  "c4.4xlarge": {
    "cores": 16,
    "memory": 30720,
    "price": 0.796
  },
  "c4.8xlarge": {
    "cores": 36,
    "memory": 61440,
    "price": 1.591
  },
  "m4.large": {
    "cores": 2,
    "memory": 8192,
    "price": 0.1
  },
  "m4.xlarge": {
    "cores": 4,
    "memory": 7680,
    "price": 0.2
  },
  "m4.2xlarge": {
    "cores": 8,
    "memory": 15360,
    "price": 0.4
  },
  "m4.4xlarge": {
    "cores": 16,
    "memory": 30720,
    "price": 0.8
  },
  "m4.10xlarge": {
    "cores": 40,
    "memory": 61440,
    "price": 2
  },
  "m4.16xlarge": {
    "cores": 64,
    "memory": 61440,
    "price": 3.2
  }
}

# parsing the json file, find input files of all jobs
def find_input_files(makeflow_json):
    all_input_files = set()
    for rule in makeflow_json['rules']:
        if 'inputs' in rule.keys():
            for input_file in rule['inputs']:
                all_input_files.add(input_file)
    return all_input_files

# parsing the json file, find output files of all jobs
def find_output_files(makeflow_json):
    all_output_files = set()
    for rule in makeflow_json['rules']:
        if 'outputs' in rule.keys():
            for output_file in rule['outputs']:
                all_output_files.add(output_file)
    return all_output_files

# find all the initial files to start with, by set input files - set output files
def find_starting_files(makeflow_json):
    return find_input_files(makeflow_json).difference(find_output_files(makeflow_json))

# parsing the json file, find the job types of all jobs
# job: {cores:(INT), memory:(INT), disk:(INT), time:(INT), input:(LIST OF FILES), output:(LIST OF FILES), local_job:(TRUE FALSE)}
def parse_makeflow_json(makeflow_json):
    # parse job category info
    category_info = makeflow_json['categories']

    # parse makeflow jobs
    job_queue = list()
    for rule in makeflow_json['rules']:
        local_job = True if 'local_job' in rule else False
        if rule['category'] in category_info and 'local_job' not in rule:
            cores = float(category_info[rule['category']]['environment']['CORES'])
            disk = int(category_info[rule['category']]['environment']['DISK'])
            memory = int(category_info[rule['category']]['environment']['MEMORY'])
            time = int(category_info[rule['category']]['environment']['WALLTIME'])
            input_files = set(rule['inputs'])
            output_files = set(rule['outputs'])
            job_queue.append(workflow_job(time, cores, memory, disk, input_files, output_files, local_job))
        else:
            input_files = set(rule['inputs'])
            output_files = set(rule['outputs'])
            job_queue.append(workflow_job(UNCATEGORIZED_TIME, UNCATEGORIZED_CORES, UNCATEGORIZED_MEMORY, UNCATEGORIZED_DISK, input_files, output_files, local_job))
    return job_queue

############################################################
###################### workflow_job ########################
############################################################

class workflow_job:
    time = UNCATEGORIZED_TIME
    cores = UNCATEGORIZED_CORES
    memory = UNCATEGORIZED_MEMORY
    disk = UNCATEGORIZED_DISK
    local_job = True
    input_files = set()
    output_files = set()

    def __init__(self, time, cores, memory, disk, input_files, output_files, local_job):
        self.time = time + disk / TCP_BANDWIDTH
        self.cores = cores
        self.memory = memory
        self.disk = disk
        self.input_files = input_files
        self.output_files = output_files
        self.local_job = local_job

    def job_time_elapse(self, time):
        self.time -= time

    def print_job(self):
        print('This job has {} cores and {} memory {} disk and {} time'.format(self.cores, self.memory, self.disk, self.time))

############################################################
###################### aws_vm ##############################
############################################################

class aws_vm: #job, status, time, etc

    # three status: empty, executing, shutdown, spinup, idle, shutdownjob, spinupjob
    status = 'empty'
    current_job = None
    vm_type = None
    vm_info = None
    spin_up_time_left = VM_SPIN_UP_TIME
    shut_down_time_left = VM_SHUT_DOWN_TIME
    idle_time_left = VM_IDLE_TIME

    def __init__(self, status, current_job, vm_info):
        self.status = status
        self.current_job = current_job
        self.vm_info = vm_info

    def queue_new_job(self, new_job): #?
        self.current_job = new_job
        self.assign_smallest_vm(new_job)
        self.status = 'spinup'

    def idle_queue_job(self, new_job): #?
        self.current_job = new_job
        self.status = 'executing'

    def return_status(self):
        return self.status

    def return_type(self):
        return self.vm_type

    def time_reset(self):
        self.spin_up_time_left = VM_SPIN_UP_TIME
        self.shut_down_time_left = VM_SHUT_DOWN_TIME
        self.idle_time_left = VM_IDLE_TIME

    def elapse_time(self, time_elapsed): # a small time interval

        if self.status == 'executing':
            self.current_job.job_time_elapse(time_elapsed)
        elif self.status == 'shutdown':
            self.shut_down_time_left -= time_elapsed
        elif self.status == 'spinup':
            self.spin_up_time_left -= time_elapsed
        elif self.status == 'idle':
            self.idle_time_left -= time_elapsed
        self.transition_status()

    def transition_status(self): # maybe use switch

        if self.status == 'shutdown' and self.shut_down_time_left <= 0:
            self.time_reset()
            if self.current_job == None:
                self.status = 'empty'
            else:
                self.assign_smallest_vm(self.current_job) #?
                self.status == 'spinup'
        elif self.status == 'spinup' and self.spin_up_time_left <= 0:
            self.time_reset()
            self.status = 'executing'
        elif self.status == 'idle' and self.current_job == None and self.idle_time_left <= 0:
            self.time_reset()
            self.status = 'shutdown'

    def assign_smallest_vm(self, new_job):
        price = 99
        for vm_type in vm_info.keys():
            if vm_info[vm_type]['cores'] >= new_job.cores and vm_info[vm_type]['memory'] >= new_job.memory and vm_info[vm_type]['price'] < price:
                price = vm_info[vm_type]['price']
                self.vm_type = vm_type

    def job_vm_resource_utilize(self, new_job, util_percentage):
        if self.vm_info[self.vm_type]['cores'] * util_percentage >= new_job.cores and self.vm_info[self.vm_type]['memory'] * util_percentage >= new_job.memory:
            return True
        else:
            return False

    def job_fits(self, job):
        if job.cores <= self.vm_info[self.vm_type]['cores'] and job.memory <= self.vm_info[self.vm_type]['memory']:
            return True
        else:
            return False

    def print_vm_state(self):
        job = self.current_job
        print('vm status: ' + self.status)
        if self.vm_type:
            print('vm type: ' + self.vm_type)
        if self.current_job:
            print('current job has {} MB memory, {} cores and {} time left'.format(job.memory, job.cores, job.time))
        if self.status == 'spinup':
            print('virtualmachine spinning up with {} time left'.format(self.spin_up_time_left))
        if self.status == 'shutdown':
            print('virtualmachine shutting down with {} time left'.format(self.shut_down_time_left))

############################################################
###################### aws_server ########################## VM_QUEUE
############################################################
# class aws_server:
# approach: enum_int, list of instances

class aws_server:

    vm_list = list()
    total_price = 0.0
    max_vm_count = 1
    vm_info = None
    approach = 1

    def __init__(self, max_vm_count, vm_info, approach):
        self.max_vm_count = max_vm_count
        self.vm_info = vm_info
        self.approach = approach

        for i in range(max_vm_count):
            self.vm_list.append(aws_vm('empty', None, vm_info))

    def vm_empty_slots(self):
        empty_vm_count = 0
        for vm in self.vm_list:
            if vm.return_status() == 'empty':
                empty_vm_count += 1
        return empty_vm_count

    def vm_idle_slots(self):
        idle_vm_count = 0
        for vm in self.vm_list:
            if vm.return_status() == 'idle':
                idle_vm_count += 1
        return idle_vm_count

    def find_empty_vm_slot(self):
        for vm in self.vm_list:
            if vm.return_status() == 'empty':
                return vm

    def queue_jobs(self, job_queue):
        # approach 1, find an empty vm slot
        if self.approach == 1:
            for i in range(min(self.vm_empty_slots(), len(job_queue))):
                vm = self.find_empty_vm_slot()
                vm.queue_new_job(job_queue.pop())

        # approach 2, peek jobs or maybe look at all avaliable jobs, find the smallest idle vm, if not see if we have an empty slot if not shut down one slot
        # approac 2* and 3* look at job and find the smallest vm that fits, if not, we just give it the first idle vm

        elif self.approach == 2: #fixed

            for i in reversed(range(len(job_queue))):
                vm = self.find_idle_vm_slot(job_queue[i])
                if vm != None:
                    vm.idle_queue_job(job_queue.pop(i))
#             # original
#             for i in range(min(self.vm_idle_slots(), len(job_queue))): #minimum between jobs and idle vms
#                 vm = self.find_idle_vm_slot(job_queue[0])              #find an idle vm for the first job
#                 if vm != None:                                         #if found, queue this job
#                     vm.idle_queue_job(job_queue.pop())

            for i in range(min(self.vm_empty_slots(), len(job_queue))): #minimum between jobs and emply vm slots
                vm = self.find_empty_vm_slot()                          #put job into. empty slot and start job
                vm.queue_new_job(job_queue.pop())

        elif self.approach == 3:

            for i in reversed(range(len(job_queue))):
                vm = self.find_utilizable_idle_vm_slot(job_queue[i], UTIL_PERCENTAGE)
                if vm != None:
                    vm.idle_queue_job(job_queue.pop(i))

#             for i in range(min(self.vm_idle_slots(), len(job_queue))):
#                 vm = self.find_utilizable_idle_vm_slot(job_queue[0], .5)
#                 vm.idle_queue_job(job_queue.pop())

            for i in range(min(self.vm_empty_slots(), len(job_queue))):
                vm = self.find_empty_vm_slot()
                vm.queue_new_job(job_queue.pop())


    def print_vm_states(self):
        for vm in self.vm_list:
            vm.print_vm_state()

    def elapse_time(self, time_elapsed):
        for vm in self.vm_list:
            vm.elapse_time(time_elapsed)
            if vm.return_status() == 'executing' or vm.return_status() == 'idle': # or idle
                vm_price = vm_info[vm.return_type()]['price']
                self.total_price += float(time_elapsed * vm_price)/3600
# #             vm.transition_status(approach, 0.5, vm_info)
#         if self.approach == 3:
#             self.print_vm_states()


    def isEmpty(self):
        for vm in self.vm_list:
            if vm.status != 'empty':
                return False
        return True

    def find_idle_vm_slot(self, job):
        # among all that fits, give it smallest
        smallest_idle_price = 99
        vm_found = None
        for vm in self.vm_list:
            if vm.status == 'idle' and job.cores <= vm_info[vm.vm_type]['cores'] and job.memory <= self.vm_info[vm.vm_type]['memory']:
                if vm_info[vm.vm_type]['price'] < smallest_idle_price:
                    vm_found = vm
                    smallest_idle = self.vm_info[vm.vm_type]['price']
        return vm_found


    def find_utilizable_idle_vm_slot(self, job, util_percentage):
        smallest_idle_price = 99
        vm_found = None
        for vm in self.vm_list:
            if vm.status == 'idle' and vm.job_vm_resource_utilize(job, util_percentage): # bug: job_vm_resource_utilize key none
                if vm_info[vm.vm_type]['price'] < smallest_idle_price:
                    vm_found = vm
                    smallest_idle = self.vm_info[vm.vm_type]['price']
        return vm_found


    def finish_jobs_return_files(self):
        output_files = set()
        for vm in self.vm_list:
            if vm.status == 'executing' and vm.current_job != None and vm.current_job.time <= 0:
                job_output_files = copy.deepcopy(vm.current_job.output_files)
                output_files = output_files.union(job_output_files)
                vm.current_job = None
                if self.approach == 1:
                    vm.status = 'shutdown'
                else:
                    vm.status = 'idle'
        return output_files

    def clear(self):
        self.vm_list = []
        self.total_price = 0.0
        for i in range(self.max_vm_count):
            self.vm_list.append(aws_vm('empty', None))

############################################################
###################### makeflow_aws_module: ################ JOB_QUEUE
############################################################

# responsible of finding out jobs_execution, job_queue situation, which jobs to prepare for

class makeflow_aws_module:
    job_queue = list()
    jobs_ready = list()
    local_files = set()
    global_clock = 0
    aws_config = None
    vm_info = None
    approach = 1

    def __init__(self, makeflow_json, max_vm_count, vm_info, approach):
        self.local_files = find_starting_files(makeflow_json)
        self.job_queue = parse_makeflow_json(makeflow_json)
        self.aws_config = aws_server(max_vm_count, vm_info, approach)
        self.approach = approach
        self.vm_info = vm_info

    def find_ready_jobs(self):
        for job in self.job_queue:
            if job.input_files.issubset(self.local_files):
                self.jobs_ready.append(job)
                self.job_queue.remove(job)

    def queue_jobs(self):
        self.aws_config.queue_jobs(self.jobs_ready)

    def elapse_time(self, time_elapsed):
        self.aws_config.elapse_time(time_elapsed)
        self.retrieve_finished_files()

    def retrieve_finished_files(self):
        self.local_files = self.local_files.union(self.aws_config.finish_jobs_return_files())


    def workflow_finished_submitting(self):
        if len(self.job_queue) == 0 and len(self.jobs_ready) == 0:
            return True
        else:
            return False

if __name__ == '__main__':

    # parser = argparse.ArgumentParser(description='A useful tool created to estimate Makeflow workflow execution time and cost on AWS EC2 instances.', formatter_class=argparse.RawTextHelpFormatter)
    parser = argparse.ArgumentParser(description='Basic Options: ')
    # file type
    parser.add_argument('makefile_json', metavar='',
                    help='makefile.json')

    parser.add_argument("-l", "--local", help="estimate workflow execution locally. Not creating a vm to measure spinup time, shutdown time, network bandwidth etc", action="store_true")
    parser.add_argument('-b','--tcp_bandwidth', type=int, default = 10, help='Set the tcp bandwidth for file transfer speed (default is 10 Mbytes/sec)')
    parser.add_argument('-m','--max_vm_count', type=int, default = 20, help='Set the maximum vm count (default is 20 vms)')
    parser.add_argument('-t','--time_interval', type=int, default = 60, help='Set the time interval for vm status update cycle (default is 60s)')
    parser.add_argument('-i','--idle_time_left', type=int, default = 30, help='Set the amount of time idle (default is 30s)')
    parser.add_argument('-p','--util_percentage', type=float, default = .5, help='Set the percentage for a vm to be considered propertly utilized (0 ~ 1.0, default is .5)')
    parser.add_argument('-V','--vm_info', type=str, help='Specifying a json file with avaliable vm name, memory and price info, see script for default dictionary')
    parser.add_argument('-c','--config_file', type=str, help='config file generated by makeflow_ec2_setup, specifying subnet, image-id etc')
    parser.add_argument('-d','--debug', help='Enables debugging and prints debugging logs', action="store_true")


    args = parser.parse_args()
    # read optional argument values
    if args.tcp_bandwidth is not None:
        TCP_BANDWIDTH = args.tcp_bandwidth
    if args.max_vm_count is not None:
        MAX_VM_COUNT = args.max_vm_count
    if args.time_interval is not None:
        TIME_STRIDES = args.time_interval
    if args.idle_time_left is not None:
        VM_IDLE_TIME = args.idle_time_left
    if args.util_percentage is not None:
        UTIL_PERCENTAGE = args.util_percentage
    if args.debug:
        logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)

    if args.vm_info is not None:
        with open(args.vminfo, 'r') as vm_info_f:
            vm_info = json.load(vm_info_f)
    else:
        vm_info = VM_INFO

    if args.config_file is not None:
        f = open(args.config_file,'r')
        for line in f.readlines():
            words = line.split()
            if words[0] == 'vpc':
                VPC = words[1]
            if words[0] == 'subnet':
                SUBNET = words[1]
            if words[0] == 'gateway':
                GATEWAY = words[1]
            if words[0] == 'security_group_id':
                SECURITY_GROUP_ID = words[1]
            if words[0] == 'keypair_name':
                KEYPAIR_NAME = words[1]
            if words[0] == 'ami':
                IMAGE_NAME = words[1]

    if not args.local:
        sys.stdout.write("Checking for aws command in PATH...")
        if not os.system("which aws >/dev/null 2>&1"):
            print("okay")
        else:
            print("failed")
            print("{}: The \"aws\" command must be in your path to use this script.").format(sys.argv[0])
            exit(1)

        sys.stdout.write("Checking for aws configuration...")
        if os.path.isfile(os.path.expanduser("~/.aws/config")):
            print("okay")
        else:
            print("failed")
            print("{}: You must run \"aws configure\" before using this script.").format(sys.argv[0])
            # os.echo()
            exit(1)

        print("Checking for correct credentials...")
        if not os.system("aws ec2 describe-instances >/dev/null 2>&1"):
            print("okay\n\n")
        else:
            print("failed")
            print("{}: Your Amazon credentials are not set up correctly. Try \"aws ec2 describe-instances\" to troubleshoot.").format(sys.argv[0])
            exit(1)
        if args.config_file is not None:
            # run instance
            run_vm_cmd = "aws ec2 run-instances --subnet {} --image-id {} --instance-type {} --key-name {} --security-group-ids {} --associate-public-ip-address --output json".format(SUBNET, IMAGE_NAME, EXAMPLE_INSTANCE_TYPE, KEYPAIR_NAME, SECURITY_GROUP_ID)
            run_vm_proc = subprocess.Popen(run_vm_cmd, stdout=subprocess.PIPE, shell=True)
            run_vm_json = run_vm_proc.stdout.read()
            run_vm_dict = json.loads(run_vm_json.rstrip())

            print(run_vm_dict)

    vm_info = VM_INFO

    # load makefile after makeflow_viz
    with open(args.makefile_json.rstrip(), 'r') as makeflow_json_info_f:
        makeflow_json = json.load(makeflow_json_info_f)

    for app_num in [1,2,3]:
        #   initializa makeflow_module with makeflow_json, -j option and vm_info
        my_makeflow_aws_module = makeflow_aws_module(makeflow_json, MAX_VM_COUNT, vm_info, approach = app_num)
        time = 0
        my_makeflow_aws_module.find_ready_jobs()

        while not (my_makeflow_aws_module.workflow_finished_submitting() and my_makeflow_aws_module.aws_config.isEmpty()):
            logging.debug(time)
            my_makeflow_aws_module.find_ready_jobs()
            logging.debug('len of job queue {}'.format(len(my_makeflow_aws_module.job_queue)))
            logging.debug('len of jobs ready {}'.format(len(my_makeflow_aws_module.jobs_ready)))
            my_makeflow_aws_module.queue_jobs()
            my_makeflow_aws_module.elapse_time(TIME_STRIDES)
            # see what job finishes, keep track if they are finished, add their finished files created to the pool
            # my_makeflow_aws_module.aws_config.print_vm_states()
            time += TIME_STRIDES

        if app_num == 1:
            print('We took the approach #1 where no vms are reused')
        elif app_num == 2:
            print('We took the approach #2 where all vms are reused')
        elif app_num == 3:
            print('We took the approach #3 where suitable(utilized percentage >= {}) vms are reused'.format(UTIL_PERCENTAGE))

        print('Workflow executed for {} seconds'.format(time))
        print('AWS total cost is {} dollars\n'.format(my_makeflow_aws_module.aws_config.total_price))
