import collections
from datasets import load_from_disk, load_metric
import random
from transformers import AutoTokenizer, AutoModelWithHeads, TrainingArguments, Trainer, EvalPrediction
from transformers.adapters.composition import Fuse
from collections import defaultdict
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import pickle
import os
import wandb
from fvcore.nn import FlopCountAnalysis, flop_count_table
import multiprocessing
# os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_DIR"] = "/tmp/"

class PRUNING_STRATEGY:
    LAYER = 'Layer'
    WEIGHT = 'Weight'
    NUERON = 'Neuron'
    NONE = 'Origin'

name_look_up = {
        'cola':'CoLA',
        'mnli':'MNLI-m',
        'mnli-mm':'MNLI-mm',
        'mrpc':'MRPC',
        'qnli':'QNLI',
        'qqp':'QQP',
        'rte':'RTE',
        'sst2':'SST-2',
        'stsb':'STS-B',
        'wnli':'WNLI',
    }
def set_seed(seed):
    if not seed:
        seed = torch.initial_seed() % 2**32
        print("Seed",seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def get_encode_batch(tokenizer, sentence1_key, sentence2_key):
    def encode_batch(examples):
        """Encodes a batch of input data using the model tokenizer."""
        if sentence2_key is None:
            return tokenizer(examples[sentence1_key], truncation=True)
        return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
    return encode_batch


def load_adapters(directory, af_adapters, model, with_head = False):
    for task in af_adapters:
        model.load_adapter(directory+'/'+task, with_head=with_head, overwrite_ok=True)


def get_compute_metrics(metric, task):
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        if task !='stsb':
            predictions = np.argmax(predictions, axis=1)
        else:
            predictions = predictions[:,0]
        return metric.compute(predictions=predictions, references=labels)
    return compute_metrics

def importance(dic, af=True):
    
    origin_out = dic['origin_out']
    combined = dic['combined']

    if af:
        # values * scores
        values = dic['values']
        scores = dic['scores']
        adapters_output = torch.unsqueeze(scores, dim=3) * values
    else:
        adapters_output = torch.unsqueeze(dic['adapter_out'], dim=2)

    origin_out = torch.unsqueeze(origin_out, dim=2)
    combined = torch.unsqueeze(combined, dim=3)

    # dot products
    origin_dot_product = torch.matmul(origin_out, combined)
    adapter_dot_product = torch.matmul(adapters_output, combined)
    

    adapter_dot_product = torch.squeeze(adapter_dot_product, dim=3)
    origin_dot_product = torch.squeeze(origin_dot_product, dim=3)

    # importance
    dot_products = torch.cat((origin_dot_product, adapter_dot_product), dim=-1)
    projection_percent = dot_products/ torch.sum(combined**2, dim=2)
    
    return projection_percent.mean(dim=(0,1)).cpu()

def get_in_out(name, layer, connections, imp_dict, is_af):
    def ori_hook(self, input, output):
        origin_out = input[-1].data.detach()
        combined = output.data.detach() if is_af else output[0].data.detach() 
        connections[(name, layer)]['origin_out'] = origin_out 
        connections[(name, layer)]['combined'] = combined
        if 'importance' not in imp_dict[(name, layer)]:
            imp_dict[(name, layer)]['importance'] = importance(connections[(name, layer)], af=is_af).unsqueeze(dim=-1)
        else:
            imp_dict[(name, layer)]['importance'] = torch.cat((imp_dict[(name, layer)]['importance'], importance(connections[(name, layer)], af=is_af).unsqueeze(dim=-1)), dim=-1)
    return ori_hook

def get_output(name, layer, variable, connections):
    def output_hook(self, input, output):
        connections[(name, layer)][variable] = output.data.detach()
    return output_hook

def add_register(module, name, layer, connections, imp_dict, af=True):
    module_dict = module.adapter_fusion_layer
    handle = []
    if af and len(list(module_dict.keys())):
        adapters = list(module.adapters.keys())
        imp_dict[(name, layer)]['adapter_names'] = adapters
        af_module = module_dict[list(module_dict.keys())[0]]

        handle_scores  = af_module.softmax.register_forward_hook(get_output(name, layer, 'scores', connections))
        handle_values  = af_module.value.register_forward_hook(get_output(name, layer, 'values', connections))
        handle_af  = af_module.register_forward_hook(get_in_out(name, layer, connections, imp_dict, is_af = True))

        handle = [handle_scores, handle_values, handle_af]

    elif len(module.adapters.keys())!=0:
        adapters = list(module.adapters.keys())
        imp_dict[(name, layer)]['adapter_names'] = adapters
        adapter_module = module.adapters[adapters[0]]
        handle_adapter = adapter_module.adapter_up.register_forward_hook(get_output(name, layer, 'adapter_out', connections))
        handle_ori = adapter_module.register_forward_hook(get_in_out(name, layer, connections, imp_dict, is_af = False))

        handle = [handle_adapter, handle_ori]

    return handle

### RUNING TEST SET

def convert_predict_result(predict,name_lst, task):
  '''
  predict : get prediction list from trainer.predict --> predictions
  '''
  if task == 'stsb':
    # return the float
    return predict
  else:
    # get the argmax in list
    argmax_lst = predict.argmax(axis=1).tolist()
    if task not in ['mnli','mnli-mm','qnli','rte']:
      # return the max value
      return argmax_lst
    else:
      # return the label list
      return list(map(lambda x: name_lst[x],argmax_lst))

def optimal_num_of_loader_workers():
    num_cpus = multiprocessing.cpu_count()
    num_gpus = torch.cuda.device_count()
    optimal_value = min(num_cpus, num_gpus*4) if num_gpus else num_cpus - 1
    return optimal_value

def test(task, 
         af_adapters, 
         model_config,
         save_model_path = '/tmp/',
         model_checkpoint = 'bert-base-uncased', 
         pruning_strategy = PRUNING_STRATEGY.NONE, 
         wandb_run = None,
         cal_imp = True,
         imp_training = True,
         cal_flops = True,
         do_train = False,
         do_eval = True,
         save_eval = True,
         gen_submit = True,
         skip_layers = None,
         
         ):
    """Test run of model AF or ST-A

    Args:
        task (str):                         Target task name
        af_adapters (List):                 A list of adapters in AF model, if it only has one adapter, model will be ST-A
        model_config (Dict):                Model configuration (learning rate, epoch)
        save_model_path (str, optional):    Path to save model. Defaults to '/tmp/'.
        model_checkpoint (str, optional):   Model checkpoint. Defaults to '../Code/bert2'.
        pruning_strategy (str, optional):   Type of pruning strategy (none, layer, weight, neuron). Defaults to PRUNING_STRATEGY.NONE.
        wandb_run (wandb.Run, optional):    Run object of wandb, if set none result will be saved locally. Defaults to None.
        cal_imp (bool, optional):           Calculate the IMP of adapters or not. Defaults to True.
        imp_training (bool, optional):      Use part of training set to calculate IMP or not. Defaults to True.
        cal_flops (bool, optional):         Calculate the FLOPS of model or not. Defaults to True.
        do_train (bool, optional):          Training the model. Defaults to False.
    """

    _is_af = len(af_adapters)!=1
    _use_prune_adapter = pruning_strategy!=PRUNING_STRATEGY.NONE
    _is_super_glue = task in ["boolq", "cb", "copa", "multirc", "record",  "wic", "wsc", "wsc.fixed", "axb", "axg",] # exclusive rte

    _task_type ="super_glue" if _is_super_glue else "glue"
    print(f"Start training {task}-{_task_type}, with adapters {af_adapters}")
    print(f'Model saved in {save_model_path}')
    if _use_prune_adapter:
        print("Using LTH model in AF")
    if _is_af:
        print("Using Adapter fusion")
        
    if not save_model_path.endswith('/'):
        save_model_path = save_model_path + '/'

    if _is_super_glue:
        dataset_cache_dir = '/mnt/Code/super_glue/'+task
        metric_script = '/mnt/Code/util/super_glue.py'
    else:
        dataset_cache_dir = '/mnt/Code/dataset/'+task
        metric_script = '/mnt/Code/util/glue.py'


    actual_task = "mnli" if task == "mnli-mm" else task

    print('Loading metric and datasets')

    metric = load_metric(metric_script, actual_task, cache_dir = './dataset')
    obtain_data = load_from_disk(dataset_cache_dir)

    task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mnli-mm": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp":  ("question1", "question2"),
        "rte":  ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
        'cb':   ("premise", "hypothesis"),
    }
    sentence1_key, sentence2_key = task_to_keys[task]

    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,use_fast=True)
    # Encode the input data
    dataset = obtain_data.map(get_encode_batch(tokenizer, sentence1_key, sentence2_key), batched=True)
    print('Dataset is ready')

    # # Load from saved adapter
    print('Load from saved adapter')


    
    
    num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2

    model = AutoModelWithHeads.from_pretrained(model_checkpoint)
    # Add a classification head for our target task
    model.add_classification_head(task, num_labels=num_labels, layers=1, overwrite_ok=True, use_pooler=True)
    save_adapters_path = 'adapters_for_af'if _use_prune_adapter else 'adapters_for_af_base'
    load_adapters(save_adapters_path, af_adapters, model, with_head = not _is_af)

    if _is_af:
        # Add a fusion layer for all loaded adapters
        print("Loading adapterfusion layer")
        model.add_adapter_fusion(Fuse(*af_adapters))
        adapter_setup = Fuse(*af_adapters)
        model.train_adapter_fusion(adapter_setup)
        model.set_active_adapters(Fuse(*af_adapters))
    else:
        model.set_active_adapters(task,skip_layers=skip_layers)

    print('Model finished setup')

    metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy"
    validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
    batch_size = 8 if task=='qnli' or task=='mnli' or task=='mnli-mm' else 16 if task == 'rte'  or task == 'qqp'  else 32
    gradient_accumulation_steps = 4 if task=='mnli' or task=='mnli-mm' or task=='qnli'else 2 if task == 'qqp' or task == 'rte' else 1

    lr = model_config['lr']
    epochs = model_config['epoch']
    fp16 = model_config['fp16']
    num_gpu = model_config['num_gpu']
    seed = model_config['seed']
    
    if wandb_run:
        wandb_run.config.seed = seed
        wandb_run.config.task = task

    training_args = TrainingArguments(
        learning_rate = lr,
        num_train_epochs = epochs,
        per_device_train_batch_size=batch_size * num_gpu,
        per_device_eval_batch_size=batch_size * num_gpu,
        logging_steps=50,
        # logging_strategy='steps',
        output_dir= save_model_path+task,
        overwrite_output_dir=True,
        remove_unused_columns=True,
        evaluation_strategy='epoch',
        save_strategy='epoch',
        metric_for_best_model="loss",#metric_name
        load_best_model_at_end=True,
        report_to="wandb" if wandb_run else 'none',  # enable logging to W&B 
        fp16 = fp16,
        gradient_accumulation_steps=gradient_accumulation_steps // num_gpu,
        weight_decay=0.01,
        dataloader_num_workers = optimal_num_of_loader_workers()
    )



    trainer = Trainer(
        model = model,
        args = training_args,
        train_dataset = dataset['train'],
        eval_dataset = dataset[validation_key],
        tokenizer = tokenizer,
        compute_metrics = get_compute_metrics(metric, task),

    )

    if not do_train:
        print('Constructed trainer, start training')
        trainer.train()
        print('Training finished')

    if cal_imp:
        print('Adding forward hooks to check connections')
        connections = defaultdict(dict)
        imp_dict = defaultdict(dict)



        handles = []
        for i in range(12):
            handle_att = add_register(trainer.model.bert.encoder.layer[i].attention.output,'attention', i, connections, imp_dict, _is_af)
            handle_out = add_register(trainer.model.bert.encoder.layer[i].output,'output', i, connections, imp_dict, _is_af)
            handles.extend(handle_att)
            handles.extend(handle_out)

    if do_eval:
        eval_result = trainer.evaluate()
        print(eval_result)

    if imp_training:
        data_loader = trainer.get_train_dataloader()
        dataiter = iter(data_loader)
        dataset_limit = min(len(data_loader.dataset)//batch_size, 3000)
        with torch.no_grad():
            for i, data_item in tqdm(zip(range(dataset_limit), dataiter), total=dataset_limit):
                model(data_item['attention_mask'].to(model.device), data_item['input_ids'].to(model.device))


    result_path = './results/'
    result_path += 'AF/' if _is_af else 'ST-A/'
    result_path += pruning_strategy+'/'
    result_path += task
    if not os.path.exists(result_path):
        os.mkdir(result_path)

    if save_eval and do_eval:
        # Write eval_result to file
        eval_path = result_path+'/eval_result_'+ wandb_run.id +'.pkl' if wandb_run else result_path+'/eval_result.pkl'
        
        with open(eval_path, 'wb') as f:
            pickle.dump(eval_result, f)
    print("Finished evaluation")

    if cal_imp:
        print("Finished IMP calculate")
        total_adapters = 0
        used_adapters = 0
        for k,v in imp_dict.items():
            print(k)
            print('Ori\t',end='')
            for name in v['adapter_names']:
                print(name, end='\t')
            total_adapters += len(v['adapter_names'])

            print()
            scores = v['importance'].mean(dim=-1).numpy()
            v['importance'] = scores
            for i in range(len(scores)):
                score = scores[i]
                print(f"{score:.2f}", end='\t')
                if score >= 0.01 and i!=0:
                    used_adapters += 1
            print()

        if total_adapters!=0:
            print("Total adapters",total_adapters, "Used adapters",used_adapters, "Used percent",used_adapters/ total_adapters)
            if wandb_run:
                wandb_run.summary["Used adapters percent"] =  used_adapters/ total_adapters
                wandb_run.summary["Total adapters"] =  total_adapters
                wandb_run.summary["Used adapters"] =  used_adapters

        # Write importance to file
        imp_path = result_path + '/importance_'+ wandb_run.id +'.pkl' if wandb_run else result_path + '/importance.pkl'
        with open(imp_path, 'wb') as f:
            pickle.dump(imp_dict, f)

        print(' Removing hooks')
        for handle in handles:
            handle.remove()
        print('hooks removed')

    if cal_flops:
        print("FLOPS counting")
        data_loader = trainer.get_train_dataloader()
        dataiter = iter(data_loader)
        data_item = next(dataiter)
        input = (data_item['attention_mask'].to(model.device), data_item['input_ids'].to(model.device))
        with torch.no_grad():
            flops = FlopCountAnalysis(model, input)
            total_flops = flops.total()
            print("Total FLOPS",total_flops)
            if wandb_run:
                wandb_run.summary["FLOPS"] =  total_flops
            # print(flop_count_table(flops, max_depth=6, show_param_shapes=False))

        print("FLOPS count finished")

    if gen_submit:
        print("Submission Started")
        submission_file = 'submission/'+pruning_strategy+'/'+str(seed)+'/'
        if not os.path.exists(submission_file):
            os.mkdir(submission_file)

        test_dataset_name = 'test_matched' if task == 'mnli' else 'test_mismatched' if task == 'mnli-mm' else 'test'

        testData = dataset[test_dataset_name].remove_columns("label")
        prediction = trainer.predict(testData)
        
        if task == 'stsb':
            tmp_lst = []
            for item in prediction.predictions.squeeze():
                if item > 5:
                    tmp_lst.append(5)
                elif item < 0:
                    tmp_lst.append(0)
                else:
                    tmp_lst.append(item)
                    
        predict_result = convert_predict_result(prediction.predictions, dataset[test_dataset_name].features['label'].names, task) if task != 'stsb' else tmp_lst
        df = pd.DataFrame({'index': dataset[test_dataset_name]['idx'],'prediction':predict_result})
        df.to_csv(f'{submission_file}/{name_look_up[task]}.tsv',sep='\t',index=False)

        if task == 'mnli':
            # AX
            print("Start AX")
            # SET AX DATASET
            dataset_cache_dir = '/mnt/Code/dataset/ax'
            metric_script = '/mnt/Code/util/glue.py'

            metric = load_metric(metric_script, actual_task, cache_dir = './dataset')
            obtain_data = load_from_disk(dataset_cache_dir)

            sentence1_key, sentence2_key = task_to_keys[task]

            tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,use_fast=True)
            # Encode the input data
            ax_dataset = obtain_data.map(get_encode_batch(tokenizer, sentence1_key, sentence2_key), batched=True)
            
            test_dataset_name = 'test'
            testData = ax_dataset[test_dataset_name].remove_columns("label")
            
            prediction = trainer.predict(testData)
            predict_result = convert_predict_result(prediction.predictions, ax_dataset[test_dataset_name].features['label'].names, task)
            df = pd.DataFrame({'index': ax_dataset[test_dataset_name]['idx'],'prediction':predict_result})
            df.to_csv(f'{submission_file}/AX.tsv',sep='\t',index=False)

        del model
        torch.cuda.empty_cache()
