import torch
import torch.nn.functional as F
from torch import nn

from transformers import Qwen2Tokenizer, Qwen2Config
from .modeling_qwen1_5 import Qwen2Model, Qwen2PromptModel, Qwen2ForSequenceClassification, Qwen2PromptForSequenceClassification


class LanguageQwenModel(nn.Module):
    def __init__(self,
                 text_encoder=None,
                 use_prompt=False,
                 only_prompt=False,
                 fix_prompt=False,
                 prompt_config=None,
                 fix_word_embeddings=False,
                 class_num=2
                 ):
        super().__init__()
        # config & model init
        qwen1_5_config = Qwen2Config.from_pretrained(text_encoder)
        self.config = qwen1_5_config
        if not use_prompt:
            qwen1_5_config.update({"num_labels": class_num})
            self.text_encoder = Qwen2ForSequenceClassification.from_pretrained(text_encoder, config=qwen1_5_config)
        else:
            qwen1_5_config.update({'prompt_config': prompt_config})
            qwen1_5_config.update({"num_labels": class_num})
            self.text_encoder = Qwen2PromptForSequenceClassification.from_pretrained(text_encoder, config=qwen1_5_config)
            if only_prompt:
                self.text_encoder.tune_only_prompt()
            if fix_prompt:
                self.text_encoder.fix_prompt()

        # fix word embedding 
        if fix_word_embeddings:
            for name, para in self.text_encoder.embeddings.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
        # count trainable parameters
        num_params = 0
        for name, para in self.named_parameters():
            if para.requires_grad == True:
                num_params += para.numel()
        print('[INFO] Finetune para num:', num_params / 1e6)

    def forward(self, text, targets=None, train=True):
        if train:
            self.train()
            # get output embedding
            text_output = self.text_encoder(text.input_ids,
                                            attention_mask=text.attention_mask,
                                            labels=targets,
                                            return_dict=True,
                                            )
            loss = text_output.loss
            return loss
        else:
            self.eval()
            text_output = self.text_encoder(text.input_ids,
                                            attention_mask=text.attention_mask,
                                            return_dict=True
                                            )
            prediction = text_output.logits
            return prediction


