import torch
import logging
from torch import nn
from transformers import GPT2LMHeadModel
from peft import PrefixTuningConfig, get_peft_model, TaskType



class FedSPModel(nn.Module):
    def __init__(self, model_name, role, prefix_len, aux_layer_num=0):
        super().__init__()

        if model_name == "gpt2_medium":
            model_path = "/home/zhujh/20240417_FedQLoRA/model/gpt2_lm_medium"
        self.model = GPT2LMHeadModel.from_pretrained(model_path, device_map = "auto")

        self.role = role
        self.prefix_len = prefix_len

        self.n_layer = self.model.config.n_layer
        self._aux_layer_num = aux_layer_num

        if role == "aux":
            self.model.transformer.h = nn.ModuleList(list(self.model.transformer.h[0:self._aux_layer_num])* int(self.n_layer/self._aux_layer_num))

        config = PrefixTuningConfig(
            num_virtual_tokens = self.prefix_len,
            prefix_projection=True,
            task_type=TaskType.CAUSAL_LM
        )
        self.model = get_peft_model(self.model, config)

    def _freeze_param(self, params):
        for n, p in self.named_parameters():
            for fp in params:
                if fp in n:
                    p.requires_grad = False
                    break

    def _unfreeze_param(self, params):
        for n, p in self.named_parameters():
            for fp in params:
                if fp in n:
                    p.requires_grad = True
                    break

    def generate(self, **kwargs):
        return self.model.generate(**kwargs)

    def forward(self,
                input_ids,
                attention_mask=None):

        batch_size = input_ids.size(0)
        if attention_mask is None:
            attention_mask = torch.ones(batch_size,
                                        input_ids.size(1)).to(input_ids.device)
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        return outputs

