import torch
import torch.nn as nn

from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence,pad_packed_sequence
# from transformers import RobertaTokenizer, RobertaForTokenClassification
from transformers import BertTokenizer, BertForTokenClassification,BertModel, BertConfig
from transformers import XLNetTokenizer, XLNetModel,XLNetConfig

from torchcrf import CRF

from model_GCN import GCN
from batch import mini_batch, mini_batch_pad,mini_batch_list


class BERT_LSTM_GCN(nn.Module):
    def __init__(self, config):
        super(BERT_LSTM_GCN,self).__init__()
        self.config=config

        self.dropout=config["dropout"]

        self.bert_layer = config["bert_layer"]
        self.use_xlnet = config["use_xlnet"]
        if self.use_xlnet:
            xlnetconfig = XLNetConfig.from_pretrained(config["plm"])
            xlnetconfig.n_layer = config["bert_layer"]
            self.bert_tokenizer = XLNetModel.from_pretrained(config["plm"],config=xlnetconfig)
        else:

            bertconfig = BertConfig.from_pretrained(config["plm"])
            bertconfig.num_hidden_layers=config["bert_layer"]

            self.bert_tokenizer = BertModel.from_pretrained(config["plm"],config=bertconfig)
        
        self.finetune = config["finetune"]

        if self.finetune==False:
            print("dongjie bert层")
            for n,p in self.bert_tokenizer.named_parameters():
                p.requires_grad = False

        self.use_gcn = config["use_gcn"]
        self.use_gcn_embed = config["use_gcn_embed"]
        self.use_gcn_bertembed = config["use_gcn_bertembed"]
        
        self.output_dim = config["hidden_size"]

        self.use_bigram=config["use_bigram"]
        if self.use_bigram:
            self.bigram_dim = config["bigram_dim"]
            self.bigarm_embed = nn.Embedding(config["bi_vocab_size"],self.bigram_dim)
            self.output_dim += self.bigram_dim
        

        if self.use_gcn:
            self.node_dim = config["graph_embed_dim"]
            self.graph_hidden_dim=config["graph_hidden_dim"]

            if self.use_gcn_embed:
                self.graph_embed=nn.Embedding(config["vocab_size"],self.node_dim)
            
            self.graph_model = GCN(self.node_dim,self.graph_hidden_dim,self.dropout)

            self.cat=config["cat"]

            if self.cat:
                self.output_dim += self.graph_hidden_dim
        else:
            # if self.cat:
            self.linear_input_dim = self.output_dim


        self.use_bilstm = config["use_bilstm"]
        self.use_output_dropout= config["use_output_dropout"]
        if self.use_bilstm:
            self.lstm_layer=config["lstm_layer"]
            self.lstm = nn.LSTM(self.output_dim, self.graph_hidden_dim, num_layers=self.lstm_layer, dropout=self.dropout , bidirectional=True, batch_first=True)
            self.linear_input_dim = self.graph_hidden_dim*2
        else:
            self.linear_input_dim = self.output_dim
            if self.use_output_dropout:
                self.output_dropout = nn.Dropout(self.dropout)
        
        self.use_crf=config["use_crf"]
        if self.use_crf:
            self.classifier = nn.Linear(self.linear_input_dim, config['tagset_size']+3)
            self.CRF = CRF(config["tagset_size"]+3, batch_first=True)

        else:
            self.classifier = nn.Linear(self.linear_input_dim, config['tagset_size']+1)
            self.loss_function=nn.CrossEntropyLoss(ignore_index=0)


    def forward(self,inputs, bi_inputs,golds, mask, token_type_ids, edges_index, indices_2=None, lengths=None, device=torch.device('cpu'), attention_mask = None, output_hidden_states=True,mode="train"):
        
        batch_size,seq_len = inputs.size(0),inputs.size(1)
        if self.use_xlnet:
            bert_output = self.bert_tokenizer(input_ids = inputs, token_type_ids = token_type_ids, attention_mask =attention_mask, output_hidden_states=output_hidden_states).hidden_states
            # print(bert_output[-1].size())
            # exit()
        else:
            
            bert_output = self.bert_tokenizer(input_ids = inputs, token_type_ids = token_type_ids, attention_mask =attention_mask, output_hidden_states=output_hidden_states,output_attentions=True)
            attention=bert_output.attentions[-1]  #12 * batch *head *S*S
            print(attention.size())
            exit()
            bert_output=bert_output.hidden_states
        
        out = bert_output[-1][:,1:-1,:]
        seq_len=seq_len-2
        # if mode =="demo":
        #     demo_buffer=list()
        #     bert_out__= out.index_select(0,indices_2)
            

        if self.use_gcn:
            if self.use_gcn_embed:
                graph_embedding = self.graph_embed(inputs)[:,1:-1,:]
            elif self.use_gcn_bertembed:
                graph_embedding = bert_output[0][:,1:-1,:]    #768
            else:
                graph_embedding = out    #768
            
            
            graph_embedding, edges_index= mini_batch_list(graph_embedding, edges_index,edge_atter=None)
            # print(out.size())
            # exit()
            # graph_embedding, edges_index= mini_batch_pad(graph_embedding, edges_index)
            edges_index=edges_index.to(device)

            graph_outs = self.graph_model(graph_embedding, edges_index).view(batch_size,seq_len,-1)

            if self.cat:

                if self.use_bigram:
                    bigram_embedding = self.bigarm_embed(bi_inputs)
                    out = torch.cat([out,bigram_embedding,graph_outs],dim=2)
                else:
                    out = torch.cat([out,graph_outs],dim=2)

            else:
                out = out + graph_outs
        else:
            # if self.cat:
            if self.use_bigram:
                bigram_embedding = self.bigarm_embed(bi_inputs)
                out = torch.cat([out,bigram_embedding],dim=2)
            # else:
                # out = torch.cat([out,graph_outs],dim=2)
            # else:
            #     if self.use_bigram:
            #         print("error!")
            #         exit()
            #     else:
            #         pass
        # if mode =="demo":
        #     graph_out__= out.index_select(0,indices_2)
            

        out=out[:,:-3,:]
        seq_len=seq_len-3
        
            
            
        # print(lengths)
        # print(out.size(1))
        # exit()

        if self.use_bilstm:
            out = pack_padded_sequence(out, lengths, batch_first=True)
            out, _ = self.lstm(out)
            out, _ = pad_packed_sequence(out, batch_first= True)
        elif self.use_output_dropout:
            out = self.output_dropout(out)
        # print(out.size())
        # exit()
        out = self.classifier(out)

        # if indices_2!=None:
        #     out=out.index_select(indices_2)

        if self.use_crf:
            if mode == "train":
                # print(out)
                # print(golds)
                # print(mask)
                # tag_space = self.CRF.decode(out, mask)
                # print(tag_space)
                # out = nn.functional.softmax(out,dim=2)
                loss = self.CRF(out,golds,mask,reduction="mean")
                # loss = self.CRF(out,golds,mask,reduction="token_mean")
                # print(loss)
                # exit()
                return -loss
            else:
                if True:
                # if self.use_bilstm:
                    out = out.index_select(0,indices_2)
                    mask = mask.index_select(0,indices_2)
                    lengths = torch.tensor(lengths).to(device).index_select(0,indices_2).tolist()
                # print(out)
                # print(mask)
                tag_space = self.CRF.decode(out, mask)
                # print(tag_space)
                # print(lengths)
                # exit()
                tag_space = [tag_space[i][:eof] for i,eof in enumerate(lengths)]
                return tag_space
        else:
            if mode == "train":
                # print(out)
                logits = out.reshape(batch_size * seq_len,-1)
                loss=self.loss_function(logits, golds.view(batch_size*seq_len))
                # print(loss)
                return loss

            elif mode == "valid" or mode == "test":
                # out = nn.functional.softmax(out, dim=2)
                # print(len(lengths))
                # print(lengths)
                # print(indices_2)
                # exit()
                # if self.use_bilstm:
                if True:
                    out = out.index_select(0,indices_2)
                    lengths = torch.tensor(lengths).to(device).index_select(0,indices_2).tolist()
                # print(out)
                # print(lengths)

                # print(out)
                # exit()
                tag_space = out.argmax(dim=2)
                tag_space=tag_space.tolist()
                # print(tag_space)
                # exit()
                tag_space = [tag_space[i][:eof] for i,eof in enumerate(lengths)]
                return tag_space
            elif mode == "demo":

                if True:
                    out = out.index_select(0,indices_2)
                    lengths = torch.tensor(lengths).to(device).index_select(0,indices_2).tolist()

                # print(out.szie())
                output_softmax=nn.functional.softmax(out,dim=2)
                # attention=nn.functional.softmax(attention,dim=3)
                # print(attention.size())
                # exit()
                tag_space = out.argmax(dim=2)
                tag_space=tag_space.tolist() 
                tag_space = [tag_space[i][:eof] for i,eof in enumerate(lengths)]
                # if mode =="demo":
                    # bert_out__=bert_out__.tollist()
                    # graph_out__=graph_out__.tolist()
                    # bert_out__=[bert_out__[i][:eof] for i,eof in enumerate(lengths)]
                    # graph_out__=[graph_out__[i][:eof] for i,eof in enumerate(lengths)]
                    # demo_buffer.append((bert_out__,graph_out__))
                return tag_space,(attention,output_softmax)
def list_index_select(tensor_list,indices):

    buffer=list()
    for index in indices.tolist():
        buffer.append(tensor_list[index])
    # for idx, index in enumerate(indices.tolist()):
    return buffer
