import os
import json
import logging
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch.nn.functional as F
import pytorch_lightning as pl
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ABModel(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.tokenizer = AutoTokenizer.from_pretrained('DialoGPT/models/medium')
        self.model = AutoModelForSequenceClassification.from_pretrained('DialoGPT/models/medium')
        self.dropout = nn.Dropout(p=args.dropout, inplace=False)
        self.criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')


    '''def forward(self, input_ids, token_type_ids=None, mask_tokens=None, pos_ids=None):
        
        params = {
            "input_ids": input_ids,
            "token_type_ids": token_type_ids,
            "attention_mask": mask_tokens,
            "position_ids": pos_ids,
        }
        # run BERT with input
        last_hidden_state, _ = self.bert(**params)

        # max-pooling
        mask_tokens = mask_tokens.unsqueeze(-1).repeat((1,1,768))
        min_values = (torch.ones_like(mask_tokens) * -100).type(torch.FloatTensor).to(device)
        hidden_state = mask_tokens * last_hidden_state
        hidden_state = torch.where(mask_tokens != 0, hidden_state, min_values)
        hidden_state, _ = hidden_state.max(dim=1)

        hidden_state = self.dropout(hidden_state)
        next_sentence = self.next_sent(hidden_state)
        return next_sentence
        '''
    def forward(self, inputs):
        print("hello!")
        #inputs = self.tokenizer(dialog, return_tensors="pt")
        #outputs = self.model(**inputs)
        #outputs = self.model(**inputs)
        #print(inputs)
        #outputs = self.model(**inputs, labels=1)
        outputs = self.model(inputs)  #输出logits
        #return outputs.logits
        return outputs

    def training_step(self, batch, batch_idx):
        dialogs = batch[0]
        labels = batch[1]
        logits = self(dialogs).logits
        loss = self.criterion(logits, labels.long())
        print("loss:{}", loss)
        return {'loss': loss}

    def validation_step(self, batch,batch_idx):


        dialogs = batch[0]
        labels = batch[1]
        logits = self(dialogs).logits

        loss = self.criterion(logits, labels.long())
        print("loss:{}", loss)
        return {'loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        tensorboard_logs = {'loss': avg_loss}
        print ("val_loss: ", avg_loss)
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
