import os

def train_vae(
    gpu, exp_name, batch_size=4, eval_batch_size=1, accum_steps=1, model_type='v1',
    lr=1e-4, warmup_steps=5000, num_epochs=3, num_steps=99999999999, valid_every=500, from_dapt=False,
    add_input=False, add_attn=False, add_softmax=False, attn_proj_vary=False, learn_prior=False,
    tuning_all_after_iters=2000, beta_0=0.0, cycle=20000, LAMBDA_0=1.0, num_category=100,
    max_length=32, min_length=0, num_beams=1, repetition_penalty=1.0,
    gpu_ratio=0.85, n_device=8, home_path="/home2/zhaoxl"
):

    params = {

        'train_file': 'pretrain_data/reddit_train_large.h5',
        'valid_file': 'pretrain_data/reddit_valid.jsonl'

        'batch_size': batch_size,
        'eval_batch_size': eval_batch_size,
        'num_steps': num_steps,
        'accum_steps': accum_steps,
        'lr': lr,
        'clip': 1.0,
        'model_type': model_type,

        'weight_decay': 0.01,
        'warmup_steps': warmup_steps,
        'adam_epsilon': 1e-8,  # 1e-8
        'num_epochs': num_epochs,
        'tuning_all_after_iters': tuning_all_after_iters,
        'num_category': num_category,

        'print_every': 10,
        'valid_every': valid_every,

        'exp_name': exp_name,
        'log': 'reddit/log',
        'seed': 42,

        'gpt2_cache_dir': '{}/Data/pretrain-models/gpt2'.format(home_path),
        'pretrain_file': '',
        'max_utterance_len': 32, # 32
        'max_utterance_num': 8,
        'add_input': add_input,
        'add_attn': add_attn,
        'add_softmax': add_softmax,
        'attn_proj_vary': attn_proj_vary,
        'learn_prior': learn_prior,

        'beta_0': beta_0,
        'cycle': cycle,
        'LAMBDA_0': LAMBDA_0,

        'max_length': max_length,
        'min_length': min_length,
        'num_beams': num_beams,
        'repetition_penalty': repetition_penalty,

        'gpu_list': gpu,
        'gpu_ratio': gpu_ratio,
        'n_device': n_device,
    }
    command = "python -u train_vae.py"
    sorted_params = sorted(params.items(), key=lambda x: x[0])
    for param in sorted_params:
        command += " --{}={}".format(param[0], param[1])
    print(command)
    command += ' | tee reddit/{}.txt'.format(params['exp_name'])
    os.system(command)

def test_vae(
    gpu, exp_name,  eval_batch_size=1, model_type='v1', pretrain_file=None,
    add_input=False, add_attn=False, add_softmax=False, attn_proj_vary=False, learn_prior=False,
    max_length=32, min_length=0, num_beams=1, repetition_penalty=1.0,
    gpu_ratio=0.85, n_device=8, home_path="/home2/zhaoxl"
):

    params = {

        'test_daily_file': 'downstream_data/dailydialog/test.jsonl',
        'test_convai_file': "downstream_data/convai2/valid_original.jsonl",

        'eval_batch_size': eval_batch_size,
        'model_type': model_type,


        'exp_name': exp_name,
        'log': 'reddit/log',
        'seed': 42,

        'gpt2_cache_dir': '{}/Data/pretrain-models/gpt2'.format(home_path),
        'pretrain_file': 'reddit/log/{}/checkpoints/model-best'.format(pretrain_file),
        # 'max_utterance_len': 64,
        'max_utterance_len': 32,
        'max_utterance_num': 8,
        'add_input': add_input,
        'add_attn': add_attn,
        'add_softmax': add_softmax,
        'attn_proj_vary': attn_proj_vary,
        'learn_prior': learn_prior,

        'max_length': max_length,
        'min_length': min_length,
        'num_beams': num_beams,
        'repetition_penalty': repetition_penalty,

        'gpu_list': gpu,
        'gpu_ratio': gpu_ratio,
        'n_device': n_device,
    }
    command = "python -u test_vae.py"
    sorted_params = sorted(params.items(), key=lambda x: x[0])
    for param in sorted_params:
        command += " --{}={}".format(param[0], param[1])
    print(command)
    command += ' | tee reddit/{}.txt'.format(params['exp_name'])
    os.system(command)




def train_vae_ft(
    gpu, exp_name, batch_size=4, eval_batch_size=1, accum_steps=1, model_type='v1',
    dataset='dailydialog', pretrain_file=None,
    lr=1e-4, warmup_steps=5000, num_epochs=3, num_steps=99999999999, valid_every=500,
    add_input=False, add_attn=False, add_softmax=False, attn_proj_vary=False, learn_prior=False,
    tuning_all_after_iters=2000, beta_0=0.0, cycle=20000, LAMBDA_0=1.0,
    max_length=32, min_length=0, num_beams=1, repetition_penalty=1.0,
    gpu_ratio=0.85, n_device=8, home_path="/home2/zhaoxl"
):
    train_file = {
        'dailydialog': 'downstream_data/dailydialog/train.jsonl',
        'convai2': 'downstream_data/convai2/train.jsonl',
    }
    valid_file = {
        'dailydialog': 'downstream_data/dailydialog/valid.jsonl',
        'convai2': 'downstream_data/convai2/valid.jsonl',
    }
    max_utterance_len = {
        'dailydialog': 32,
        'convai2': 64,
    }

    params = {

        'train_file': train_file[dataset],
        'valid_file': valid_file[dataset],

        'batch_size': batch_size,
        'eval_batch_size': eval_batch_size,
        'num_steps': num_steps,
        'accum_steps': accum_steps,
        'lr': lr,
        'clip': 1.0,
        'model_type': model_type,

        'weight_decay': 0.01,
        'warmup_steps': warmup_steps,
        'adam_epsilon': 1e-8,  # 1e-8
        'num_epochs': num_epochs,
        'tuning_all_after_iters': tuning_all_after_iters,

        'print_every': 10,
        'valid_every': valid_every,

        'exp_name': exp_name,
        'log': 'reddit/log',
        'seed': 42,

        'gpt2_cache_dir': '{}/Data/pretrain-models/gpt2'.format(home_path),
        'pretrain_file': 'reddit/log/{}/checkpoints/model-best'.format(pretrain_file),
        'max_utterance_len': max_utterance_len[dataset], # 32
        'max_utterance_num': 8,
        'add_input': add_input,
        'add_attn': add_attn,
        'add_softmax': add_softmax,
        'attn_proj_vary': attn_proj_vary,
        'learn_prior': learn_prior,

        'beta_0': beta_0,
        'cycle': cycle,
        'LAMBDA_0': LAMBDA_0,

        'max_length': max_length,
        'min_length': min_length,
        'num_beams': num_beams,
        'repetition_penalty': repetition_penalty,

        'gpu_list': gpu,
        'gpu_ratio': gpu_ratio,
        'n_device': n_device,
    }
    command = "python -u train_vae_ft.py"
    sorted_params = sorted(params.items(), key=lambda x: x[0])
    for param in sorted_params:
        command += " --{}={}".format(param[0], param[1])
    print(command)
    command += ' | tee reddit/{}.txt'.format(params['exp_name'])
    os.system(command)

if __name__ == '__main__':
    # pre-train on Reddit
    train_vae('4,5,6,7', 'reddit_pretrain', from_dapt=True, LAMBDA_0=1.0, beta_0=1.0, cycle=2000, tuning_all_after_iters=5000, add_input=False, add_attn=False, add_softmax=True, attn_proj_vary=False, learn_prior=True, batch_size=8, eval_batch_size=8, accum_steps=4, model_type='v1', lr=2e-5, warmup_steps=5000, num_epochs=1, num_steps=10000000, valid_every=20000, max_length=32, min_length=8, num_beams=4, repetition_penalty=1.0, gpu_ratio=0.85, n_device=8, home_path="/home2/zhaoxl")
    test_vae('4', 'convai_zero_test', pretrain_file='reddit_pretrain', add_input=False, add_attn=False, add_softmax=True, attn_proj_vary=False, learn_prior=True, eval_batch_size=8, model_type='v1', max_length=32, min_length=4, num_beams=5, repetition_penalty=1.0, gpu_ratio=0.85, n_device=8, home_path="/home2/zhaoxl")
    test_vae('5', 'dailydialog_zero_test', pretrain_file='reddit_pretrain', add_input=False, add_attn=False, add_softmax=True, attn_proj_vary=False, learn_prior=True, eval_batch_size=8, model_type='v1', max_length=32, min_length=6, num_beams=5, repetition_penalty=1.0, gpu_ratio=0.85, n_device=8, home_path="/home2/zhaoxl")
    

    # fine-tune on DailyDialog
    train_vae_ft('4', 'daily_ft', dataset='dailydialog', pretrain_file='reddit_pretrain', LAMBDA_0=1.0, beta_0=1.0, cycle=2000, tuning_all_after_iters=0, add_input=False, add_attn=False, add_softmax=True, attn_proj_vary=False, learn_prior=True, batch_size=4, eval_batch_size=16, accum_steps=4, model_type='v1', lr=5e-5, warmup_steps=500, num_epochs=20, num_steps=10000000, valid_every=2000, max_length=32, min_length=5, num_beams=1, repetition_penalty=1.0, gpu_ratio=0.85, n_device=8, home_path="/home2/zhaoxl")

    # fine-tune on ConvAI
    train_vae_ft('5', 'convai_ft', dataset='convai2', pretrain_file='reddit_pretrain', LAMBDA_0=1.0, beta_0=1.0, cycle=2000, tuning_all_after_iters=0, add_input=False, add_attn=False, add_softmax=True, attn_proj_vary=False, learn_prior=True, batch_size=2, eval_batch_size=8, accum_steps=8, model_type='v1', lr=5e-5, warmup_steps=500, num_epochs=20, num_steps=10000000, valid_every=2000, max_length=32, min_length=15, num_beams=3, repetition_penalty=1.0, gpu_ratio=0.85, n_device=8, home_path="/home2/zhaoxl")

    


