import torch
import torch.nn.functional as F
from torch import nn

from transformers import RobertaTokenizer, RobertaConfig, BertLMHeadModel
from .modeling_roberta import RobertaModel, RobertaPromptModel

# adaption with roberta
class LanguageRoBERTaModel(nn.Module):
    def __init__(self,
                 text_encoder=None,
                 config=None,
                 use_prompt=False,
                 only_prompt=False,
                 fix_prompt=False,
                 prompt_config=None,  
                 fix_word_embeddings=False,
                 class_num=2
                 ):
        super().__init__()
        # config & model init
        roberta_config = RobertaConfig.from_pretrained(text_encoder)
        if not use_prompt:
            self.text_encoder = RobertaModel.from_pretrained(text_encoder, config=roberta_config)
        else:
            roberta_config.update({'prompt_config': prompt_config})
            self.text_encoder = RobertaPromptModel.from_pretrained(text_encoder, config=roberta_config)
            if only_prompt:
                self.text_encoder.tune_only_prompt()
            if fix_prompt:
                self.text_encoder.fix_prompt()
        # cls head
        text_width = self.text_encoder.config.hidden_size
        self.cls_head = nn.Linear(text_width, class_num)
        self.class_num = class_num
        # fix word embedding 
        if fix_word_embeddings:
            for name, para in self.text_encoder.embeddings.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
        # count trainable parameters
        num_params = 0
        for name, para in self.named_parameters():
            if para.requires_grad == True:
                num_params += para.numel()
        print('[INFO] Finetune para num:', num_params / 1e6)

    def forward(self, text, targets=None, train=True):
        if train:
            self.train()
            # get output embedding
            text_output = self.text_encoder(text.input_ids,
                                            attention_mask=text.attention_mask,
                                            return_dict=True,
                                            )
            text_embeds = text_output.last_hidden_state[:, 0, :]
            # get cls result
            prediction = self.cls_head(text_embeds)
            # only for regression task (sts-b)
            if self.class_num == 1:
                prediction = prediction.clip(0., 5.)
                targets = targets.to(torch.float)
                if len(targets.shape) == 1:
                    targets = targets.view(-1, 1)
                loss = F.mse_loss(prediction, targets)
            else:
                loss = F.cross_entropy(prediction, targets)
            return loss
        else:
            self.eval()
            text_output = self.text_encoder(text.input_ids,
                                            attention_mask=text.attention_mask,
                                            return_dict=True
                                            )
            text_embeds = text_output.last_hidden_state[:, 0, :]
            prediction = self.cls_head(text_embeds)
            return prediction
