#!/usr/bin/env python
# coding: utf-8

# In[1]:


import matplotlib.pyplot as plt
import numpy as np
import pandas
import pickle
import transformers
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import copy
from util.prune import *
from util.trainer import *
from util.adapter_model import Adapter, AdapterModel
from util.training_glue import *
import argparse
import json
parser = argparse.ArgumentParser()


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--target',type=str,default='layer',choices=['base','weight','neuron','layer'],help='Type of training')
    parser.add_argument('--adapter_size',type=int,default=128,help='Size of adapter')
    parser.add_argument('--prune_persent',type=int,default=20,help="Percentage to prune each iteration")
    parser.add_argument('--lr',type=float,default=5e-4,help="learning rate")
    parser.add_argument('--model_save_dir',type=str,help="Cache model directory")
    parser.add_argument('--resume_iter',type=int,default=0,help="Iteration to resume training")
    parser.add_argument('--use_checkpoint',type=bool,default=False,help="To resume model in checkpoint")
    parser.add_argument('--seed',type=int,default=36,help="Random Seed in the model")
    parser.add_argument('--tasks_file',type=str,default='tasks.txt',help="The file contain list of tasks to train on")



    pargs = parser.parse_args()
    target = pargs.target
    adapter_size = pargs.adapter_size
    prune_persent = pargs.prune_persent
    lr = pargs.lr
    model_save_dir = pargs.model_save_dir
    resume_iter = pargs.resume_iter
    checkpoint = pargs.use_checkpoint
    seed = pargs.seed
    prune_iter = 15 if target != 'layer' else 12

    with open(pargs.tasks_file,'r') as f:
        GLUE_TASKS = json.load(f)
        print(GLUE_TASKS)

    if model_save_dir is None:
        model_save_dir = f'../adapter-{target}'

    # # model_checkpoint = "distilbert-base-uncased"
    model_checkpoint = "bert-base-uncased"

    

    for task in GLUE_TASKS:
        # Load parameters:
        
        batch_size = 32 if task!='mnli' and task!='mnli-mm' and task!='qnli'and task!='rte' else 16
        epoch = 4 if task !='mrpc' and task!='wnli' else 5
            

        print("Training settings:",task,"Batch size",batch_size,'lr',lr,'epoch',epoch,'seed',seed)
        
        glue_trainer = Glue_trainer(
            adapter_size = adapter_size,
            model_path = model_checkpoint, 
            prune_persent = prune_persent, 
            prune_iter = prune_iter, 
            save_dir= f'{model_save_dir}/{task}/{lr}-{batch_size}'

        )

        # Setting num_labels, metric_name, validation_key
        glue_trainer.setTask(task)
        # Load datasets
        encoded_dataset, metric, tokenizer = glue_trainer.prepare_dataset(task)
        steps_per_epoch = encoded_dataset.num_rows['train'] / batch_size
        # Set warmup steps
        warmup = steps_per_epoch if steps_per_epoch < 1e4 else 1e4 if target!='base' else 0
        print("warmup",warmup,'steps per epoch',steps_per_epoch)
        
        
        args = TrainingArguments(
            "test-glue",
            evaluation_strategy = "epoch",
            learning_rate=lr,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=epoch,
            weight_decay=0.01,
            load_best_model_at_end=True,
            metric_for_best_model=glue_trainer.metric_name if target=='base' else 'loss',
            seed = seed,
            warmup_steps = warmup,
            eval_steps = steps_per_epoch //2 // (1 if task!='mnli' and task!='mnli-mm' and task!='qnli'and task!='rte' else 2),
            gradient_accumulation_steps = 1 if task!='mnli' and task!='mnli-mm' and task!='qnli'and task!='rte' else 2
        )
        
        if target == 'base':
            print("Running Adapter baseline model")
            glue_trainer.train_adapter(task,args, encoded_dataset, metric, tokenizer,resume_iter)
        elif target == 'layer':
            glue_trainer.train_adapter_layer(task,target, args , encoded_dataset, metric, tokenizer,resume_iter)
        else:
            glue_trainer.train_adapter_llt(task,target, args , encoded_dataset, metric, tokenizer,resume_iter,checkpoint)


if __name__ == '__main__':
    main()
