import math
from datetime import datetime
from utils import run_script_with_kwargs
import random
import numpy as np


N_REPEAT = 7
TASKS = ['SST-2', 'QNLI']#['QQP', 'MNLI_M', 'MNLI_MM']#['COLA', 'QNLI', 'STS-B', 'MRPC', 'SST-2', 'RTE']#['QQP', 'QNLI', 'STS-B', 'MRPC', 'SST-2', 'COLA', 'RTE', 'MNLI_M', 'MNLI_MM']#['COLA', 'SST-2', 'QQP']#['COLA']
#['MRPC', 'SST-2', 'COLA', 'QQP']
#['QQP', 'QNLI', 'STS-B', 'MRPC', 'SST-2', 'COLA', 'RTE', 'MNLI_M', 'MNLI_MM']
# ['SST-2_48', 'SST-2_32', 'SST-2_16','MNLI_M_112', 'MNLI_M_96', 'MNLI_M_80', 'MNLI_M_64', 'MNLI_M_48',]
#['QQP', 'QNLI', 'IMDB', '20_NEWS_GROUP']#['STS-B', 'MRPC', 'SST-2', 'COLA', 'RTE', 'MNLI_M', 'MNLI_MM']#['IMDB', '20_NEWS_GROUP']##['MRPC', 'IMDB', 'RTE', '20_NEWS_GROUP']#'20_NEWS_GROUP']#['RACE', 'RACE_HIGH', 'IMDB', 'RTE']#['COLA', 'IMDB', 'SST-2', 'MNLI', 'QQP']#, 'RACE']#['QNLI', 'MRPC', 'QQP', 'MNLI', 'SST-2', 'COLA', 'STS-B', 'RTE']#['COLA', 'RTE', 'MNLI', 'SST-2', 'QQP']#['COLA', 'RTE', 'QQP', 'MNLI', 'STS-B', 'SST-2']#['RTE', 'QQP', 'MNLI', 'STS-B']#['STS-B', 'COLA', 'QQP']#, 'COLA', 'MNLI', 'QQP', 'RTE', 'STS-B']
NUM_EPOCHS = [3]
LRs = [2e-5]#, 5e-5, 1e-6, 5e-6, 1e-4, 5e-4, 1e-7]#[2e-6, 1e-6, 2e-5, 1e-5, 2e-4, 1e-4, 2e-3, 1e-3, 2e-2]
TRAIN_BATCH_SIZES = [48]#[8, 10, 12]#[48]
#COSINE_SIM_THRESHOLDS = [0.3, 0.5, 0.7, 0.9]#[0.1, 0.3, 0.5, 0.7, 0.9, 0.95]#[0.1, 0.3, 0.5, 0.7, 0.9]#[0.3, 0.5, 0.8]#[0.65, 0.75, 0.8, 0.9, 1]
#LAYERS_TO_TOKEN_REDUCE = [['layer_1'], ['layer_3'], ['layer_6'], ['layer_9']] #[['layer_1'], ['layer_3'], ['layer_6'], ['layer_9']]#[['layer_3', 'layer_6', 'layer_9'], ['layer_1'], ['layer_3'], ['layer_6'], ['layer_9']]#[]#[['layer_NO']]#[['layer_3', 'layer_6', 'layer_9']]#[['layer_1'], ['layer_3'], ['layer_6'], ['layer_9'], ['layer_10']]#[['layer_3', 'layer_6', 'layer_9']]#[['layer_1'], ['layer_3'], ['layer_6'], ['layer_9'], ['layer_10']]#[['layer_1'], ['layer_3'], ['layer_6']]#[['layer_10'], ['layer_7'], ['layer_3']]
REDUCTION_METHOD = ['KMEANSPLUS']#['CLS_NO_DUMMY', 'RANDOM', 'ATT_ONLY']#['RANDOM', 'ATT_ONLY']#['CLS_NO_DUMMY', 'RANDOM', 'ATT_ONLY', 'KMEANSPLUS']#['FIRSTX_NO_PRUNE'] #['CLS_NO_DUMMY', 'RANDOM', 'ATT_ONLY']#['CLS_NO_DUMMY', 'RANDOM', 'ATT_ONLY']#['CLS_NO_DUMMY', 'RANDOM', 'ATT_ONLY', 'KMEANSPLUS']#['KMEANSPLUS']#, 'CLS_NO_DUMMY', 'RANDOM', 'ATT_ONLY']#, 'CLS_NO_DUMMY', 'RANDOM']#['RANDOM', 'ATT_ONLY', 'CLS_NO_DUMMY']#['CLS_NO_DUMMY', 'ATT_ONLY', 'CLS_ATT']#['KMEANSPLUS', 'CLS_ATT']#['CLS_NO_DUMMY', 'ATT_ONLY'] #'CLS_NO_DUMMY'
#RETENTION_CONFIG = [('poly', 3)]#[{'layer_1': 0.3}, {'layer_3': 0.3}, {'layer_6': 0.3}, {'layer_9': 0.3}]#[('linear', 3), ('quadratic', 3), {'layer_3': 0.3}]

KCENTERS_PARAM_S = [1]#0.1, 0.2, 0.3, 0.4]
MARKER = 'full3-'#'full-'#'km_para'#'MoreAgg_km_1_2' #'' #'benchmark'
RETENTION_CONFIG = [
    #{},
    #('expo_decay', 0.15, 1),
    #('expo_decay', 0.15, 2),
    #('expo_decay', 0.15, 3),
     ('expo_decay', 0.1, 2),
    # ('expo_decay', 0.1, 3),
     ('expo_decay', 0.2, 2),
     ('expo_decay', 0.2, 3),
    # # ################### new
     ('expo_decay', 0.17, 2),
    # ('expo_decay', 0.18, 2),
     ('expo_decay', 0.19, 2),
    # ('expo_decay', 0.22, 2),
     ('expo_decay', 0.25, 2),
    # #####
     ('expo_decay', 0.17, 3),
    # ('expo_decay', 0.18, 3),
     ('expo_decay', 0.19, 3),
    # ('expo_decay', 0.22, 3),
     ('expo_decay', 0.25, 3),

    # #####################################
    # ('expo_decay', 0.25, 0), #[0,1,2,4, 8, 12] #[0.5, 0.25], [0,1,2,4,8,12]
    # ('expo_decay', 0.25, 1),
    # ('expo_decay', 0.25, 2),
    # ('expo_decay', 0.25, 4),
    # ('expo_decay', 0.25, 8),
    # ('expo_decay', 0.25, 10),
    # ('expo_decay', 0.5, 0),  # [0,1,2,4, 8, 12] #[0.5, 0.25], [0,1,2,4,8,12]
    # ('expo_decay', 0.5, 1),
    # ('expo_decay', 0.5, 2),
    # ('expo_decay', 0.5, 4),
    # ('expo_decay', 0.5, 8),
    # ('expo_decay', 0.5, 10),
    #('expo_decay', 0.75, 0),  # [0,1,2,4, 8, 12] #[0.5, 0.25], [0,1,2,4,8,12]
    #('expo_decay', 0.75, 1),
    #('expo_decay', 0.75, 2),
    #('expo_decay', 0.75, 6),
    #('expo_decay', 0.75, 8),
    #('expo_decay', 0.75, 10),

    #('expo_decay', 0.1, 2),
    #('expo_decay', 0.1, 5),
    #('expo_decay', 0.1, 8),
    # ('expo_decay', 0.3, 3),
    # ('expo_decay', 0.3, 6),
    # ('expo_decay', 0.3, 9),
    #('1layer', 16)
    #('poly', 3), #('poly', 10), #('poly', 20),
    #, ('poly', 50),
    #('linear', 3), ('linear', 10), #('linear', 50),
]
# RETENTION_CONFIG = [
#     #{},
#     {'layer_1': 0.3}, {'layer_3': 0.3}, {'layer_6': 0.3}, #{'layer_9': 0.3},
#     {'layer_1': 0.5}, {'layer_3': 0.5}, {'layer_6': 0.5}, #{'layer_9': 0.5},
#     {'layer_1': 0.7}, {'layer_3': 0.7}, {'layer_6': 0.7}, #{'layer_9': 0.7},
#     {'layer_1': 0.9}, {'layer_3': 0.9}, {'layer_6': 0.9}, #{'layer_9': 0.9},
# ]

MAX_SEQ_LEN_FINDER = {
    256: [240, 224, 208, 192, 160, 128, 112, 96, 80, 64, 48, 32, 16],
    128: [112, 96, 80, 64, 48, 32, 16], #add 8 for MRPC
    64: [32, 16, 8, 4],#[48, 32, 16, 8, 4],
}


MAX_SEQ_LENGTH_PER_DATASET = {
    "RACE": 512,
    '20_NEWS_GROUP': 512,
    "IMDB": 512,
    "RTE": 256,
    "COLA": 64,
    "SST-2": 64,
    'YELP_FULL': 512,
    'QNLI': 128,
    'QQP': 128,
    'STS-B': 128,
}

BATCH_SIZES_PER_DATASET = {
    'IMDB': 4,
    'RTE': 16,
    'RACE': 4,
    '20_NEWS_GROUP': 10,
    'YELP_FULL': 10,
    'MNLI_M': 32,
    'MNLI_MM': 32,
}

NUM_EPOCHS_PER_DATASET = {
    '20_NEWS_GROUP': 4,
}

LR_PER_DATASET = {
    "RACE": 5e-6,
}

#METHOD_NAME = 'CLS_NO_DUMMY'#'ATT_ONLY'#'ATT_CLS_NODUMMY'#'CLS_NO_DUMMY'


def get_retention_config_multiple_reduction(task, method, last_num_token, final_layer_idx=3):
    MAX_SEQ_LENGTH_PER_DATASET = {
        "RACE": 512,
        '20_NEWS_GROUP': 512,
        "IMDB": 512,
        "RTE": 256,
        "COLA": 64,
        "SST-2": 64,
        'YELP_FULL': 512,
        'QNLI': 128,
        'QQP': 128,
        'STS-B': 128,
        'SST-2_48': 48,
        'SST-2_32': 32,
        'SST-2_16': 16,
        'MNLI_M_112': 112,
        'MNLI_M_96': 96,
        'MNLI_M_80': 80,
        'MNLI_M_64': 64,
        'MNLI_M_48': 48,
    }
    res_retention_config = {}
    if method == "poly":
        base = MAX_SEQ_LENGTH_PER_DATASET.get(task, 128)
        slope = float(last_num_token-base) / 121.0
        # for i in range(11):
        #     #for i in [0, 3, 6, 9, 10]:
        #     if i == 10:
        #         res_retention_config["layer_" + str(i)] = last_num_token
        #     else:
        #         res_retention_config["layer_"+str(i)] = int(slope*(i+1)**2 + base)

        base = int(base*0.8)
        res_retention_config["layer_0"] = base
        base = int(base*0.7)
        res_retention_config["layer_3"] = base
        base = int(base * 0.6)
        res_retention_config["layer_6"] = base
        res_retention_config["layer_10"] = last_num_token
        print(f"for {task}, the config is: {res_retention_config} ")


    elif method == "linear":
        b = MAX_SEQ_LENGTH_PER_DATASET.get(task, 128)
        slope = float(last_num_token - b) / 11.0
        for i in range(11):
            if i == 10:
                res_retention_config["layer_" + str(i)] = last_num_token
            else:
                res_retention_config["layer_"+str(i)] = int(slope*(i+1) + b)
        print(f"for {task}, the config is: {res_retention_config} ")

    elif method == 'aggressive':
        b = MAX_SEQ_LENGTH_PER_DATASET.get(task, 128)
        res_retention_config = {}
        res_retention_config['layer_0'] = int(b / 2)
        res_retention_config['layer_1'] = int(b / 4)
        #res_retention_config['layer_2'] = int(b / 8)
        # res_retention_config['layer_4'] = int(b / 4)
        # res_retention_config['layer_5'] = int(b / 4)
        # res_retention_config['layer_6'] = int(b / 4)
        # res_retention_config['layer_7'] = int(b / 4)
        # res_retention_config['layer_8'] = int(b / 4)
        # res_retention_config['layer_9'] = int(b / 4)
        # res_retention_config['layer_10'] = int(b / 4)
    elif method == '1layer':
        b = MAX_SEQ_LENGTH_PER_DATASET.get(task, 128)
        res_retention_config = {}
        res_retention_config['layer_1'] = int(b / 2)
    elif method == "expo_decay":
        b = MAX_SEQ_LENGTH_PER_DATASET.get(task, 128)
        expo_coefficient = -(final_layer_idx+1) / math.log(last_num_token)
        for i in range(final_layer_idx+1):
            num_tokens_2keep = int(b * (math.exp(-(i+1) / expo_coefficient)))
            if num_tokens_2keep < 3:
                break
            else:
                res_retention_config[f"layer_{i}"] = num_tokens_2keep
    elif method == 'rand_seq':
        random_seed = last_num_token
        np.random.seed(random_seed)
        b = MAX_SEQ_LENGTH_PER_DATASET.get(task, 128)
        rand_int_res = np.random.choice(range(3, b), size=11, replace=False)
        rand_int_res = sorted(rand_int_res, reverse=True)
        for i in range(11):
            res_retention_config[f"layer_{i}"] = int(rand_int_res[i])

    else:
        raise(f"method {method} is not implemented yet.")

    return res_retention_config


def get_kwargs(
        task,
        num_epochs,
        lr,
        train_batch_size,
        reduction_method,
        retention_config,
        experiment_slug,
        kcenters_param_s,
        each_seq_len_noprune,
):
    assert reduction_method in ('ATT_ONLY', 'CLS_NO_DUMMY', 'KMEANSPLUS', 'CLS_ATT', 'RANDOM', 'FIRST_X', 'FIRSTX_NO_PRUNE')
    kwargs = {
        'mode': 'train_and_eval',
        #'input_meta_data_path': '/datasets/COLA_meta_data',
        #'train_data_path': '/datasets/COLA_train.tf_record',
        #'eval_data_path': '/datasets/COLA_eval.tf_record',
        'bert_config_file': '/pretrained_model_checkpoints/bert_config.json',
        'init_checkpoint': '/pretrained_model_checkpoints/bert_model.ckpt',
        #'train_batch_size': 48,
        'eval_batch_size': 128,
        'steps_per_loop': 1,
        #'learning_rate': 2e-5,
        #'num_train_epochs': 3,
        #'model_dir': '/output/COLA/876/sfsf',
        'distribution_strategy': 'mirrored',
        #'cos_sim_threshold': 0.65,
        #'layer_to_token_reduce': ['layer_9']
    }

    train_batch_size = BATCH_SIZES_PER_DATASET.get(task, train_batch_size)
    lr = LR_PER_DATASET.get(task, lr)
    num_epochs = NUM_EPOCHS_PER_DATASET.get(task, num_epochs)

    task_note = task
    varies_seq_note_noprune = ''
    if each_seq_len_noprune != '':
        assert retention_config == {}
        task = task + '_seql_' + str(each_seq_len_noprune)
        varies_seq_note_noprune = task + '_'

    input_meta_data_path = f"/datasets/{task}_meta_data"
    train_data_path = f"/datasets/{task}_train.tf_record"
    eval_data_path = f"/datasets/{task}_eval.tf_record"

    kwargs.update(input_meta_data_path=input_meta_data_path)
    kwargs.update(train_data_path=train_data_path)
    kwargs.update(eval_data_path=eval_data_path)

    kwargs.update(train_batch_size=train_batch_size)
    kwargs.update(learning_rate=lr)
    kwargs.update(num_train_epochs=num_epochs)
    kwargs.update(reduction_method=reduction_method)

    if isinstance(retention_config, tuple):

        if len(retention_config) == 2:
            reduction_note = retention_config[0] + '-' + str(retention_config[1])
            retention_config = get_retention_config_multiple_reduction(
                task=task,
                method=retention_config[0],
                last_num_token=retention_config[1],
            )
        if len(retention_config) == 3:
            reduction_note = retention_config[0] + '-layer' + str(retention_config[1]).replace(".", "-") + '-pro' +str(retention_config[2])
            retention_config = get_retention_config_multiple_reduction(
                task=task,
                method=retention_config[0],
                last_num_token=retention_config[1],
                final_layer_idx=retention_config[2],
            )
    else:
        assert isinstance(retention_config, dict) is True
        retention_config_list = list(retention_config.items())
        if len(retention_config_list) == 1:
            reduction_note = f"cos{int(retention_config_list[0][1] * 100)}_layer{'_'.join(retention_config_list[0][0].replace('layer_', ''))}"
        elif len(retention_config_list) == 0:
            reduction_note = f"cos{100}_layer{'_'.join(['layer_NO'.replace('layer_', '')])}"
        else:
            raise("Multiple reduction configs are found. Used the (method_name, final_num_token_to_reduce) apporach.")

    kwargs.update(retention_config=retention_config)
    kwargs.update(kcenters_param_s=kcenters_param_s)

    if isinstance(kcenters_param_s, float):
        kcenters_param_s_marker = str(kcenters_param_s).replace('.', '_')
    else:
        assert isinstance(kcenters_param_s, int) is True
        kcenters_param_s_marker = str(kcenters_param_s)

    model_dir = f"/output/{reduction_method}/{task_note}/" \
                f"{MARKER}{varies_seq_note_noprune}{reduction_method}_epo{num_epochs}_lr{'{:.1e}'.format(lr).replace('.0', '')}_B{train_batch_size}_{reduction_note}_kcenters_{kcenters_param_s_marker}/" \
                f"{experiment_slug}"

    kwargs.update(model_dir=model_dir)

    return kwargs, '-'.join(model_dir.split('/')[3:])

if __name__ == '__main__':

  for _ in range(N_REPEAT):
      for task in TASKS:
        for num_epochs in NUM_EPOCHS:
            for lr in LRs:
                for train_batch_size in TRAIN_BATCH_SIZES:
                    for reduction_method in REDUCTION_METHOD:
                        for retention_config in RETENTION_CONFIG:
                            if reduction_method != 'KMEANSPLUS':
                                KCENTERS_PARAM_S = [0]
                            for kcenters_param_s in KCENTERS_PARAM_S:
                                if reduction_method == 'FIRSTX_NO_PRUNE':
                                    assert retention_config == {}
                                    varies_seq_len_list_noprune = MAX_SEQ_LEN_FINDER[MAX_SEQ_LENGTH_PER_DATASET.get(task, 128)]
                                else:
                                    varies_seq_len_list_noprune = ['']

                                for each_seq_len_noprune in varies_seq_len_list_noprune:

                                    experiment_slug = datetime.now().strftime('%b%d_%H-%M-%S-%f')
                                    kwargs, job_name = get_kwargs(
                                        task=task,
                                        num_epochs=num_epochs,
                                        lr=lr,
                                        train_batch_size=train_batch_size,
                                        reduction_method=reduction_method,
                                        retention_config=retention_config,
                                        experiment_slug=experiment_slug,
                                        kcenters_param_s=kcenters_param_s,
                                        each_seq_len_noprune=each_seq_len_noprune,
                                    )
                                    run_script_with_kwargs(
                                        'official.nlp.bert.run_classifier_customize',
                                         kwargs,
                                         session_name=job_name,
                                    )