import torch
import torch.nn as nn
from opt_einsum import contract
from long_seq import process_long_input
from losses import ATLoss
import torch.nn.functional as F
from axial_attention import AxialAttention, AxialImageTransformer
import numpy as np
from itertools import accumulate

def batch_index(tensor, index, pad=False):
    if tensor.shape[0] != index.shape[0]:
        raise Exception()

    if not pad:
        return torch.stack([tensor[i][index[i]] for i in range(index.shape[0])])
    else:
        return padded_stack([tensor[i][index[i]] for i in range(index.shape[0])])

class AxialTransformer_by_entity(nn.Module):
    def  __init__(self, emb_size = 768, dropout = 0.1, num_layers = 2, dim_index = -1, heads = 8, num_dimensions=2, ):
        super().__init__()
        self.num_layers = num_layers
        self.dim_index = dim_index
        self.heads = heads
        self.emb_size = emb_size
        self.dropout = dropout
        self.max_ent = 42
        self.num_dimensions = num_dimensions
        self.axial_attns = nn.ModuleList([AxialAttention(dim = self.emb_size, dim_index = dim_index, heads = heads, num_dimensions = num_dimensions, ) for i in range(num_layers)])
        self.ffns = nn.ModuleList([nn.Linear(self.emb_size, self.emb_size) for i in range(num_layers)] )
        self.lns = nn.ModuleList([nn.LayerNorm(self.emb_size) for i in range(num_layers)])
        self.attn_dropouts = nn.ModuleList([nn.Dropout(dropout) for i in range(num_layers)])
        self.ffn_dropouts = nn.ModuleList([nn.Dropout(dropout) for i in range(num_layers)] )
    def forward(self, x):
        for idx in range(self.num_layers):
          x = x + self.attn_dropouts[idx](self.axial_attns[idx](x))
          x = self.ffns[idx](x)
          x = self.ffn_dropouts[idx](x)
          x = self.lns[idx](x)
        return x

class DocREModel(nn.Module):
    def __init__(self, config, model, emb_size=1024, block_size=64, num_labels=-1):
        super().__init__()
        self.config = config
        self.model = model
        self.hidden_size = config.hidden_size
        self.loss_fnt = ATLoss()

        self.head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.mention_head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.mention_tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
        self.projection = nn.Linear(emb_size * block_size, config.hidden_size)
        self.mention_projection = nn.Linear(emb_size * block_size, config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size , config.num_labels)
        self.axial_transformer = AxialTransformer_by_entity(emb_size = config.hidden_size, dropout=0.0, num_layers=6, heads=8)
        #self.ent_num_emb = nn.Embedding(37, config.hidden_size)
        self.emb_size = emb_size
        self.block_size = block_size
        self.num_labels = num_labels

    def encode(self, input_ids, attention_mask):
        config = self.config
        if config.transformer_type == "bert":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id]
        elif config.transformer_type == "roberta":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id, config.sep_token_id]
        sequence_output, attention = process_long_input(self.model, input_ids, attention_mask, start_tokens, end_tokens)
        return sequence_output, attention

    def get_hrt(self, sequence_output, attention, entity_pos, hts, sentid_mask):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        sent_embs = []
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            sid_mask = sentid_mask[i]
            sentids = [x for x in range(torch.max(sid_mask).cpu().long() + 1)]
            local_mask  = torch.tensor([sentids] * sid_mask.size()[0] ).T
            local_mask = torch.eq(sid_mask , local_mask).long().to(sequence_output)
            sentence_embs = local_mask.unsqueeze(2) * sequence_output[i]
            #print(sentence_embs.size())
            sentence_embs = torch.sum(sentence_embs, dim=1)/local_mask.unsqueeze(2).sum(dim=1)
            #print(sentence_embs.size())
            seq_sent_embs = sentence_embs.unsqueeze(1) * local_mask.unsqueeze(2)
            seq_sent_embs = torch.sum(seq_sent_embs, dim=0)
            #print(seq_sent_embs.size())
            sent_embs.append(seq_sent_embs)
            for e in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                if len(e) > 1:
                    e_emb, e_att = [], []
                    for start, end in e:
                        if start + offset < c:
                            # In case the entity mention is truncated due to limited max seq length.
                            #e_emb.append(sequence_output[i, start + offset] + seq_sent_embs[start + offset])
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                else:
                    start, end = e[0]
                    if start + offset < c:
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            #entity_lens = torch.stack(entity_lens, dim=0)
            #entity_embs = entity_embs + entity_lens

            ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])

            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            #ht_att = (h_att * t_att).mean(1)
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            hss.append(hs)
            tss.append(ts)
            rss.append(rs)
        hss = torch.cat(hss, dim=0)
        tss = torch.cat(tss, dim=0)
        rss = torch.cat(rss, dim=0)
        sent_embs = torch.cat(sent_embs, dim=0)
        return hss, rss, tss, sent_embs


    def entities(self, sequence_output, attention, entity_pos, hts, sentid_mask):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        sent_embs = []
        b_ent_embs = []
        b_ent_atts = []
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            sid_mask = sentid_mask[i]
            sentids = [x for x in range(torch.max(sid_mask).cpu().long() + 1)]
            local_mask  = torch.tensor([sentids] * sid_mask.size()[0] ).T
            local_mask = torch.eq(sid_mask , local_mask).long().to(sequence_output)
            sentence_embs = local_mask.unsqueeze(2) * sequence_output[i]
            sentence_embs = torch.sum(sentence_embs, dim=1)/local_mask.unsqueeze(2).sum(dim=1)
            seq_sent_embs = sentence_embs.unsqueeze(1) * local_mask.unsqueeze(2)
            seq_sent_embs = torch.sum(seq_sent_embs, dim=0)
            sent_embs.append(seq_sent_embs)
            for e in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                if len(e) > 1:
                    e_emb, e_att = [], []
                    for start, end in e:
                        if start + offset < c:
                            # In case the entity mention is truncated due to limited max seq length.
                            #e_emb.append(sequence_output[i, start + offset] + seq_sent_embs[start + offset])
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                else:
                    start, end = e[0]
                    if start + offset < c:
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]

            b_ent_embs.append(entity_embs)
            b_ent_atts.append(entity_atts)
        sent_embs = torch.cat(sent_embs, dim=0)
        return b_ent_embs, b_ent_atts, sent_embs




    def get_mention_hrt(self, sequence_output, attention, mention_pos, mention_hts, sent_embs):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss, ctxs = [], [], [], []

        for i in range(len(mention_pos)):
            entity_embs, entity_atts = [], []
            men_idx = 0
            loc_mention_hts = mention_hts[i]

            for e in loc_mention_hts:
                start, end = e 
                if start + offset < c:
                    #e_emb = sequence_output[i, start + offset] + sent_embs[i, start + offset] 
                    e_emb = sequence_output[i, start + offset] 
                    e_att = attention[i, :, start + offset]
                else:
                    e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                    e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            


            ht_i = torch.LongTensor(mention_hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
            mention_pos_t = torch.tensor(mention_pos[0]).long().to(sequence_output)
            h_pos = torch.index_select( mention_pos_t, 0 , ht_i[:,0])
            t_pos = torch.index_select( mention_pos_t, 0 , ht_i[:,1])
            ht_pos = torch.cat([h_pos, t_pos], dim=1)
    

            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            #ht_att = (h_att * t_att).mean(1)
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            #print(ht_att.size())
            #print(sequence_output[i].size())
            #exit()
            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            hss.append(hs)
            tss.append(ts)
            rss.append(rs)
            #ctxs.append(context_embs)
        hss = torch.cat(hss, dim=0)
        tss = torch.cat(tss, dim=0)
        rss = torch.cat(rss, dim=0)
        #ctxs = torch.cat(ctxs, dim=0)
        return hss, rss, tss

    def compute_kl_loss(self, p, q, pad_mask=None):
    
        p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
        q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
    
        # pad_mask is for seq-level tasks
        if pad_mask is not None:
            p_loss.masked_fill_(pad_mask, 0.)
            q_loss.masked_fill_(pad_mask, 0.)

        # You can choose whether to use function "sum" and "mean" depending on your task
        p_loss = p_loss.sum()
        q_loss = q_loss.sum()

        loss = (p_loss + q_loss) / 2
        return loss

    def get_mask(self, ents, bs, ne, run_device):
        ent_mask = torch.zeros(bs, ne, device=run_device)
        rel_mask = torch.zeros(bs, ne, ne, device=run_device)
        for _b in range(bs):
            ent_mask[_b, :len(ents[_b])] = 1
            rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1
        return ent_mask, rel_mask


    def get_ht(self, rel_enco, hts):
        htss = []
        for i in range(len(hts)):
            ht_index = hts[i]
            for (h_index, t_index) in ht_index:
                htss.append(rel_enco[i,h_index,t_index])
        htss = torch.stack(htss,dim=0)
        return htss

    def get_entity_table(self, sequence_output, entity_embs):

        bs,_,d = sequence_output.size()
        ne = self.max_ent

        index_pair = []
        for i in range(ne):
            tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
            index_pair.append(tmp)
        index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
        map_rss = []
        for b in range(bs):
            entity_atts = entity_as[b]
            h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
            t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
            ht_att = (h_att * t_att).mean(1)
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            rs = contract("ld,rl->rd", sequence_output[b], ht_att)
            map_rss.append(rs)
        map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
        return map_rss


    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None,
                entity_pos=None,
                hts=None,
                mention_pos=None,
                mention_hts=None,
                padded_mention=None,
                padded_mention_mask=None,
                sentid_mask=None,
                instance_mask=None,
                return_logits=False,
                ):

        sequence_output_0, attention_0 = self.encode(input_ids, attention_mask)
        #print(sentid_mask)
        b_ent_embs, b_ent_atts, sent_embs = get_entities(self, sequence_output, attention, entity_pos, hts, sentid_mask)
        hs_0, rs_0, ts_0, sent_embs = self.get_hrt(sequence_output_0, attention_0, entity_pos, hts, sentid_mask)
        hs_0 = torch.tanh(self.head_extractor(torch.cat([hs_0, rs_0], dim=1)))
        ts_0 = torch.tanh(self.tail_extractor(torch.cat([ts_0, rs_0], dim=1)))
        b1_0 = hs_0.view(-1, self.emb_size // self.block_size, self.block_size)
        b2_0 = ts_0.view(-1, self.emb_size // self.block_size, self.block_size)
        bl_0 = (b1_0.unsqueeze(3) * b2_0.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
        ent_nums = int(np.sqrt(hs_0.size()[0]))
        #bl_0 = bl_0.view( 1, ent_nums, ent_nums, -1)

        ent_lens = [len(x) for x in entity_pos[0]]
        cum_sum = [0] + list(accumulate(ent_lens))
        
        feature = self.projection(bl_0) 
        '''
        hs_m, rs_m, ts_m= self.get_mention_hrt(sequence_output_0, attention_0, mention_pos, mention_hts, sent_embs)
        hs_m = torch.tanh(self.head_extractor(torch.cat([hs_m, rs_m], dim=1)))
        ts_m = torch.tanh(self.tail_extractor(torch.cat([ts_m, rs_m], dim=1)))
        b1_m = hs_m.view(-1, self.emb_size // self.block_size, self.block_size)
        b2_m = ts_m.view(-1, self.emb_size // self.block_size, self.block_size)
        bl_m = (b1_m.unsqueeze(3) * b2_m.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
        men_nums = int(np.sqrt(hs_m.size()[0]))
        
        feature = self.projection(bl_0) 
        mention_feature = self.projection(bl_m)

        to_pool = batch_index(mention_feature.unsqueeze(0), padded_mention[0].to(sequence_output_0).long().unsqueeze(0))
        to_pool += (padded_mention_mask[0].to(sequence_output_0).unsqueeze(-1) == 0).float() * (-1e4)
        to_pool = to_pool.max(dim=2)[0].squeeze()
        
        feature = feature + to_pool
        '''
        feature = self.axial_transformer(feature.view(1, ent_nums, ent_nums, -1) )
        logits_0 = self.classifier(feature)
        #logits_0 = self.bilinear(bl_0)
        logits_0 = logits_0.view(-1, self.config.num_labels)
        
          
        output = (self.loss_fnt.get_label(logits_0, num_labels=self.num_labels), logits_0)
        if return_logits==True:
            output = logits_0
        if labels is not None:
            labels = [torch.tensor(label) for label in labels]
            labels = torch.cat(labels, dim=0).to(logits_0)
            label_idx = torch.max(labels, dim = 1 )[1]
            non_zero_idx = torch.where(label_idx != 0)[0]
            loss_0 = self.loss_fnt(logits_0.float(), labels.float())
            g_x = torch.autograd.grad(loss_0, logits_0, retain_graph=True)[0]
            g_x_norm = torch.norm(g_x, dim=1)
            g_x_norm = torch.mean(g_x_norm)
            
            output_0 = loss_0.to(sequence_output_0)
            output = (output_0, g_x_norm)
            exit()
        return output


class DocREModel_KD(nn.Module):
    def __init__(self, config, model, emb_size=1024, block_size=64, num_labels=-1, teacher_model=None):
        super().__init__()
        self.config = config
        self.model = model
        self.hidden_size = config.hidden_size
        self.loss_fnt = ATLoss()
        if teacher_model is not None:
            self.teacher_model = teacher_model
            self.teacher_model.requires_grad = False
            self.teacher_model.eval()
        else:
            self.teacher_model = None
        self.head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.mention_head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.mention_tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
        self.projection = nn.Linear(emb_size * block_size, config.hidden_size)
        self.mention_projection = nn.Linear(emb_size * block_size, config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size , config.num_labels)
        self.axial_transformer = AxialTransformer_by_entity(emb_size = config.hidden_size, dropout=0.0, num_layers=6, heads=8)
        #self.ent_num_emb = nn.Embedding(37, config.hidden_size)
        self.emb_size = emb_size
        self.block_size = block_size
        self.num_labels = num_labels
    def encode(self, input_ids, attention_mask):
        config = self.config
        if config.transformer_type == "bert":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id]
        elif config.transformer_type == "roberta":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id, config.sep_token_id]
        sequence_output, attention = process_long_input(self.model, input_ids, attention_mask, start_tokens, end_tokens)
        return sequence_output, attention

    def get_hrt(self, sequence_output, attention, entity_pos, hts):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            for e in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                if len(e) > 1:
                    e_emb, e_att = [], []
                    for start, end in e:
                        if start + offset < c:
                            # In case the entity mention is truncated due to limited max seq length.
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                else:
                    start, end = e[0]
                    if start + offset < c:
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            #entity_lens = torch.stack(entity_lens, dim=0)
            #entity_embs = entity_embs + entity_lens

            ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])

            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            #ht_att = (h_att * t_att).mean(1)
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            hss.append(hs)
            tss.append(ts)
            rss.append(rs)
        hss = torch.cat(hss, dim=0)
        tss = torch.cat(tss, dim=0)
        rss = torch.cat(rss, dim=0)
        return hss, rss, tss


    def get_mention_hrt(self, sequence_output, attention, mention_pos, mention_hts):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss, ctxs = [], [], [], []

        for i in range(1):
            entity_embs, entity_atts = [], []
            men_idx = 0
            loc_mention_hts = mention_hts[i]
            for e in loc_mention_hts:
                start, end = e 
                if start + offset < c:
                    e_emb = sequence_output[i, start + offset]
                    e_att = attention[i, :, start + offset]
                else:
                    e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                    e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            


            ht_i = torch.LongTensor(mention_hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
            mention_pos_t = torch.tensor(mention_pos[0]).long().to(sequence_output)
            h_pos = torch.index_select( mention_pos_t, 0 , ht_i[:,0])
            t_pos = torch.index_select( mention_pos_t, 0 , ht_i[:,1])
            ht_pos = torch.cat([h_pos, t_pos], dim=1)
    
            '''
            ctx_emb_mask = torch.zeros((ht_pos.size()[0], sequence_output.size()[1])).to(sequence_output)
            max_ht = torch.max(ht_pos, dim = 1)[0].long().squeeze()
            min_ht = torch.min(ht_pos, dim = 1)[0].long().squeeze()
            
            for ix in range(ctx_emb_mask.size()[0]):
                ctx_emb_mask[ix, min_ht[ix]:max_ht[ix]] = 1
            size = ctx_emb_mask.size()[0]
            half_size = int(1/2 * size)
            context_embs_0 = sequence_output.repeat(half_size, 1,1)
            context_embs_0 = context_embs_0.masked_fill(ctx_emb_mask[:half_size].eq(0).unsqueeze(2), -1.0e4)
            context_embs_0 = context_embs_0.max(dim=1)[0]

            context_embs_1 = sequence_output.repeat((size-half_size), 1,1)
            context_embs_1 = context_embs_1.masked_fill(ctx_emb_mask[half_size:].eq(0).unsqueeze(2), -1.0e4)
            context_embs_1 = context_embs_1.max(dim=1)[0]

            context_embs = torch.cat([context_embs_0,context_embs_1], dim=0)
            '''

            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            #ht_att = (h_att * t_att).mean(1)
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            #print(ht_att.size())
            #print(sequence_output[i].size())
            #exit()
            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            hss.append(hs)
            tss.append(ts)
            rss.append(rs)
            #ctxs.append(context_embs)
        hss = torch.cat(hss, dim=0)
        tss = torch.cat(tss, dim=0)
        rss = torch.cat(rss, dim=0)
        #ctxs = torch.cat(ctxs, dim=0)
        return hss, rss, tss

    def get_mask(self, ents, bs, ne, run_device):
        ent_mask = torch.zeros(bs, ne, device=run_device)
        rel_mask = torch.zeros(bs, ne, ne, device=run_device)
        for _b in range(bs):
            ent_mask[_b, :len(ents[_b])] = 1
            rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1
        return ent_mask, rel_mask


    def get_ht(self, rel_enco, hts):
        htss = []
        for i in range(len(hts)):
            ht_index = hts[i]
            for (h_index, t_index) in ht_index:
                htss.append(rel_enco[i,h_index,t_index])
        htss = torch.stack(htss,dim=0)
        return htss

    def get_entity_table(self, sequence_output, entity_embs):

        bs,_,d = sequence_output.size()
        ne = self.max_ent

        index_pair = []
        for i in range(ne):
            tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
            index_pair.append(tmp)
        index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
        map_rss = []
        for b in range(bs):
            entity_atts = entity_as[b]
            h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
            t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
            ht_att = (h_att * t_att).mean(1)
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            rs = contract("ld,rl->rd", sequence_output[b], ht_att)
            map_rss.append(rs)
        map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
        return map_rss


    def compute_kl_loss(self, p, q, pad_mask=None):
    
        p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
        q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
    
        # pad_mask is for seq-level tasks
        if pad_mask is not None:
            p_loss.masked_fill_(pad_mask, 0.)
            q_loss.masked_fill_(pad_mask, 0.)

        # You can choose whether to use function "sum" and "mean" depending on your task
        p_loss = p_loss.sum()
        q_loss = q_loss.sum()

        loss = (p_loss + q_loss) / 2
        return loss



    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None,
                entity_pos=None,
                hts=None,
                mention_pos=None,
                mention_hts=None,
                padded_mention=None,
                padded_mention_mask=None,
                instance_mask=None,
                ):

        sequence_output_0, attention_0 = self.encode(input_ids, attention_mask)

        hs_0, rs_0, ts_0 = self.get_hrt(sequence_output_0, attention_0, entity_pos, hts)
        hs_0 = torch.tanh(self.head_extractor(torch.cat([hs_0, rs_0], dim=1)))
        ts_0 = torch.tanh(self.tail_extractor(torch.cat([ts_0, rs_0], dim=1)))
        b1_0 = hs_0.view(-1, self.emb_size // self.block_size, self.block_size)
        b2_0 = ts_0.view(-1, self.emb_size // self.block_size, self.block_size)
        bl_0 = (b1_0.unsqueeze(3) * b2_0.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
        ent_nums = int(np.sqrt(hs_0.size()[0]))
        #bl_0 = bl_0.view( 1, ent_nums, ent_nums, -1)

        ent_lens = [len(x) for x in entity_pos[0]]
        cum_sum = [0] + list(accumulate(ent_lens))
        
        '''       
        hs_m, rs_m, ts_m= self.get_mention_hrt(sequence_output_0, attention_0, mention_pos, mention_hts)
        hs_m = torch.tanh(self.head_extractor(torch.cat([hs_m, rs_m], dim=1)))
        ts_m = torch.tanh(self.tail_extractor(torch.cat([ts_m, rs_m], dim=1)))
        b1_m = hs_m.view(-1, self.emb_size // self.block_size, self.block_size)
        b2_m = ts_m.view(-1, self.emb_size // self.block_size, self.block_size)
        bl_m = (b1_m.unsqueeze(3) * b2_m.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
        men_nums = int(np.sqrt(hs_m.size()[0]))
        
        feature = self.projection(bl_0) 
        mention_feature = self.projection(bl_m)

        to_pool = batch_index(mention_feature.unsqueeze(0), padded_mention[0].to(sequence_output_0).long().unsqueeze(0))
        to_pool += (padded_mention_mask[0].to(sequence_output_0).unsqueeze(-1) == 0).float() * (-1e4)
        to_pool = to_pool.max(dim=2)[0].squeeze()
        '''
        #feature = feature 
        feature = self.axial_transformer(feature.view(1, ent_nums, ent_nums, -1) ) + feature
        logits_0 = self.classifier(feature)
        #logits_0 = self.bilinear(bl_0)
        logits_0 = logits_0.view(-1, self.config.num_labels)
        
          
        output = (self.loss_fnt.get_label(logits_0, num_labels=self.num_labels), logits_0)
        if labels is not None:
            labels = [torch.tensor(label) for label in labels]
            labels = torch.cat(labels, dim=0).to(logits_0)
            label_idx = torch.max(labels, dim = 1 )[1]
            non_zero_idx = torch.where(label_idx != 0)[0]
            loss_0 = self.loss_fnt(logits_0.float(), labels.float())
            g_x = torch.autograd.grad(loss_0, logits_0, retain_graph=True)[0]
            g_x_norm = torch.norm(g_x, dim=1)
            g_x_norm = torch.mean(g_x_norm)
            output_0 = loss_0.to(sequence_output_0)
            if self.teacher_model is not None:
                teacher_logits = self.teacher_model(input_ids=input_ids, attention_mask=attention_mask,
                                                    entity_pos=entity_pos, hts=hts, 
                                                    mention_pos = mention_pos, mention_hts=mention_hts, 
                                                    padded_mention=padded_mention, padded_mention_mask=padded_mention_mask,
                                                    return_logits=True)
                teacher_logits = teacher_logits.detach()
                kl_loss = F.kl_div(F.log_softmax(logits_0, dim=-1), F.softmax(teacher_logits, dim=-1), reduction='mean')
                output_0 = output_0 + kl_loss
            output = (output_0, g_x_norm)
            
        return output
