import math
import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers.trainer_callback import TrainerCallback

class ISTCallback(TrainerCallback):
    def __init__(self, model, dataset, data_collator, ist_mu = 10, ist_beta = 0.25):
        super().__init__()
        self.batch_size = 16 #default for LLM-Adapter
        self.model = model.get_base_model()
        self.dataset = dataset
        self.data_collator = data_collator
        self.dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.data_collator)
        self.dataloader = iter(self.dataloader)
        # Determine the way to access layers based on the model type
        class_to_layers_map = {
            'LlamaForCausalLM': 'model.model.layers',
            'Qwen2ForCausalLM': 'model.model.layers',
            'MistralForCausalLM': 'model.model.layers',
            'MixtralForCausalLM': 'model.model.layers',
            'GemmaForCausalLM': 'model.model.layers',
            'GPT2LMHeadModel': 'model.transformer.h',
        }
        model_class_name = self.model.__class__.__name__
        if model_class_name in class_to_layers_map:
            self.layers_attribute = class_to_layers_map[model_class_name]
        else:
            print('Pls inject the class to layers map manually', model_class_name)
            raise NotImplementedError

        self.total_layers = len(eval('self.' + self.layers_attribute))  # Dynamically execute to get the number of layers
        self.importance_score = torch.zeros(self.total_layers)

        self.mu = ist_mu # importance updating rate
        self.beta = ist_beta # response suppression rate

        self.N_c = 3 # default param
        self.T_c = 10 # default param

        self.N_u = self.total_layers // 4 # default select 25% layers for fine-tuning loop
        self.N_v = self.total_layers // 2 # default select 50% layers for importance updating loop

        self.active_layers_indices = []
        self.trainable_module_name = []
        self.raw_scaling = None
        layers = eval('self.' + self.layers_attribute)
        for idx in range(self.total_layers):
            for name, module in layers[idx].named_modules():
                if hasattr(module, 'scaling'):
                    self.raw_scaling = module.scaling
                if hasattr(module, 'adapter_scaling'):
                    self.raw_scaling = module.adapter_scaling
                if hasattr(module, 'disable_adapters'):
                    for name, param in module.named_parameters():
                        if param.requires_grad and name not in self.trainable_module_name:
                            self.trainable_module_name.append(name)

        print('The name of trainable modules in each layer: ',self.trainable_module_name)

        if self.raw_scaling is not None:
            print(f'Default scaling is {self.raw_scaling}')
        else:
            print(f'Cannot find a scaling factor for response suppression')
            raise NotImplementedError

    def sampling_more_important_selection(self, num):
        prob = self.importance_score.sigmoid()
        select = torch.sort(torch.multinomial(prob, num))[0]
        return select


    def on_step_begin(self, args, state, control, **kwargs):
        # Check if it's time to switch active layers, including at step 0
        if state.global_step % self.T_c == 0 and state.global_step > 0:
            # Importance Updating Loop
            try:
                batch = next(self.dataloader)
            except:
                self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True,
                                             collate_fn=self.data_collator)
                self.dataloader = iter(self.dataloader)
                batch = next(self.dataloader)

            for k, v in batch.items():
                batch[k] = v.cuda()

            selects = []
            losses = []
            for k in range(self.N_c):
                select = self.sampling_more_important_selection(self.N_v)
                selects.append(select)
                self.adapter_response_suppression(select)

                self.model.eval()
                with torch.inference_mode():
                    outputs = self.model(**batch)
                self.model.train()

                loss = outputs.loss
                losses.append(loss.item())

            rewards = [math.exp(-_loss) for _loss in losses]
            _loss_mean = np.mean(rewards)
            rewards = [r-_loss_mean for r in rewards]

            prob = self.importance_score.sigmoid()

            for k in range(self.N_c):
                for i in range(self.total_layers):
                    if i in selects[k]:
                        self.importance_score[i] += self.mu * rewards[k] * prob[i] * (1 - prob[i]) # prob[i] * (1 - prob[i]) is a regularization term avoiding overfitting
            self.resume_adapter()
        # Fine-tuning Loop
        self.switch_active_layers()

    def resume_adapter(self):
        self.model.train()
        layers = eval('self.' + self.layers_attribute)
        for idx in range(self.total_layers):
            for name, module in layers[idx].named_modules():
                if hasattr(module, 'scaling'):
                    module.scaling = self.raw_scaling
                if hasattr(module, 'adapter_scaling'):
                    module.adapter_scaling = self.raw_scaling

    def adapter_response_suppression(self, select):
        layers = eval('self.' + self.layers_attribute)
        for idx in range(self.total_layers):
            if idx not in select:
                # response suppression with \beta
                for name, module in layers[idx].named_modules():
                    if hasattr(module, 'scaling'):
                        module.scaling = self.raw_scaling * self.beta
                    if hasattr(module, 'adapter_scaling'):
                        module.adapter_scaling = self.raw_scaling * self.beta
            else:
                # no change
                for name, module in layers[idx].named_modules():
                    if hasattr(module, 'scaling'):
                        module.scaling = self.raw_scaling
                    if hasattr(module, 'adapter_scaling'):
                        module.adapter_scaling = self.raw_scaling

    def freeze_all_layers(self):
        layers = eval('self.' + self.layers_attribute)  # Dynamically execute to get layers
        for layer in layers:
            for param in layer.parameters():
                param.requires_grad = False

    def switch_active_layers(self):
        # First, disable gradients for all layers
        self.freeze_all_layers()

        # Randomly select n_layers to activate
        layers = eval('self.' + self.layers_attribute)  # Re-fetch layer references
        self.active_layers_indices = self.sampling_more_important_selection(self.N_u)
        print(
            f"Total layers: {self.total_layers}, Activating layers at indices: {self.active_layers_indices} for the next steps.",
            flush=True)

        # Enable gradients only for the selected layers
        for idx in self.active_layers_indices:
            for name, module in layers[idx].named_modules():
                if hasattr(module, 'disable_adapters'):
                    for name, param in module.named_parameters():
                        if name in self.trainable_module_name:
                            param.requires_grad = True