import os
import copy
import pandas
import pickle
import torch
import numpy as np
import transformers
import torch.nn as nn
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from datasets import load_dataset, load_metric
import datasets
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

from .prune import *
from .adapter_model import Adapter, AdapterModel
from .trainer import *

class Glue_trainer():
    def __init__(self, adapter_size, model_path, prune_persent, prune_iter, save_dir):
        '''
        adapter size
        model
        prune percentage
        prune iteration
        training args
        save directory
        '''
        
        self.adapter_size = adapter_size
        self.model_path = model_path
        self.prune_persent = prune_persent
        self.prune_iter = prune_iter
        self.save_dir = save_dir
        
        
    def saveInitAdapter(self,state_dict):
        save_dict = {}
        for key, value in state_dict.items():
            if 'adapter' in key or 'classifier' in key:
                save_dict[key] = value
        with open(f'{self.save_dir}/init_adapter_clsfier.pkl','wb') as f:
            pickle.dump(save_dict,f)
    def setTrainingArgs(self, args):
        self.args = args
        
    # Create dir for the training
    def create_save_dir_with_log(self, target):
        os.makedirs(self.save_dir,exist_ok=True)
        with open(f'{self.save_dir}/config.txt','x',encoding='utf-8') as file:
            file.write(f'Adapter_size = {self.adapter_size}\n')
            file.write(f'prune_iter = {self.prune_iter}\n')
            file.write(f'persent = {self.prune_persent}\n')
            file.write(f'prune target = {target}\n')

    # Convert datasets sentence to ids
    def get_dataset(self, task,dataset,tokenizer):
        task_to_keys = {
            "cola": ("sentence", None),
            "mnli": ("premise", "hypothesis"),
            "mnli-mm": ("premise", "hypothesis"),
            "mrpc": ("sentence1", "sentence2"),
            "qnli": ("question", "sentence"),
            "qqp": ("question1", "question2"),
            "rte": ("sentence1", "sentence2"),
            "sst2": ("sentence", None),
            "stsb": ("sentence1", "sentence2"),
            "wnli": ("sentence1", "sentence2"),
        }
        sentence1_key, sentence2_key = task_to_keys[task]


        def preprocess_function(examples):
            if sentence2_key is None:
                return tokenizer(examples[sentence1_key], truncation=True)
            return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)

        return dataset.map(preprocess_function, batched=True)
    
    # Preparing dataset and metric
    def prepare_dataset(self,task,cache_dir):
        
        print("Loading Dataset...")
        actual_task = "mnli" if task == "mnli-mm" else task
        dataset = datasets.load_from_disk(cache_dir)
        metric = load_metric('glue.py', actual_task)
        
        print("Setting tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=True)
        print("Preprocessing datasets...")
        
        
        return self.get_dataset(task, dataset, tokenizer), metric, tokenizer
    
    # Setting num_labels, metric_name, validation_key
    def setTask(self, task):
        self.num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2

        self.metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy"
        
        self.validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
    
    # Load weight loadAdapterWithClassifier(path,model)
    # Load mask & Load Prune_lst
    def load_params(self,path):
        with open(f'{path}','rb') as f:
            data = pickle.load(f)
        return data
    
    
    # Train the model (Prune weights and neurons)
    def train_adapter_llt(self, task, target, args, encoded_dataset, metric, tokenizer,start_iter=0,checkpoint=False):
        '''
        task: GLUE tasks
        args: TrainingArguments settings
        target: 'weight' or 'neuron'
        '''
        assert target == 'weight' or target == 'neuron' , "Prune Target should be \'weight\' or \'neuron\'"
        # set compute metrics
        def compute_metrics(eval_pred):
            predictions, labels = eval_pred
            if task != "stsb":
                predictions = np.argmax(predictions, axis=1)
            else:
                predictions = predictions[:, 0]
            return metric.compute(predictions=predictions, references=labels)

        

        print(f"Start training Lottery Ticke in {target}")
        torch.manual_seed(args.seed)
        model = AdapterModel(self.model_path, self.num_labels, self.adapter_size)
#         def init_model():
#             model = AdapterModel(self.model_path, self.num_labels, self.adapter_size)
#             return model

        best_result = []
        persent_lst = []
        initial_state_dict = None
        cur_state_dict = None
        mask = {}
        
        if start_iter != 0 :
            print('Resume trainig at',start_iter)
            print("Loading mask")
            mask = self.load_params(f'{self.save_dir}/Iter{start_iter}/mask.pkl')
            print("Loading initial weight")
            loadAdapterWithClassifier(f'{self.save_dir}/init_adapter_clsfier.pkl',model)
            initial_state_dict = copy.deepcopy(model.state_dict())
        else:
            # new training , create dir and log
            self.create_save_dir_with_log(target)
            
        warmup = args.warmup_steps
        args.warmup_steps = 0

        for iter in range(start_iter, self.prune_iter):
            print("Iter",iter,'Start')
            # Set saving dir for every iter
            args.output_dir = self.save_dir+"/Iter"+str(iter)
            args.logging_dir = args.output_dir+'/runs'
            args.warmup_steps = warmup if iter!=0 else 0
            # Setting Trainer
            trainer = LTTrainer(
            model,
#             model_init = init_model,
            args = args,
            train_dataset=encoded_dataset["train"],
            eval_dataset=encoded_dataset[self.validation_key],
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
            callbacks=[transformers.trainer_callback.EarlyStoppingCallback(early_stopping_patience=3)]

          )

            # IF this is second time: 1. Prune the model, 2. Reinitial param
            if iter != 0:
                # Reload param
                print('Trainer warmup',trainer.args.warmup_steps)
                    
                # Reinitialize param
                original_initialization(mask, initial_state_dict, trainer.model)
            else:
                # Create mask
                mask = make_mask(trainer.model)
                # Save initial param
                initial_state_dict = copy.deepcopy(trainer.model.state_dict())
                self.saveInitAdapter(initial_state_dict)

            # Print and Get NonZeros
            nonzeros_persent = persentage_nonzeros(trainer.model)
            print("Prune Iter:",iter,"Non zero:",str(nonzeros_persent)+"%")

            # Get persentage
            remain_persent = print_nonzeros(trainer.model)
            persent_lst.append(remain_persent)

            # Save mask and remain persentage
            with open(f'{args.output_dir}/mask.pkl', 'wb') as f:
                pickle.dump(mask,f)
            with open(f'{args.output_dir}/remain_persent.txt','w',encoding='utf-8') as file:
                file.write(f'Remain persent = {remain_persent}\n')

            # Asign mask, start training
            trainer.mask = mask
            trainer.train(resume_from_checkpoint = start_iter!=0 and iter==start_iter if checkpoint else None)
            
            # Modify mask
            if target == 'weight':
                prune_by_percentile_global(self.prune_persent, trainer.model, mask)
            else:
                prune_by_percentile_node_global(self.prune_persent, trainer.model, mask)

            # Save trainer evaluation
            best_tmp = trainer.evaluate()
            best_result.append(best_tmp)

            # Remove param
            removeModel(args.output_dir)

        with open(f'{self.save_dir}/best_result.pkl', 'wb') as f:
            pickle.dump(best_result,f)
        with open(f'{self.save_dir}/persent_lst.pkl', 'wb') as f:
            pickle.dump(persent_lst,f)
            
    # Prune the adapter layer
    def train_adapter_layer(self, task, target, args, encoded_dataset, metric, tokenizer, start_iter = 0):
        
        # set compute metrics
        def compute_metrics(eval_pred):
            predictions, labels = eval_pred
            if task != "stsb":
                predictions = np.argmax(predictions, axis=1)
            else:
                predictions = predictions[:, 0]
            return metric.compute(predictions=predictions, references=labels)

        

        print("Start training Lottery ticket in Layer")
        torch.manual_seed(args.seed)
        model = AdapterModel(self.model_path, self.num_labels, self.adapter_size)
#         def init_model():
#             model = AdapterModel(self.model_path, self.num_labels, self.adapter_size)
#             return model
        # Set prune list,
        prune_lst = [[False,False] for i in range(12)]
        initial_state_dict = None
        
        # Resume
        if start_iter != 0:
            print('Resume training at',start_iter)
            print('loading prune list')
            prune_lst = self.load_params(f"{self.save_dir}/Iter{start_iter}/cur_prune_lst.pkl")
            print('loading adapter and classifier modules')
            loadAdapterWithClassifier(f"{self.save_dir}/init_adapter_clsfier.pkl",model)
            initial_state_dict = copy.deepcopy(model.state_dict())
        else:
            self.create_save_dir_with_log('layer')
            
        warmup = args.warmup_steps
        args.warmup_steps = 0

        # pruneAdapter(model.model.bert.encoder.layer, prune_lst)

        
        lst = []
        best_result = []
        
        

        for i in range(start_iter, self.prune_iter):
            print("Iter",i,'Start')
            
            # Set saving dir for every iter
            args.output_dir = self.save_dir+"/Iter"+str(i)
            args.logging_dir = args.output_dir+'/runs'
            args.warmup_steps = warmup if i!=0 else 0
            
#             if i != 0 :
#                 args.num_train_epochs = 5 
                
            
            trainer = AdapterTrainer(
              model,
#               model_init = init_model,
              args = args,
              train_dataset=encoded_dataset["train"],
              eval_dataset=encoded_dataset[self.validation_key],
              tokenizer=tokenizer,
              compute_metrics=compute_metrics,
              callbacks=[transformers.trainer_callback.EarlyStoppingCallback(early_stopping_patience=3)]

            )
            if i == 0:
                initial_state_dict = copy.deepcopy(trainer.model.state_dict())
                self.saveInitAdapter(initial_state_dict)
            else:
                trainer.model.load_state_dict(initial_state_dict,strict=False)
                
            trainer.train()
            # Save trainer evaluation
            best_tmp = trainer.evaluate()
            best_result.append(best_tmp)
            # Save mask
            with open(f'{args.output_dir}/cur_prune_lst.pkl', 'wb') as f:
                pickle.dump(prune_lst,f)
            #  Get L1 of adapters, prune adapters
            if i!=self.prune_iter-1:
                dic = adapterL1(trainer.model.model.bert.encoder.layer,prune_lst)
                lst.append(dic)
                if not updatePruneLst(self.prune_persent/100, dic, prune_lst):
                    break
                print("Total Pruned Adapters:",sum([sum(item) for item in prune_lst]))
                pruneAdapter(trainer.model.model.bert.encoder.layer, prune_lst)
                print("Acc after Prune",trainer.evaluate())
            removeModel(args.output_dir)
            
        with open(f'{self.save_dir}/prune_dic_lst.pkl', 'wb') as f:
            pickle.dump(lst,f)
        with open(f'{self.save_dir}/best_result.pkl', 'wb') as f:
            pickle.dump(best_result,f)
            
    def train_adapter(self, task, args, encoded_dataset, metric, tokenizer):
        # set compute metrics
        def compute_metrics(eval_pred):
            predictions, labels = eval_pred
            if task != "stsb":
                predictions = np.argmax(predictions, axis=1)
            else:
                predictions = predictions[:, 0]
            return metric.compute(predictions=predictions, references=labels)

        

        print("Start training Adapter model baseline")
        
        args.output_dir = self.save_dir
        args.logging_dir = args.output_dir+'/runs'
        def init_model():
            model = AdapterModel(self.model_path, self.num_labels, self.adapter_size)
            return model
        self.create_save_dir_with_log('Origin')
        
        trainer = AdapterTrainer(
              model_init = init_model,
              args = args,
              train_dataset=encoded_dataset["train"],
              eval_dataset=encoded_dataset[self.validation_key],
              tokenizer=tokenizer,
              compute_metrics=compute_metrics,
        )
        trainer.train()
        
        best_result = trainer.evaluate()
        
        
        with open(f'{self.save_dir}/best_result.pkl', 'wb') as f:
            pickle.dump(best_result,f)
            
        removeModel(args.output_dir)
        
