import torch
import torch.nn as nn
from opt_einsum import contract
from long_seq import process_long_input
from losses import ATLoss
import torch.nn.functional as F
from axial_attention import AxialAttention, AxialImageTransformer
import numpy as np
import math
from itertools import accumulate
import copy

def batch_index(tensor, index, pad=False):
    if tensor.shape[0] != index.shape[0]:
        raise Exception()

    if not pad:
        return torch.stack([tensor[i][index[i]] for i in range(index.shape[0])])
    else:
        return padded_stack([tensor[i][index[i]] for i in range(index.shape[0])])


class AxialTransformer_by_entity(nn.Module):
    def  __init__(self, emb_size = 768, dropout = 0.1, num_layers = 2, dim_index = -1, heads = 8, num_dimensions=2, ):
        super().__init__()
        self.num_layers = num_layers
        self.dim_index = dim_index
        self.heads = heads
        self.emb_size = emb_size
        self.dropout = dropout
        self.num_dimensions = num_dimensions
        self.axial_attns = nn.ModuleList([AxialAttention(dim = self.emb_size, dim_index = dim_index, heads = heads, num_dimensions = num_dimensions, ) for i in range(num_layers)])
        self.ffns = nn.ModuleList([nn.Linear(self.emb_size, self.emb_size) for i in range(num_layers)] )
        self.lns = nn.ModuleList([nn.LayerNorm(self.emb_size) for i in range(num_layers)])
        self.attn_dropouts = nn.ModuleList([nn.Dropout(dropout) for i in range(num_layers)])
        self.ffn_dropouts = nn.ModuleList([nn.Dropout(dropout) for i in range(num_layers)] )
    def forward(self, x):
        for idx in range(self.num_layers):
          x = x + self.attn_dropouts[idx](self.axial_attns[idx](x))
          x = self.ffns[idx](x)
          x = self.ffn_dropouts[idx](x)
          x = self.lns[idx](x)
        return x


class AxialEntityTransformer(nn.Module):
    def  __init__(self, emb_size = 768, dropout = 0.1, num_layers = 2, dim_index = -1, heads = 8, num_dimensions=2, ):
        super().__init__()
        self.num_layers = num_layers
        self.dim_index = dim_index
        self.heads = heads
        self.emb_size = emb_size
        self.dropout = dropout
        self.num_dimensions = num_dimensions
        self.axial_img_transformer = AxialImageTransformer()
        self.axial_attns = nn.ModuleList([AxialAttention(dim = self.emb_size, dim_index = dim_index, heads = heads, num_dimensions = num_dimensions, ) for i in range(num_layers)])
        self.ffns = nn.ModuleList([nn.Linear(self.emb_size, self.emb_size) for i in range(num_layers)] )
        self.lns = nn.ModuleList([nn.LayerNorm(self.emb_size) for i in range(num_layers)])
        self.attn_dropouts = nn.ModuleList([nn.Dropout(dropout) for i in range(num_layers)])
        self.ffn_dropouts = nn.ModuleList([nn.Dropout(dropout) for i in range(num_layers)] )
    def forward(self, x):
        for idx in range(self.num_layers):
          x = x + self.attn_dropouts[idx](self.axial_attns[idx](x))
          x = self.ffns[idx](x)
          x = self.ffn_dropouts[idx](x)
          x = self.lns[idx](x)
        return x

class DocREModel(nn.Module):
    def __init__(self, config, model, emb_size=1024, block_size=64, num_labels=-1):
        super().__init__()
        self.config = config
        self.model = model
        self.hidden_size = config.hidden_size
        self.loss_fnt = ATLoss()
        self.max_ent = 42
        self.head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.mention_head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.mention_tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
        self.projection = nn.Linear(emb_size * block_size, config.hidden_size)
        self.mention_projection = nn.Linear(emb_size * block_size, config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size , config.num_labels)
        self.axial_transformer = AxialTransformer_by_entity(emb_size = config.hidden_size, dropout=0.0, num_layers=6, heads=8)
        #self.ent_num_emb = nn.Embedding(37, config.hidden_size)
        #self.axial_transformer = AxialImageTransformer(dim = config.hidden_size, depth = 8, dim_index=-1)
        self.emb_size = emb_size
        self.block_size = block_size
        self.num_labels = num_labels

    def encode(self, input_ids, attention_mask):
        config = self.config
        if config.transformer_type == "bert":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id]
        elif config.transformer_type == "roberta":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id, config.sep_token_id]
        sequence_output, attention = process_long_input(self.model, input_ids, attention_mask, start_tokens, end_tokens)
        return sequence_output, attention

    def get_hrt(self, sequence_output, attention, entity_pos, hts, sentid_mask):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        sent_embs = []
        b, seq_l, h_size = sequence_output.size()
        n_e = max([len(x) for x in entity_pos])
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            sid_mask = sentid_mask[i]
            sentids = [x for x in range(torch.max(sid_mask).cpu().long() + 1)]
            local_mask  = torch.tensor([sentids] * sid_mask.size()[0] ).T
            local_mask = torch.eq(sid_mask , local_mask).long().to(sequence_output)
            sentence_embs = local_mask.unsqueeze(2) * sequence_output[i]
            sentence_embs = torch.sum(sentence_embs, dim=1)/local_mask.unsqueeze(2).sum(dim=1)
            seq_sent_embs = sentence_embs.unsqueeze(1) * local_mask.unsqueeze(2)
            seq_sent_embs = torch.sum(seq_sent_embs, dim=0)
            sent_embs.append(seq_sent_embs)
            for e in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                if len(e) > 1:
                    e_emb, e_att = [], []
                    for start, end in e:
                        if start + offset < c:
                            # In case the entity mention is truncated due to limited max seq length.
                            #e_emb.append(sequence_output[i, start + offset] + seq_sent_embs[start + offset])
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                else:
                    start, end = e[0]
                    if start + offset < c:
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            s_ne, _ = entity_embs.size()

            ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
            
            pad_hs = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_ts = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_hs[:s_ne, :s_ne, :] = hs.view(s_ne, s_ne, h_size)
            pad_ts[:s_ne, :s_ne, :] = ts.view(s_ne, s_ne, h_size)


            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            #ht_att = (h_att * t_att).mean(1)
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-10)
            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            pad_rs = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_rs[:s_ne, :s_ne, :] = rs.view(s_ne, s_ne, h_size)
            hss.append(pad_hs)
            tss.append(pad_ts)
            rss.append(pad_rs)
        hss = torch.stack(hss, dim=0)
        tss = torch.stack(tss, dim=0)
        rss = torch.stack(rss, dim=0)
        #sent_embs = torch.cat(sent_embs, dim=0)
        return hss, rss, tss


    def get_entities(self, sequence_output, attention, entity_pos, hts, sentid_mask):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        b, seq_l, h_size = sequence_output.size()
        sent_embs = []
        b_ent_embs = []
        b_ent_atts = []
        ne = max([len(x) for x in entity_pos])
        if ne % 2 != 0:
            ne += 1
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            '''
            sid_mask = sentid_mask[i]
            sentids = [x for x in range(torch.max(sid_mask).cpu().long() + 1)]
            local_mask  = torch.tensor([sentids] * sid_mask.size()[0] ).T
            local_mask = torch.eq(sid_mask , local_mask).long().to(sequence_output)
            sentence_embs = local_mask.unsqueeze(2) * sequence_output[i]
            sentence_embs = torch.sum(sentence_embs, dim=1)/local_mask.unsqueeze(2).sum(dim=1)
            seq_sent_embs = sentence_embs.unsqueeze(1) * local_mask.unsqueeze(2)
            seq_sent_embs = torch.sum(seq_sent_embs, dim=0)
            sent_embs.append(seq_sent_embs)
            '''
            for e in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                if len(e) > 1:
                    e_emb, e_att = [], []
                    for start, end in e:
                        if start + offset < c:
                            # In case the entity mention is truncated due to limited max seq length.
                            #e_emb.append(sequence_output[i, start + offset] + seq_sent_embs[start + offset])
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                else:
                    start, end = e[0]
                    if start + offset < c:
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]

            b_ent_embs.append(entity_embs)
            b_ent_atts.append(entity_atts)
        sent_embs = torch.cat(sent_embs, dim=0)
        return b_ent_embs, b_ent_atts, sent_embs




    def get_mention_hrt(self, sequence_output, attention, entity_pos, hts, sentid_mask):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        sent_embs = []
        b, seq_l, h_size = sequence_output.size()
        n_e = max([len(x) for x in entity_pos])
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            #entity_lens = []
            #sid_mask = sentid_mask[i]
            #sentids = [x for x in range(torch.max(sid_mask).cpu().long() + 1)]
            #local_mask  = torch.tensor([sentids] * sid_mask.size()[0] ).T
            #local_mask = torch.eq(sid_mask , local_mask).long().to(sequence_output)
            #sentence_embs = local_mask.unsqueeze(2) * sequence_output[i]
            #sentence_embs = torch.sum(sentence_embs, dim=1)/local_mask.unsqueeze(2).sum(dim=1)
            #seq_sent_embs = sentence_embs.unsqueeze(1) * local_mask.unsqueeze(2)
            #seq_sent_embs = torch.sum(seq_sent_embs, dim=0)
            #sent_embs.append(seq_sent_embs)
            for e in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                if len(e) > 1:
                    e_emb, e_att = [], []
                    for start, end in e:
                        if start + offset < c:
                            # In case the entity mention is truncated due to limited max seq length.
                            #e_emb.append(sequence_output[i, start + offset] + seq_sent_embs[start + offset])
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                else:
                    start, end = e[0]
                    if start + offset < c:
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            s_ne, _ = entity_embs.size()

            ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
            
            pad_hs = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_ts = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_hs[:s_ne, :s_ne, :] = hs.view(s_ne, s_ne, h_size)
            pad_ts[:s_ne, :s_ne, :] = ts.view(s_ne, s_ne, h_size)


            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            #ht_att = (h_att * t_att).mean(1)
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-10)

            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            pad_rs = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_rs[:s_ne, :s_ne, :] = rs.view(s_ne, s_ne, h_size)
            hss.append(pad_hs)
            tss.append(pad_ts)
            rss.append(pad_rs)
        hss = torch.stack(hss, dim=0)
        tss = torch.stack(tss, dim=0)
        rss = torch.stack(rss, dim=0)
        #sent_embs = torch.cat(sent_embs, dim=0)
        return hss, rss, tss

    def compute_kl_loss(self, p, q, pad_mask=None):
    
        p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
        q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
    
        # pad_mask is for seq-level tasks
        if pad_mask is not None:
            p_loss.masked_fill_(pad_mask, 0.)
            q_loss.masked_fill_(pad_mask, 0.)

        # You can choose whether to use function "sum" and "mean" depending on your task
        p_loss = p_loss.sum()
        q_loss = q_loss.sum()

        loss = (p_loss + q_loss) / 2
        return loss

    def get_mask(self, ents, bs, ne, run_device):
        ent_mask = torch.zeros(bs, ne, device=run_device)
        rel_mask = torch.zeros(bs, ne, ne, device=run_device)
        for _b in range(bs):
            ent_mask[_b, :len(ents[_b])] = 1
            rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1
        return ent_mask, rel_mask


    def get_ht(self, rel_enco, hts):
        htss = []
        for i in range(len(hts)):
            ht_index = hts[i]
            for (h_index, t_index) in ht_index:
                htss.append(rel_enco[i,h_index,t_index])
        htss = torch.stack(htss,dim=0)
        return htss

    def get_entity_table(self, sequence_output, entity_as):

        bs,_,d = sequence_output.size()
        ne = self.max_ent

        index_pair = []
        for i in range(ne):
            tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
            index_pair.append(tmp)
        index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
        map_rss = []
        for b in range(bs):
            entity_atts = entity_as[b]
            h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
            t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
            ht_att = (h_att * t_att).mean(1)
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            rs = contract("ld,rl->rd", sequence_output[b], ht_att)
            map_rss.append(rs)
        map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
        return map_rss


    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None,
                entity_pos=None,
                hts=None,
                mention_pos=None,
                mention_hts=None,
                padded_mention=None,
                padded_mention_mask=None,
                sentid_mask=None,
                instance_mask=None,
                teacher_logits=None,
                return_logits=False,
                ):

        sequence_output_0, attention_0 = self.encode(input_ids, attention_mask)
        bs, seq_len, h_size = sequence_output_0.size()

        device = sequence_output_0.device.index
        ne = max([len(x) for x in entity_pos])
        nes = [len(x) for x in entity_pos]
        ent_mask, rel_mask = self.get_mask(entity_pos, bs, ne, device)
        hs_0, rs_0, ts_0 = self.get_hrt(sequence_output_0, attention_0, entity_pos, hts, sentid_mask)
 
        hs_0 = torch.tanh(self.head_extractor(torch.cat([hs_0, rs_0], dim=3)))
        ts_0 = torch.tanh(self.tail_extractor(torch.cat([ts_0, rs_0], dim=3)))
     
        b1_0 = hs_0.view(bs, ne, ne, self.emb_size // self.block_size, self.block_size)
        b2_0 = ts_0.view(bs, ne, ne, self.emb_size // self.block_size, self.block_size)

        #print(b2_0.size())
        bl_0 = (b1_0.unsqueeze(5) * b2_0.unsqueeze(4)).view(bs, ne, ne, self.emb_size * self.block_size)
        #print(bl_0.size())
        ent_nums = int(np.sqrt(hs_0.size()[0]))
        #bl_0 = bl_0.view( 1, ent_nums, ent_nums, -1)

        ent_lens = [len(x) for x in entity_pos[0]]
        cum_sum = [0] + list(accumulate(ent_lens))
        
        feature = self.projection(bl_0) 
        '''
        hs_m, rs_m, ts_m= self.get_mention_hrt(sequence_output_0, attention_0, mention_pos, mention_hts, sent_embs)
        hs_m = torch.tanh(self.head_extractor(torch.cat([hs_m, rs_m], dim=1)))
        ts_m = torch.tanh(self.tail_extractor(torch.cat([ts_m, rs_m], dim=1)))
        b1_m = hs_m.view(-1, self.emb_size // self.block_size, self.block_size)
        b2_m = ts_m.view(-1, self.emb_size // self.block_size, self.block_size)
        bl_m = (b1_m.unsqueeze(3) * b2_m.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
        men_nums = int(np.sqrt(hs_m.size()[0]))
        
        feature = self.projection(bl_0) 
        mention_feature = self.projection(bl_m)

        to_pool = batch_index(mention_feature.unsqueeze(0), padded_mention[0].to(sequence_output_0).long().unsqueeze(0))
        to_pool += (padded_mention_mask[0].to(sequence_output_0).unsqueeze(-1) == 0).float() * (-1e4)
        to_pool = to_pool.max(dim=2)[0].squeeze()
        
        feature = feature + to_pool
        '''

        feature = self.axial_transformer(feature)
        
        #feature = feature.permute(0,2,3,1)
        logits_0 = self.classifier(feature)
        #logits_0 = self.bilinear(bl_0)
        logits_0 = torch.cat([logits_0[x, :nes[x], :nes[x] , :].reshape(-1, self.config.num_labels) for x in range(len(nes))])
        logits_0 = logits_0.view(-1, self.config.num_labels)
        
          
        output = (self.loss_fnt.get_label(logits_0, num_labels=self.num_labels), logits_0)
        if return_logits==True:
            output = logits_0
        if labels is not None:
            labels = [torch.tensor(label) for label in labels]
            labels = torch.cat(labels, dim=0).to(logits_0)
            label_idx = torch.max(labels, dim = 1 )[1]
            non_zero_idx = torch.where(label_idx != 0)[0]
            loss_0, loss1, loss2 = self.loss_fnt(logits_0.float(), labels.float())
            #g_x = torch.autograd.grad(loss_0, logits_0, retain_graph=True)[0]
            #g_x_norm = torch.norm(g_x, dim=1)
            #g_x_norm = torch.mean(g_x_norm)
            
            output_0 = loss_0.to(sequence_output_0)
            output = (output_0, loss1, loss2)
            
        return output


class DocREModel_KD(nn.Module):
    def __init__(self, config, model, emb_size=1024, block_size=64, num_labels=-1, teacher_model=None):
        super().__init__()
        self.config = config
        self.model = model
        self.hidden_size = config.hidden_size
        self.emb_size = emb_size
        self.block_size = block_size
        self.loss_fnt = ATLoss()
        if teacher_model is not None:
            self.teacher_model = teacher_model
            self.teacher_model.requires_grad = False
            self.teacher_model.eval()
        else:
            self.teacher_model = None
        self.head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.mention_head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        #self.entity_classifier = nn.Linear( config.hidden_size, 7)
        self.bin_classifier = nn.Linear( config.hidden_size, 5)
        self.entity_type_embeddings = nn.Embedding( 7, config.hidden_size)
        self.entity_criterion = nn.CrossEntropyLoss()
        self.bin_criterion = nn.CrossEntropyLoss()
        self.mention_tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
        self.projection = nn.Linear(emb_size * block_size, config.hidden_size)
        self.mention_projection = nn.Linear(emb_size * block_size, config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size , config.num_labels)
        self.mse_criterion = nn.MSELoss()
        self.axial_transformer = AxialTransformer_by_entity(emb_size = config.hidden_size, dropout=0.0, num_layers=6, heads=8)
        self.emb_size = emb_size
        self.threshold = nn.Threshold(0,0)
        self.block_size = block_size
        self.num_labels = num_labels
    def encode(self, input_ids, attention_mask):
        config = self.config
        if config.transformer_type == "bert":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id]
        elif config.transformer_type == "roberta":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id, config.sep_token_id]
        sequence_output, attention = process_long_input(self.model, input_ids, attention_mask, start_tokens, end_tokens)
        return sequence_output, attention

    def get_hrt(self, sequence_output, attention, entity_pos, hts):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        sent_embs = []
        batch_entity_embs = []
        b, seq_l, h_size = sequence_output.size()
        #n_e = max([len(x) for x in entity_pos])
        n_e = 42
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            '''
            sid_mask = sentid_mask[i]
            sentids = [x for x in range(torch.max(sid_mask).cpu().long() + 1)]
            local_mask  = torch.tensor([sentids] * sid_mask.size()[0] ).T
            local_mask = torch.eq(sid_mask , local_mask).long().to(sequence_output)
            sentence_embs = local_mask.unsqueeze(2) * sequence_output[i]
            sentence_embs = torch.sum(sentence_embs, dim=1)/local_mask.unsqueeze(2).sum(dim=1)
            seq_sent_embs = sentence_embs.unsqueeze(1) * local_mask.unsqueeze(2)
            seq_sent_embs = torch.sum(seq_sent_embs, dim=0)
            sent_embs.append(seq_sent_embs)
            '''
            #entity_type = torch.tensor(entity_types[i]).to(sequence_output).long()
            #print(entity_type.size())
            #entity_type_embs = self.entity_type_embeddings(entity_type)
            for e in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                if len(e) > 1:
                    e_emb, e_att = [], []
                    for start, end in e:
                        if start + offset < c:
                            # In case the entity mention is truncated due to limited max seq length.
                            #e_emb.append(sequence_output[i, start + offset] + seq_sent_embs[start + offset])
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                     

                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                else:
                    start, end = e[0]
                    if start + offset < c:
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)
                

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            s_ne, _ = entity_embs.size()

            ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
            
            pad_hs = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_ts = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_hs[:s_ne, :s_ne, :] = hs.view(s_ne, s_ne, h_size)
            pad_ts[:s_ne, :s_ne, :] = ts.view(s_ne, s_ne, h_size)


            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            #print(h_att.size())
            #ht_att = (h_att * t_att).mean(1)
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-10)
            
            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            pad_rs = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_rs[:s_ne, :s_ne, :] = rs.view(s_ne, s_ne, h_size)
            hss.append(pad_hs)
            tss.append(pad_ts)
            rss.append(pad_rs)
            batch_entity_embs.append(entity_embs)
        hss = torch.stack(hss, dim=0)
        tss = torch.stack(tss, dim=0)
        rss = torch.stack(rss, dim=0)
        batch_entity_embs = torch.cat(batch_entity_embs, dim=0)
        return hss, rss, tss, batch_entity_embs

    def get_hrt_by_segment(self, sequence_output, attention, entity_pos, hts, segment_span):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        sent_embs = []
        batch_entity_embs = []
        #print(sequence_output.size(), attention.size())
        segment_start, segment_end = segment_span
        seg_start_idx = 0
        b, seq_l, h_size = sequence_output.size()
        n_e = max([len(x) for x in entity_pos])
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            mask = []
            logit_mask = torch.zeros((n_e, n_e))
            for e_pos in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                e_pos = [x for x in e_pos if (x[0] >= segment_start)  and (x[1] < segment_end)]
                #print(len(e_pos))
                
                if len(e_pos) > 1:
                    e_emb, e_att = [], []
                    exist = 1 
                    for start, end in e_pos:
                        start = start - segment_start
                        end = start - segment_start
                        if   start + offset < c :
                            # In case the entity mention is truncated due to limited max seq length.
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                elif len(e_pos) == 1:
                    start, end = e_pos[0]
                    start = start - segment_start
                    end = start - segment_start
                    exist = 1 
                    if  start + offset < c :
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                elif len(e_pos) == 0:
                    exist = 0
                    e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                    e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)
                mask.append(exist)
               
            for i_e in range(n_e):
                for j_e in range(n_e):
                    if mask[i_e]==1 and mask[j_e]==1:
                        logit_mask[i_e, j_e] = 1
            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]

            #entity_embs = entity_embs + entity_type_embs
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            s_ne, _ = entity_embs.size()

            ht_i = torch.LongTensor(hts[0]).to(sequence_output.device)
                        
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0]).view(s_ne, s_ne, h_size)
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1]).view(s_ne, s_ne, h_size)
            
            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            #ht_att = (h_att * t_att).mean(1)
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-10)
            rs = contract("ld,rl->rd", sequence_output[0], ht_att).view(s_ne, s_ne, h_size)

            hss.append(hs)
            tss.append(ts)
            rss.append(rs)
        hss = torch.stack(hss, dim=0)
        tss = torch.stack(tss, dim=0)
        rss = torch.stack(rss, dim=0)
        return hss, rss, tss, logit_mask
    
    def get_hrt_by_two_segment(self, sequence_output, attention, entity_pos, hts, segment_span):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        sent_embs = []
        batch_entity_embs = []
        #print(sequence_output.size(), attention.size())
        segment0_start, segment0_end = segment_span[0]
        segment0_len = segment0_end - segment0_start
        segment1_start, segment1_end = segment_span[1]
        segment1_len = segment1_end - segment1_start

        seg_start_idx = 0
        b, seq_l, h_size = sequence_output.size()
        n_e = max([len(x) for x in entity_pos])
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            mask = []
            logit_mask = torch.zeros((n_e, n_e))
            for e_pos in entity_pos[i]:
                #entity_lens.append(self.ent_num_emb(torch.tensor(len(e)).to(sequence_output).long()))
                s0_pos = [x for x in e_pos if (x[0] >= segment0_start)  and (x[1] < segment0_end)]
                s1_pos = [x for x in e_pos if (x[0] >= segment1_start)  and (x[1] < segment1_end)]
                #print(len(e_pos))
                
                if len(s0_pos + s1_pos) > 1:
                    e_emb, e_att = [], []
                    exist = 1 
                    for start, end in s0_pos:
                        start = start - segment0_start
                        end = start - segment0_start
                        if   start + offset < c :
                            # In case the entity mention is truncated due to limited max seq length.
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    for start, end in s1_pos:
                        start = start - segment1_start + segment0_len
                        end = start - segment1_start + segment0_len
                        if   start + offset < c :
                            # In case the entity mention is truncated due to limited max seq length.
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                        
                elif len(s0_pos) == 1:
                    start, end = s0_pos[0]
                    start = start - segment0_start
                    end = start - segment0_start
                    exist = 1 
                    if  start + offset < c :
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                        
                elif len(s1_pos) == 1:
                    start, end = s1_pos[0]
                    start = start - segment1_start + segment0_len
                    end = start - segment1_start + segment0_len
                    exist = 1 
                    if  start + offset < c :
                        #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                        
                elif len(s0_pos + s1_pos) == 0:
                    exist = 0
                    e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                    e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)
                mask.append(exist)
               
            for i_e in range(n_e):
                for j_e in range(n_e):
                    if mask[i_e]==1 and mask[j_e]==1:
                        logit_mask[i_e, j_e] = 1
            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]

            #entity_embs = entity_embs + entity_type_embs
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            s_ne, _ = entity_embs.size()

            ht_i = torch.LongTensor(hts[0]).to(sequence_output.device)
                        
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0]).view(s_ne, s_ne, h_size)
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1]).view(s_ne, s_ne, h_size)
            
            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            #ht_att = (h_att * t_att).mean(1)
            m = torch.nn.Threshold(0,0)
            ht_att = m((h_att * t_att).sum(1))
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-10)
            rs = contract("ld,rl->rd", sequence_output[0], ht_att).view(s_ne, s_ne, h_size)

            hss.append(hs)
            tss.append(ts)
            rss.append(rs)
        hss = torch.stack(hss, dim=0)
        tss = torch.stack(tss, dim=0)
        rss = torch.stack(rss, dim=0)
        return hss, rss, tss, logit_mask
    
    
    def get_mention_hrt(self, sequence_output, attention, entity_pos, hts, sentid_mask):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        sent_embs = []
        batch_entity_embs = []
        b, seq_l, h_size = sequence_output.size()
        n_e = max([len(x) for x in entity_pos])
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            entity_lens = []
            for e in entity_pos[i]:
                start, end = e
                if start + offset < c:
                    #e_emb = sequence_output[i, start + offset] + seq_sent_embs[start + offset]
                    e_emb = sequence_output[i, start + offset]
                    e_att = attention[i, :, start + offset]
                else:
                    e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
                    e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]
            s_ne, _ = entity_embs.size()

            ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
            
            pad_hs = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_ts = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_hs[:s_ne, :s_ne, :] = hs.view(s_ne, s_ne, h_size)
            pad_ts[:s_ne, :s_ne, :] = ts.view(s_ne, s_ne, h_size)


            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            m = torch.nn.Threshold(0,0)
            
            ht_att = m((h_att * t_att).sum(1))
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-10)
            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            pad_rs = torch.zeros((n_e, n_e, h_size)).to(sequence_output)
            pad_rs[:s_ne, :s_ne, :] = rs.view(s_ne, s_ne, h_size)
            hss.append(pad_hs)
            tss.append(pad_ts)
            rss.append(pad_rs)
            batch_entity_embs.append(entity_embs)
        hss = torch.stack(hss, dim=0)
        tss = torch.stack(tss, dim=0)
        rss = torch.stack(rss, dim=0)
        batch_entity_embs = torch.cat(batch_entity_embs, dim=0)
        return hss, rss, tss, batch_entity_emb


    def encode_by_segment(self, input_ids, attention_mask, sentid_mask, ctx_window, stride):
        bsz, seq_len = input_ids.size()
        if seq_len <= ctx_window:
            segment_output, segment_attn = self.encode(input_ids,  attention_mask)
            return segment_output, segment_attn, [(0, seq_len)]
        else:
            segments = math.ceil((seq_len - ctx_window)/stride)
            batch_input_ids = []
            batch_input_attn = []
            segment_spans = []
            context_sz = 100
            max_len = stride * segments + ctx_window
            segment_input = torch.zeros((max_len)).to(input_ids)
            segment_attention = torch.zeros((max_len)).to(input_ids)
            segment_input[:seq_len] =  input_ids.squeeze(0)
            segment_attention[:seq_len] =  attention_mask.squeeze(0)
            for i in range(segments + 1):
                batch_input_ids.append(segment_input[ i * stride: i * stride + ctx_window])
                batch_input_attn.append(segment_attention[ i * stride: i * stride + ctx_window])
                segment_spans.append((i * stride, i * stride + ctx_window))
            batch_input_ids = torch.stack(batch_input_ids ,dim=0)
            batch_input_attn = torch.stack(batch_input_attn, dim=0)
            segment_output, segment_attn = self.encode(batch_input_ids,  batch_input_attn)
            return segment_output, segment_attn, segment_spans
 
    def encode_by_sentence(self, input_ids, attention_mask, sentid_mask, ctx_window, stride):
        bsz, seq_len = input_ids.size()
        
        segments = math.ceil((seq_len - ctx_window)/stride)        
        max_len = ctx_window * segments
        segment_input = torch.zeros((max_len)).to(input_ids)
        segment_attention = torch.zeros((max_len)).to(input_ids)
        segment_input[:seq_len] =  input_ids.squeeze(0)
        segment_attention[:seq_len] =  attention_mask.squeeze(0)
        segment_input = segment_input.view(segments, ctx_window).long()
        segment_attention = segment_attention.view(segments, ctx_window).long()
        segment_output, segment_attn = self.encode(segment_input,  segment_attention)
        return segment_output, segment_attn


    def get_ensembled_logits(self, segment_output, segment_attention, segment_spans, entity_pos, hts):
        ensemble_logits = []
        ensemble_masks = []
        for i in range(len(segment_spans)):
            tmp_logits, tmp_mask = self.get_logits_by_segment(segment_spans[i], segment_output[i], segment_attention[i], entity_pos, hts)
            #tmp_logits = self.threshold(tmp_logits[:,:]-tmp_logits[:,:1])
            ensemble_logits.append(tmp_logits)
            ensemble_masks.append(tmp_mask)
        ensemble_logits = torch.stack(ensemble_logits, dim = -2)
        ensemble_masks = torch.stack(ensemble_masks, dim = -2)
        
        ensemble_logits = ensemble_logits.sum(-2)/ (ensemble_masks.sum(-2).unsqueeze(-1) + 1e-10)
        ensemble_masks = ensemble_masks.sum(-2).unsqueeze(-1)
        return ensemble_logits, ensemble_masks

    def create_copy_mod(self):
        device = next(self.classifier.parameters()).device
        self.local_head_extractor = nn.Linear(2 * self.config.hidden_size, self.emb_size).to(device)
        self.local_tail_extractor = nn.Linear(2 * self.config.hidden_size, self.emb_size).to(device)
        self.local_classifier = nn.Linear(self.config.hidden_size , self.config.num_labels).to(device)
        self.local_projection = nn.Linear(self.emb_size * self.block_size, self.config.hidden_size).to(device)
        self.local_axial_transformer = AxialTransformer_by_entity(emb_size = self.config.hidden_size, dropout=0.0, num_layers=6, heads=8).to(device)
        
        self.local_classifier.load_state_dict(self.classifier.state_dict())
        self.local_head_extractor.load_state_dict(self.head_extractor.state_dict())
        self.local_tail_extractor.load_state_dict(self.tail_extractor.state_dict())
        self.local_projection.load_state_dict(self.projection.state_dict())
        self.local_axial_transformer.load_state_dict(self.axial_transformer.state_dict())
    
    def get_logits_by_segment(self, segment_span, sequence_output, attention, entity_pos, hts):
        seg_start, seg_end = segment_span
        sequence_output = sequence_output.unsqueeze(0)
        attention = attention.unsqueeze(0)
        bs, seq_len, h_size = sequence_output.size()
        ne = len(entity_pos[0])
        #hs_e, rs_e, ts_e, logit_mask = self.get_hrt_by_two_segment(sequence_output, attention, entity_pos, hts, segment_span)
        hs_e, rs_e, ts_e, logit_mask = self.get_hrt_by_segment(sequence_output, attention, entity_pos, hts, segment_span)
        #print(hs.size())
        logit_mask = torch.tensor(logit_mask).clone().to(sequence_output).detach()

        hs_e = torch.tanh(self.head_extractor(torch.cat([hs_e, rs_e], dim=3)))        
        ts_e = torch.tanh(self.tail_extractor(torch.cat([ts_e, rs_e], dim=3)))   
        b1_e = hs_e.view(bs, ne, ne, self.emb_size // self.block_size, self.block_size)
        b2_e = ts_e.view(bs, ne, ne, self.emb_size // self.block_size, self.block_size)
        bl_e = (b1_e.unsqueeze(5) * b2_e.unsqueeze(4)).view(bs, ne, ne, self.emb_size * self.block_size)
        
        feature = self.projection(bl_e) 
        feature = self.axial_transformer(feature) + feature
        logits = self.classifier(feature).squeeze()
        self_mask = (1 - torch.diag(torch.ones((ne)))).unsqueeze(-1).to(sequence_output)
        logits = logits * logit_mask.unsqueeze(-1)
        logits = logits * self_mask

        return logits, logit_mask

    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None,
                entity_pos=None,
                hts=None,
                mention_pos=None,
                mention_hts=None,
                padded_mention=None,
                padded_mention_mask=None,
                sentid_mask=None,
                return_logits = None,
                teacher_logits = None,
                entity_types = None,
                segment_spans = None,
                instance_mask=None,
                ):

        sequence_output, attention = self.encode(input_ids, attention_mask)
        bs, seq_len, h_size = sequence_output.size()
        bs, num_heads, seq_len, seq_len = attention.size()
        ctx_window = 300
        stride = 25
        
        #ne = max([len(x) for x in entity_pos])
        ne = 42
        nes = [len(x) for x in entity_pos]
        hs_e, rs_e, ts_e, batch_entity_embs = self.get_hrt(sequence_output, attention, entity_pos, hts)
        
        hs_e = torch.tanh(self.head_extractor(torch.cat([hs_e, rs_e], dim=3)))
        ts_e = torch.tanh(self.tail_extractor(torch.cat([ts_e, rs_e], dim=3)))   
        b1_e = hs_e.view(bs, ne, ne, self.emb_size // self.block_size, self.block_size)
        b2_e = ts_e.view(bs, ne, ne, self.emb_size // self.block_size, self.block_size)
        bl_e = (b1_e.unsqueeze(5) * b2_e.unsqueeze(4)).view(bs, ne, ne, self.emb_size * self.block_size)


        feature = self.projection(bl_e) 
        feature = self.axial_transformer(feature) + feature
        logits_e = self.classifier(feature)
        self_mask = (1 - torch.diag(torch.ones((ne)))).unsqueeze(0).unsqueeze(-1).to(sequence_output)
        logits_e = logits_e * self_mask
        #print(logits_e.size())
        logits_e = torch.cat([logits_e[x, :nes[x], :nes[x] , :].reshape(-1, self.config.num_labels) for x in range(len(nes))])
        

        
        #ent_lens = [len(x) for x in entity_pos[0]]
        #cum_sum = [0] + list(accumulate(ent_lens))
        #nm = max([len(x) for x in mention_pos])
        
        ##ToDo: put m2e in preprocessing
        #m2e = []
        #for i, entity in enumerate(entity_pos[0]):
        #    for men_pos in entity:
        #        m2e.append(i)
        #men2ent = torch.tensor(m2e)
        '''
        if segment_spans is None:
            segment_spans = [(0,seq_len)]
        '''
        #logits_tmp = self.get_logits_by_segment(segment_span, segment_output[1], seg_attention[1], entity_pos, hts)
        #segment_spans = [(0, 100), (100, seq_len)]
        #logits_seg, mask_seg = self.get_ensembled_logits(segment_output, seg_attention, segment_spans, entity_pos, hts)
        #logits_seg, mask_seg = self.get_logits_by_segment(segment_spans, sequence_output, attention, entity_pos, hts)
        
        #bin_mask_seg = (1 - torch.eq(mask_seg, 0).long()).to(sequence_output)
        #logits_tmp = self.get_logits_by_segment(segment_span, sequence_output[0,:100,:], attention[0,:,:100,:100], entity_pos, hts)
        #print(logits_tmp.size())
        
        #print(rel_idx)

        logits_e = logits_e.view(-1, self.config.num_labels) 
        #logits_seg = logits_seg.view(-1, self.config.num_labels) 
        
        #logits = 1.0 * logits_e  +  0.0 * logits_seg
        logits = 1.0 * logits_e

        #exit()
        output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels), logits)
        if labels is not None:
            labels = [torch.tensor(label) for label in labels]
            labels = torch.cat(labels, dim=0).to(logits)
            
            #segment_labels = bin_mask_seg.view(-1, 1) * labels.clone()
            #label_num = labels.sum(1)
            #num_loss = self.bin_criterion(logits_num, label_num.long())
            
            loss_e, loss1, loss2 = self.loss_fnt(logits_e.float(), labels.float())
            #loss_s, loss1_s, loss2_s = self.loss_fnt(logits_seg.float(), segment_labels.float())
            #output = loss_e.to(sequence_output) + loss_s.to(sequence_output)
            output = loss_e.to(sequence_output)
            '''
            if entity_types is not None:
                entity_types = torch.tensor(entity_types).long().to(logits_0)
                entity_type_preds = self.entity_classifier(batch_entity_embs)
                ent_loss = self.entity_criterion(entity_type_preds, entity_types.long())
                output_0 = output_0 + 0.1*ent_loss
            '''
            if teacher_logits is not None:
                teacher_logits = torch.cat(teacher_logits, dim=0).to(logits)
                mse_loss = self.mse_criterion(logits, teacher_logits)
                #print(teacher_logits.size())
                #print(logits_0.size())
                #print(labels.size())
                #kl_loss = F.kl_div(F.log_softmax(logits_0, dim=-1), F.softmax(teacher_logits, dim=-1), reduction='mean')
                logits = logits + 1.0 *  mse_loss
                #print(mse_loss)
                
                #print(loss_0)
                #exit()
            output = (output, loss1, loss2)            
        return output
