import torch
import transformers
import torch.nn as nn
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import pickle
import os

from transformers import PreTrainedModel, WEIGHTS_NAME


### path: .../adapter.pkl
def loadAdapterWithClassifier(path,model):
    with open(path,'rb') as f:
        state = pickle.load(f)
    for name, p in model.named_parameters():
        if p.requires_grad:
            p.data = state[name]
def loadAdapter(path,model):
    with open(path,'rb') as f:
        state = pickle.load(f)
    for name, p in model.named_parameters():
        if p.requires_grad and 'adapter' in name:
            p.data = state[name]
            
def removeModel(path):
    names = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path,name))]
    for name in names:
        cur = path+'/'+name+'/'+'pytorch_model.bin'
        cur2 = path+'/'+name+'/'+'scheduler.pt'
        if os.path.isfile(cur):
            print('Remove file',cur)
            os.remove(cur)
        if os.path.isfile(cur2):
            print('Remove file',cur2)
            os.remove(cur2)
class SaverTraniner(transformers.Trainer):
    def __init__(
        self,
        model = None,
        args = None,
        data_collator = None,
        train_dataset = None,
        eval_dataset = None,
        tokenizer  = None,
        model_init = None,
        compute_metrics = None,
        callbacks = None,
        optimizers = (None, None),
    ):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        # logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            # logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
            state_dict = self._get_adapter_and_classifier_dict(state_dict)

            print("Saving Adapter and classifier")
            with open(output_dir+'/adapter.pkl', 'wb') as f:
                  pickle.dump(state_dict,f)


        else:
            self.model.save_pretrained(output_dir)
        if self.tokenizer is not None and self.is_world_process_zero():
            self.tokenizer.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
    def _get_adapter_and_classifier_dict(self, state_dict):
        save_dict = {}
        for key, value in state_dict.items():
            if 'adapter' in key or 'classifier' in key:
                save_dict[key] = value
        return save_dict
    def get_saved_dict(self):
        with open(self.args.output_dir+ '/adapter.pkl','rb') as f:
              state = pickle.load(f)
    
    
class AdapterTrainer(SaverTraniner):
    def __init__(
        self,
        model = None,
        args = None,
        data_collator = None,
        train_dataset = None,
        eval_dataset = None,
        tokenizer  = None,
        model_init = None,
        compute_metrics = None,
        callbacks = None,
        optimizers = (None, None)
    ):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers)
        
        
        
class LTTrainer(SaverTraniner):
    def __init__(
        self,
        model = None,
        args = None,
        data_collator = None,
        train_dataset = None,
        eval_dataset = None,
        tokenizer  = None,
        model_init = None,
        compute_metrics = None,
        callbacks = None,
        optimizers = (None, None)
    ):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers)
        self.mask = None
    
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
          model (:obj:`nn.Module`):
              The model to train.
          inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
              The inputs and targets of the model.

              The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
              argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
          :obj:`torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        if self.use_amp:
            with autocast():
                loss = self.compute_loss(model, inputs)
        else:
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.use_amp:
            self.scaler.scale(loss).backward()
        elif self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.deepspeed:
            self.deepspeed.backward(loss)
        else:
            loss.backward()
        self.wipe_out_grad()
        return loss.detach()
    # Clear the Gradient of pruned weights
    def wipe_out_grad(self):
        for name, p in self.model.named_parameters():
            if 'weight' in name and p.requires_grad and 'dense' in name:
                #mask = torch.from_numpy(self.mask[name]).to(self.args.device)
                #p.grad.data = p.grad.data * mask
                grad_tensor = p.grad.data.cpu().numpy()
                grad_tensor = np.where(self.mask[name] == 0, 0, grad_tensor)
                p.grad.data = torch.from_numpy(grad_tensor).to(self.args.device)
