import os
import json
import logging
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import RobertaTokenizer, RobertaModel, BertModel, BertTokenizer
import torch.nn.functional as F
import pytorch_lightning as pl

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class NUFScorer(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.tokenizer = RobertaTokenizer.from_pretrained(hparams.pretrained_model_path)
        self.bert = RobertaModel.from_pretrained(hparams.pretrained_model_path)

        #self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        #self.bert = BertModel.from_pretrained('bert-base-uncased')
        num_classes = 5
        self.mlp_hidden_0 = nn.Linear(768, 64, bias=True)
        self.mlp_hidden_1 = nn.Linear(64, 32, bias=True)
        self.mlp_hidden_2 = nn.Linear(32,  16, bias=True)
        self.mlp_out = nn.Linear(16, num_classes, bias=True)

        self.dropout = nn.Dropout(p=hparams.dropout, inplace=False)

        self.criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')


    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        last_hidden_state, _= self.bert(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )
        #max-pooling
        attention_mask = attention_mask.unsqueeze(-1).repeat((1,1,768))
        min_values = (torch.ones_like(attention_mask) * -100).type(torch.FloatTensor).to(device)
        hidden_state = attention_mask * last_hidden_state
        hidden_state = torch.where(attention_mask != 0, hidden_state, min_values)
        hidden_state, _ = hidden_state.max(dim=1)

        hidden_state = self.dropout(hidden_state)
        hidden_0 = self.mlp_hidden_0(hidden_state)
        hidden_1 = self.mlp_hidden_1(hidden_0)
        hidden_2 = self.mlp_hidden_2(hidden_1)
        out = self.mlp_out(hidden_2)
        return out


    '''def forward(self, input_ids, token_type_ids=None, mask_tokens=None, pos_ids=None):

        params = {
            "input_ids": input_ids,
            "token_type_ids": token_type_ids,
            "attention_mask": mask_tokens,
            "position_ids": pos_ids,
        }

        # run BERT with input
        last_hidden_state, _ = self.bert(**params)  

        # max-pooling
        mask_tokens = mask_tokens.unsqueeze(-1).repeat((1, 1, 768))
        min_values = (torch.ones_like(mask_tokens) * -100).type(torch.FloatTensor).to(device)
        hidden_state = mask_tokens * last_hidden_state
        hidden_state = torch.where(mask_tokens != 0, hidden_state, min_values)
        hidden_state, _ = hidden_state.max(dim=1)

        hidden_state = self.dropout(hidden_state)

        hidden_0 = self.mlp_hidden_0(hidden_state)
        hidden_1 = self.mlp_hidden_1(hidden_0)
        hidden_2 = self.mlp_hidden_2(hidden_1)
        out = self.mlp_out(hidden_2)  
        return out'''



    def predict(self, x):
        instance = self.tokenizer.encode_plus(
            x,
            add_special_tokens=True,
            max_length=self.hparams.res_token_len,
            pad_to_max_length=True,
            return_tensors="pt"
        )
        input_ids = instance['input_ids'].to(device)
        token_type_ids = instance['token_type_ids'].to(device)
        attention_mask = instance['attention_mask'].type(torch.FloatTensor).to(device)

        output = self(input_ids, token_type_ids, attention_mask)
        output = F.softmax(output, dim=1)


        # eg：tensor([[0.0049, 0.0041, 0.0182, 0.5308, 0.4420]] ,得到的某一样本（res）的预测结果
        p1 = output[:, 0]
        p2 = output[:, 1]
        p3 = output[:, 2]
        p4 = output[:, 3]
        p5 = output[:, 4]
        score = 1*p1.item() + 2*p2.item() + 3*p3.item() + 4*p4.item() + 5*p5.item()

        #return output.item()
        return score
    def training_step(self, batch, batch_nb):
        batch = [ x.to(device) for x in batch ]
        input_ids, token_type_ids, attention_mask, label = batch
        input_ids = input_ids.squeeze(1).to(device)
        token_type_ids = token_type_ids.squeeze(1).to(device)
        attention_mask = attention_mask.squeeze(1).type(torch.FloatTensor).to(device)

        output = self(input_ids, token_type_ids, attention_mask)
        loss = F.cross_entropy(output, label)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        output = self.training_step(batch, batch_nb)
        loss = output['loss']
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        print ("val_loss: ", avg_loss)
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
