import dgl
from dgl.nn import GraphConv

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from transformers import AutoConfig
from transformers import AutoModelWithLMHead
from src.utils import *
from src.dataloader import *

import logging
logger = logging.getLogger()

class BertTagger(nn.Module):
    def __init__(self, src_dm, tgt_dm, hidden_dim, model_name, ckpt):
        super(BertTagger, self).__init__()
        self.num_class_source = len(domain2labels[src_dm])
        self.num_class_target = len(domain2labels[tgt_dm])
        self.num_entity_class_target = len(domain2entitylabels[tgt_dm])
        self.hidden_dim = hidden_dim
        config = AutoConfig.from_pretrained(model_name)
        config.output_hidden_states = True
        self.encoder = AutoModelWithLMHead.from_pretrained(model_name, config=config)

        if ckpt != '':
            logger.info("Reloading encoder from %s" % ckpt)
            encoder_ckpt = torch.load(ckpt)
            self.encoder.load_state_dict(encoder_ckpt)

        # The classifier
        self.is_source = True
        self.linear_source = nn.Linear(self.hidden_dim, self.num_class_source)
        self.linear_target = nn.Linear(self.hidden_dim, self.num_class_target)
        self.linear_target_2 = nn.Linear(self.hidden_dim, self.num_class_target)
        self.linear_target_bce = nn.Linear(self.hidden_dim*self.num_entity_class_target, 
                                            self.num_entity_class_target)

        # The linear projection to label related embedding
        self.proj_layer = nn.Sequential(
                                nn.Linear(self.hidden_dim, self.hidden_dim)
                            )
        self.label_repre = nn.Parameter(torch.zeros(self.num_entity_class_target,
                                                    self.hidden_dim),
                                                requires_grad=True)
        nn.init.normal_(self.label_repre, mean=0.0, std=0.1)
        self.gcn = GraphConv(self.hidden_dim, 
                            self.hidden_dim, 
                            norm='none', 
                            weight=True, 
                            bias=True)
        self.source_graph_gcn = None
        self.source_edges = None

    def forward(self, X, auxillary_task=False, return_hiddens=False):
        outputs = self.encoder(X) # a tuple ((bsz,seq_len,hidden_dim), (bsz, hidden_dim))
        outputs = outputs[1][-1] # (bsz, seq_len, hidden_dim)
        
        if self.is_source:
            prediction = self.linear_source(outputs)
            if return_hiddens:
                return prediction, outputs
            else:
                return prediction
        else:
            batch_size = outputs.shape[0]
            outputs_label_repre = self.proj_layer(outputs)
            label_attention_logits = torch.bmm(self.label_repre.expand((batch_size,-1,-1)), 
                                            outputs_label_repre.transpose(1,2))
            label_attention_logits = label_attention_logits.masked_fill(
                X.unsqueeze(1).expand_as(label_attention_logits) == 0, 
                1e-9)
            label_attention = torch.nn.functional.softmax(label_attention_logits, dim=-1)
            label_semantic_repre = torch.bmm(label_attention, outputs_label_repre)

            # for gcn
            label_semantic_repre_2 = self.gcn(self.source_graph_gcn.to('cuda'), 
                                            label_semantic_repre.transpose(0,1),
                                            edge_weight=self.source_edges.cuda())
            label_semantic_repre_2 = label_semantic_repre_2.transpose(0,1)
            # for BCE classifier
            sentence_pred_bce = self.linear_target_bce(label_semantic_repre_2.reshape((batch_size,-1)))
            
            # attention for each token
            token_attention_logits = torch.bmm(outputs_label_repre,
                                    self.label_repre.expand((batch_size,-1,-1)).transpose(1,2))
            token_attention = torch.nn.functional.softmax(token_attention_logits, dim=-1)
            token_label_repre = torch.bmm(token_attention, label_semantic_repre_2)

            prediction_1 = self.linear_target(outputs)
            prediction_2 = self.linear_target_2(token_label_repre)
            prediction = prediction_1+prediction_2

            if return_hiddens:
                if auxillary_task:
                    return prediction, token_attention_logits, sentence_pred_bce, outputs
                else:
                    return prediction, outputs
            else:
                if auxillary_task:
                    return prediction, token_attention_logits, sentence_pred_bce
                else:
                    return prediction