import torch
from transformers import AutoModelForSequenceClassification,BertModel
from transformers import GPT2LMHeadModel
from transformers import AutoTokenizer, GPT2Tokenizer
import torch.nn as nn
from transformers.modeling_outputs import SequenceClassifierOutput
import losses

class Model(torch.nn.Module):
    def __init__(self,args,model_name,from_check_point = False,tokenizer_dir = None, model_dir = None): #if model name is a dir, then we directly load the weight, else we load from transformer package
        super(Model,self).__init__()
        assert(type(from_check_point) == bool)   #Check the datatype

        self.args = args
        if 'gpt' in model_name:
            self.tokenizer = GPT2Tokenizer.from_pretrained(model_name,do_lower_case = True) if from_check_point == False else GPT2Tokenizer.from_pretrained(tokenizer_dir,do_lower_case = True)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name,num_labels = 2)
            if not from_check_point:
                self.tokenizer.add_special_tokens({'pad_token':'[PAD]'})
            self.model.resize_token_embeddings(len(self.tokenizer))
            self.model.config.pad_token_id = self.tokenizer.pad_token_id
            if from_check_point:
                config = torch.load(model_dir,map_location = {'cuda:0':"cuda:0"})
                self.model.load_state_dict(config)

        else:
            # self.model = AutoModelForSequenceClassification.from_pretrained(model_name,num_labels = 2,ignore_mismatched_sizes=True)
            self.model = BertModel.from_pretrained(model_name)
            self.dropout = nn.Dropout(0.2)
            self.bilstm = nn.LSTM(bidirectional=True, input_size=768, hidden_size=768 // 2,batch_first=True)
            self.line1 = nn.Linear(768, 128)
            self.relu = nn.ReLU()
            self.classifier2 = nn.Linear(128, 2)

            if from_check_point:
                config = torch.load(model_dir)
                self.model.load_state_dict(config)
            #self.model = torch.load(model_dir)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name,do_lower_case = True) if from_check_point == False else AutoTokenizer.from_pretrained(tokenizer_dir,do_lower_case = True)
        
    def forward(self,sent,label,device):
        sent = list(sent)
        token = self.tokenizer(sent, padding='max_length', truncation=True, max_length=512, return_tensors="pt").to(device)
        # print(token)
        loss = None
        if label==[]:
            output = self.model(**token)
            pooled_output = output[1]
            pooled_output = self.dropout(pooled_output)
            pooled_output, _ = self.bilstm(pooled_output)
            temp = self.relu(self.line1(pooled_output))

            logits = self.classifier2(temp)

        else:
            # output = self.model(**token,labels = label)
            output = self.model(**token)
            pooled_output = output[1]
            pooled_output = self.dropout(pooled_output)
            pooled_output, _ = self.bilstm(pooled_output)
            temp = self.relu(self.line1(pooled_output))

            logits = self.classifier2(temp)

            loss_fct = nn.CrossEntropyLoss()
            ce_loss = loss_fct(logits.view(-1, 2), label.view(-1))
            scl_fct = losses.SupConLoss()
            scl_loss = scl_fct(pooled_output, label)

            loss = 0.9*ce_loss + 0.1 * scl_loss
        # return output
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits
            # hidden_states=output.hidden_states,
            # attentions=output.attentions
        )
    
    def save_model(self,dir):
        self.tokenizer.save_pretrained(dir)
        torch.save(self.model.state_dict(),dir+ f"/dev_best_seed{self.args.seed}.pth")