import subprocess
import itertools
import argparse
import sys
import os
import time
import random

argparser = argparse.ArgumentParser()
argparser.add_argument('--partition', default='p:31', type=str)
argparser.add_argument('--repeats', default=1, type=int)
argparser.add_argument('--cpu_memory', default='48GB', type=str)
argparser.add_argument('--output_dir', default='hyperparams', type=str)
argparser.add_argument('--script', type=str)
argparser.add_argument('--exp_remark', type=str)

args = argparser.parse_args(sys.argv[1:])

user = os.environ["USER"]

datetime_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

out_dir = os.path.join(args.output_dir, "tune-{}-".format(args.exp_remark) + datetime_str)
print("Writing to output dir: %s" % out_dir)

if not os.path.exists(out_dir):
  os.makedirs(out_dir)

partition_maxjobs = [p.split(':') for p in args.partition.split(',')]
partition_maxjobs = [(s, int(v)) for s, v in partition_maxjobs]

# these will be passed as a list of hyperparams to be parsed by tf.contrib.HParams
params = {

  # 'parse_dep_headcount': [1, 2, 5, 8],
  # 'parse_dep_injection': ['injection'],
    'prod_mode': ['eye-end_node-noop', 'noop-end_node-noop'],
    'use_labeled_adjacency_mtx_hparams_option': [True],
    'head_label_aggregation': ['masking_with_gate'],
    'label_score_aggregation': ['expectation'],
    'use_strength_bias': [False],
    # 'on_value': [-1e0, -2e0, -5e-1, -2e-1],
  # set random seed randomly, sort of
    'random_seed': [int(time.time()) + i for i in range(args.repeats)]
}

# for SA
# predicate_layers="2 3 4"

# for LISA
# parents_layers="parents:4 parents:5"
# predicate_layers="3 4"


def make_job_str(_setting):
    name_setting = {n: _s for n, _s in zip(names, _setting)}
    ban_parameter = ['use_labeled_adjacency_mtx_hparams_option']
    # setting_list = ['--%s %s' % (name, str(value)) for name, value in name_setting.items()]
    # _setting_str = ' '.join(setting_list)
    setting_list = ["%s=%s" % (name, str(value)) for name, value in name_setting.items()]
    _setting_str = "--hparams %s" % ','.join(setting_list)
    name_setting_for_log = {n: _s for n, _s in name_setting.items() if n not in ban_parameter}
    _log_str = '___'.join(map(str, name_setting_for_log.values()))
    return _log_str, _setting_str


def add_to_partition(_partition, _setting_str, _log_str):
    job_name = _setting_str[10:]
    slurm_cmd = 'srun --gres=gpu:1 --partition=%s -J %s --cpus-per-task=9 --mem-per-cpu=10240MB' % (_partition, job_name)
    # create dir for this specific job
    log_dir = '%s/%s' % (out_dir, _log_str)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    # write run cmd to file in logdir
    with open('%s/%s' % (log_dir, 'run.cmd'), 'w') as outf:
        outf.write('%s %s\n' % (args.script, _setting_str))
    save_str = "--save_dir %s" % os.path.join(log_dir, "model")
    # create bash cmd which directs into a log
    full_cmd = '%s %s %s %s' % (slurm_cmd, args.script, _setting_str, save_str)
    bash_cmd = '%s > %s/train.log 2>&1 &' % (full_cmd, log_dir)
    print(bash_cmd)
    subprocess.call(bash_cmd, shell=True)

print(args.script)
# exit()
names, all_params = zip(*[(k, v) for k, v in params.items()])
all_jobs = list(itertools.product(*all_params))
# print(all_jobs)
# exit()
print('Starting %d jobs' % (len(all_jobs)))
random.shuffle(all_jobs)

for setting in all_jobs:
    log_str, setting_str = make_job_str(setting)
    added = False
    while not added:
        for partition, max_jobs in partition_maxjobs:
            # only run max_jobs at once
            running_jobs = int(subprocess.check_output('squeue -u %s -p %s | wc -l'
                                                       % (user, partition), shell=True))
            if running_jobs <= max_jobs and not added:
                add_to_partition(partition, setting_str, log_str)
                added = True
            else:
                time.sleep(1)
            time.sleep(1)


print('Done. Ran %d jobs.' % len(all_jobs))
