import torch
import torch.nn.functional as F
from torch import nn

from transformers import PhiConfig
from .modeling_phi import PhiForCausalLM


class LanguagePhiModel(nn.Module):
    def __init__(self,
                 text_encoder=None,
                 use_prompt=False,
                 only_prompt=False,
                 fix_prompt=False,
                 prompt_config=None,
                 fix_word_embeddings=False,
                 ):
        super().__init__()
        # config & model init
        phi_config = PhiConfig.from_pretrained(text_encoder)
        self.config = phi_config
        if not use_prompt:
            self.text_encoder = PhiForCausalLM.from_pretrained(text_encoder, config=phi_config)
        else:
            pass
            phi_config.update({'prompt_config': prompt_config})
            # self.text_encoder = PhiPromptForCausalLM.from_pretrained(text_encoder, config=phi_config)
            # if only_prompt:
            #     self.text_encoder.tune_only_prompt()
            # if fix_prompt:
            #     self.text_encoder.fix_prompt()

        # fix word embedding 
        if fix_word_embeddings:
            for name, para in self.text_encoder.embeddings.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
        # count trainable parameters
        num_params = 0
        for name, para in self.named_parameters():
            if para.requires_grad == True:
                num_params += para.numel()
        print('[INFO] Finetune para num:', num_params / 1e6)

    def forward(self, text, train=True):
        if train:
            self.train()
            # get output embedding
            text_output = self.text_encoder(**text,
                                            return_dict=True,
                                            )
            # loss = text_output.loss
            return text_output
        else:
            self.eval()
            model_pred = self.text_encoder.generate(text['input_ids'], attention_mask=text['attention_mask'], max_new_tokens=32)
            return model_pred

    def generate(self, **kwargs):
        return self.text_encoder.generate(**kwargs)