from concurrent.futures.thread import _threads_queues
import json
import random
from functools import partial
import pdb
from turtle import pd
import numpy as np
import redis
import sklearn
import torch
from eveliver import (Logger, load_model, tensor_to_obj)
from _trainer import Trainer, TrainerCallback
from transformers import AutoTokenizer, BertModel
from opt_einsum import contract
from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention
from matrix_transformer import Encoder as MatTransformer
from graph_encoder import Encoder as GraphEncoder
from bcrt_transformer import Encoder as BCRTransformer
from torch import nn
import torch.nn.functional as F
import os
from tqdm import tqdm
from buffer import Buffer
from utils import CAPACITY, BLOCK_SIZE, DEFAULT_MODEL_NAME, contrastive_pair
from torch.nn import CrossEntropyLoss
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from itertools import groupby
from pyg_graph import create_edges, create_graph, GCN, Attention, create_graph_single
from utils import DotProductSimilarity
def eval_performance(facts, pred_result):
    sorted_pred_result = sorted(pred_result, key=lambda x: x['score'], reverse=True)
    prec = []
    rec = []
    correct = 0
    total = len(facts)
    #pdb.set_trace()
    for i, item in enumerate(sorted_pred_result):
        if (item['entpair'][0], item['entpair'][1], item['relation']) in facts:
            correct += 1
        prec.append(float(correct) / float(i + 1))
        rec.append(float(correct) / float(total))
    auc = sklearn.metrics.auc(x=rec, y=prec)
    np_prec = np.array(prec)
    np_rec = np.array(rec)
    f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max()
    mean_prec = np_prec.mean()
    return {'prec': np_prec.tolist(), 'rec': np_rec.tolist(), 'mean_prec': mean_prec, 'f1': f1, 'auc': auc}

def expand(start, end, total_len, max_size):
    e_size = max_size - (end - start)
    _1 = start - (e_size // 2)
    _2 = end + (e_size - e_size // 2)
    if _2 - _1 <= total_len:
        if _1 < 0:
            _2 -= -1
            _1 = 0
        elif _2 > total_len:
            _1 -= (_2 - total_len)
            _2 = total_len
    else:
        _1 = 0
        _2 = total_len
    return _1, _2


def place_train_data(dataset):
    ep2d = dict()
    for key, doc1, doc2, label in dataset:
        if key not in ep2d:
            ep2d[key] = dict()
        if label not in ep2d[key]:
            ep2d[key][label] = list()
        ep2d[key][label].append([doc1, doc2, label])
    bags = list()
    for key, l2docs in ep2d.items():
        if len(l2docs) == 1 and 'n/a' in l2docs:
            bags.append([key, 'n/a', l2docs['n/a'], 'o'])
        else:
            labels = list(l2docs.keys())
            for label in labels:
                if label != 'n/a':
                    ds = l2docs[label]
                    if 'n/a' in l2docs:
                        ds.extend(l2docs['n/a'])
                    bags.append([key, label, ds, 'o'])
    bags.sort(key=lambda x: x[0] + '#' + x[1])
    return bags


def place_dev_data(dataset, single_path):
    ep2d = dict()
    for key, doc1, doc2, label in dataset:
        if key not in ep2d:
            ep2d[key] = dict()
        if label not in ep2d[key]:
            ep2d[key][label] = list()
        ep2d[key][label].append([doc1, doc2, label])
    bags = list()
    for key, l2docs in ep2d.items():
        if len(l2docs) == 1 and 'n/a' in l2docs:
            bags.append([key, ['n/a'], l2docs['n/a'], 'o'])
        else:
            labels = list(l2docs.keys())
            ds = list()
            for label in labels:
                if single_path and label != 'n/a':
                    ds.append(random.choice(l2docs[label]))
                else:
                    ds.extend(l2docs[label])
            if 'n/a' in labels:
                labels.remove('n/a')
            bags.append([key, labels, ds, 'o'])
    bags.sort(key=lambda x: x[0] + '#' + '#'.join(x[1]))
    return bags

def place_test_data(dataset, single_path):
    ep2d = dict()
    for data in dataset:
        key = data['h_id'] + '#' + data['t_id']
        doc1 = data['doc'][0]
        doc2 = data['doc'][1]
        label = 'n/a'
        if key not in ep2d:
            ep2d[key] = dict()
        if label not in ep2d[key]:
            ep2d[key][label] = list()
        ep2d[key][label].append([doc1, doc2, label])
    bags = list()
    for key, l2docs in ep2d.items():
        if len(l2docs) == 1 and 'n/a' in l2docs:
            bags.append([key, ['n/a'], l2docs['n/a'], 'o'])
        else:
            labels = list(l2docs.keys())
            ds = list()
            for label in labels:
                if single_path and label != 'n/a':
                    ds.append(random.choice(l2docs[label]))
                else:
                    ds.extend(l2docs[label])
            if 'n/a' in labels:
                labels.remove('n/a')
            bags.append([key, labels, ds, 'o'])
    bags.sort(key=lambda x: x[0] + '#' + '#'.join(x[1]))
    return bags

def gen_c(tokenizer, passage, span, max_len, bound_tokens, d_start, d_end, no_additional_marker, mask_entity):
    ret = list()
    ret.append(bound_tokens[0])
    for i in range(span[0], span[1]):
        if mask_entity:
            ret.append('[MASK]')
        else:
            ret.append(passage[i])
    ret.append(bound_tokens[1])
    prev = list()
    prev_ptr = span[0] - 1
    while len(prev) < max_len:
        if prev_ptr < 0:
            break
        if not no_additional_marker and prev_ptr in d_end:
            prev.append(f'[unused{(d_end[prev_ptr] + 2) * 2 + 2}]')
        prev.append(passage[prev_ptr])
        if not no_additional_marker and prev_ptr in d_start:
            prev.append(f'[unused{(d_start[prev_ptr] + 2) * 2 + 1}]')
        prev_ptr -= 1
    nex = list()
    nex_ptr = span[1]
    while len(nex) < max_len:
        if nex_ptr >= len(passage):
            break
        if not no_additional_marker and nex_ptr in d_start:
            nex.append(f'[unused{(d_start[nex_ptr] + 2) * 2 + 1}]')
        nex.append(passage[nex_ptr])
        if not no_additional_marker and nex_ptr in d_end:
            nex.append(f'[unused{(d_end[nex_ptr] + 2) * 2 + 2}]')
        nex_ptr += 1
    pn = max_len - len(ret)
    if len(prev) + len(nex) > pn:
        if len(prev) > pn / 2 and len(nex) > pn / 2:
            prev = prev[0:pn // 2]
            nex = nex[0:pn - pn // 2]
        elif len(prev) <= len(nex):
            nex = nex[0:pn - len(prev)]
        elif len(nex) < len(prev):
            prev = prev[0:pn - len(nex)]
    prev.reverse()
    ret = prev + ret + nex
    return ret


def gen_c_complete(tokenizer, passage, span, max_len, bound_tokens, d_start, d_end, no_additional_marker, mask_entity):
    ret = list() 
    ret.append(bound_tokens[0]) 
    for i in range(span[0], span[1]):
        if mask_entity:
            ret.append('[MASK]')
        else:
            ret.append(passage[i]) 
    ret.append(bound_tokens[1])
    prev = list()
    prev_ptr = span[0] - 1 
    while len(prev) < max_len:
        if prev_ptr < 0:
            break
        if not no_additional_marker and prev_ptr in d_end:
            prev.append(f'[unused{(d_end[prev_ptr] + 2) * 2 + 2}]')
        prev.append(passage[prev_ptr])
        if not no_additional_marker and prev_ptr in d_start:
            prev.append(f'[unused{(d_start[prev_ptr] + 2) * 2 + 1}]')
        prev_ptr -= 1
    nex = list()
    nex_ptr = span[1]
    while len(nex) < max_len:
        if nex_ptr >= len(passage):
            break
        if not no_additional_marker and nex_ptr in d_start:
            nex.append(f'[unused{(d_start[nex_ptr] + 2) * 2 + 1}]')
        nex.append(passage[nex_ptr])
        if not no_additional_marker and nex_ptr in d_end:
            nex.append(f'[unused{(d_end[nex_ptr] + 2) * 2 + 2}]')
        nex_ptr += 1
    prev.reverse()
    ret = prev + ret + nex
    return ret

def process(tokenizer, h, t, doc0, doc1):                 
    ht_markers = ["[unused" + str(i) + "]" for i in range(1, 5)]
    b_markers = ["[unused" + str(i) + "]" for i in range(5, 101)]
    max_blk_num = CAPACITY // (BLOCK_SIZE + 1)
    cnt, batches = 0, []
    d = []
    def fix_entity(doc, ht_markers, b_markers):
        markers = ht_markers + b_markers
        markers_pos = []
        if list(set(doc).intersection(set(markers))):
            for marker in markers:
                try:
                    pos = doc.index(marker)
                    markers_pos.append((pos, marker))
                except ValueError as e:
                    continue
        idx = 0
        while idx <= len(markers_pos)-1:
            #pdb.set_trace()
            try:
                assert (int(markers_pos[idx][1].replace("[unused", "").replace("]", "")) % 2 == 1) and (int(markers_pos[idx][1].replace("[unused", "").replace("]", "")) - int(markers_pos[idx+1][1].replace("[unused", "").replace("]", "")) == -1)
                entity_name = doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]]
                while "." in entity_name:
                    #pdb.set_trace()
                    #print(entity_name)
                    assert doc[markers_pos[idx][0] + entity_name.index(".") + 1] == "."
                    doc[markers_pos[idx][0] + entity_name.index(".") + 1] = "|"
                    #print(doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]])
                    entity_name = doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]]
                idx += 2
            except:
                #pdb.set_trace()
                idx += 1
                continue
        return doc

    d0 = fix_entity(doc0, ht_markers, b_markers)
    d1 = fix_entity(doc1, ht_markers, b_markers)
    


    question = [tokenizer.cls_token] + tokenizer.tokenize('what is the relation between' + h + 'and' + t)
    q, q_property = [question], [[('relevance', 2), ('blk_type', 1)]]
    for di in [d0, d1]:
        d.extend(di)
    qbuf, cnt = Buffer.split_document_into_blocks(q, tokenizer, properties=q_property, cnt=cnt)
    d0_buf, cnt = Buffer.split_document_into_blocks(d0, tokenizer, cnt=cnt, hard=False, docid=0)
    d1_buf, cnt = Buffer.split_document_into_blocks(d1, tokenizer, cnt=cnt, hard=False, docid=1)
    dbuf = Buffer()
    dbuf.blocks = d0_buf.blocks + d1_buf.blocks
    for blk in dbuf:
        if list(set(tokenizer.convert_tokens_to_ids(ht_markers)).intersection(set(blk.ids))):
            blk.relevance = 2
        elif list(set(tokenizer.convert_tokens_to_ids(b_markers)).intersection(set(blk.ids))):
            blk.relevance = 1
        else:
            continue
    ret = []
    n0 = 1
    n1 = 1
    n2 = 1
    lb = max_blk_num - len(qbuf)  
    st = random.randint(0, max(0, len(dbuf) - lb * n0))
    for i in range(n0):
        buf = Buffer()
        buf.blocks = qbuf.blocks + dbuf.blocks[st + i * lb:st + (i+1) * lb]
        ret.append(buf) 

    pbuf_ht, nbuf_ht = dbuf.filtered(lambda blk, idx: blk.relevance >= 2, need_residue=True) 
    pbuf_b, nbuf_b = nbuf_ht.filtered(lambda blk, idx: blk.relevance >= 1, need_residue=True) 
    for i in range(n1):
        selected_htblks = random.sample(pbuf_ht.blocks, min(lb, len(pbuf_ht)))    
        selected_pblks = random.sample(pbuf_b.blocks, min(lb - len(selected_htblks), len(pbuf_b))) 
        selected_nblks = random.sample(nbuf_b.blocks, min(lb - len(selected_pblks) - len(selected_htblks), len(nbuf_b))) 
        buf = Buffer()
        buf.blocks = qbuf.blocks + selected_htblks + selected_pblks + selected_nblks 
        ret.append(buf.sort_())

    for i in range(n2):
        _selected_htblks = random.sample(pbuf_ht.blocks, min(lb+1, len(pbuf_ht)))  
        _selected_pblks = random.sample(pbuf_b.blocks, min(lb+1 - len(_selected_htblks), len(pbuf_b))) 
        _selected_nblks = random.sample(nbuf_b.blocks, min(lb+1 - len(_selected_pblks) - len(_selected_htblks), len(nbuf_b))) 
        buf = Buffer()
        buf.blocks = _selected_htblks + _selected_pblks + _selected_nblks 
        ret.append(buf.sort_())
    ret[2][0].ids.insert(0, tokenizer.convert_tokens_to_ids(tokenizer.cls_token))

    return qbuf, dbuf, ret[1], ret[2]

def process_example(h, t, doc1, doc2, tokenizer, max_len, redisd, no_additional_marker, mask_entity):
    doc1 = json.loads(redisd.get('codred-doc-' + doc1))
    doc2 = json.loads(redisd.get('codred-doc-' + doc2))
    v_h = None
    for entity in doc1['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == h and v_h is None:
            v_h = entity
    assert v_h is not None
    v_t = None
    for entity in doc2['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == t and v_t is None:
            v_t = entity
    assert v_t is not None
    d1_v = dict()
    for entity in doc1['entities']:
        if 'Q' in entity:
            d1_v[entity['Q']] = entity
    d2_v = dict()
    for entity in doc2['entities']:
        if 'Q' in entity:
            d2_v[entity['Q']] = entity
    ov = set(d1_v.keys()) & set(d2_v.keys())
    if len(ov) > 40:
        ov = set(random.choices(list(ov), k=40))
    ov = list(ov)
    ma = dict()
    for e in ov:
        ma[e] = len(ma)
    d1_start = dict()
    d1_end = dict()
    for entity in doc1['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d1_start[span[0]] = ma[entity['Q']]
                d1_end[span[1] - 1] = ma[entity['Q']]
    d2_start = dict()
    d2_end = dict()
    for entity in doc2['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d2_start[span[0]] = ma[entity['Q']]
                d2_end[span[1] - 1] = ma[entity['Q']]
    #pdb.set_trace()
    
    k1 = gen_c(tokenizer, doc1['tokens'], v_h['spans'][0], max_len // 2 - 2, ['[unused1]', '[unused2]'], d1_start, d1_end, no_additional_marker, mask_entity)
    k2 = gen_c(tokenizer, doc2['tokens'], v_t['spans'][0], max_len // 2 - 1, ['[unused3]', '[unused4]'], d2_start, d2_end, no_additional_marker, mask_entity)
    tokens = ['[CLS]'] + k1 + ['[SEP]'] + k2 + ['[SEP]']
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    if len(token_ids) < max_len:
        token_ids = token_ids + [0] * (max_len - len(tokens))
    attention_mask = [1] * len(tokens) + [0] * (max_len - len(tokens))
    token_type_id = [0] * (len(k1) + 2) + [1] * (len(k2) + 1) + [0] * (max_len - len(tokens))
    return tokens, token_ids, token_type_id, attention_mask


def process_example_complete(h, t, doc1, doc2, tokenizer, max_len, redisd, no_additional_marker, mask_entity):
    max_len = 99999
    bert_max_len = 512
    doc1 = json.loads(redisd.get('codred-doc-' + doc1))
    doc2 = json.loads(redisd.get('codred-doc-' + doc2))
    v_h = None
    for entity in doc1['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == h and v_h is None:
            v_h = entity
    assert v_h is not None
    v_t = None
    for entity in doc2['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == t and v_t is None:
            v_t = entity
    assert v_t is not None
    d1_v = dict()
    for entity in doc1['entities']:
        if 'Q' in entity:
            d1_v[entity['Q']] = entity
    d2_v = dict()
    for entity in doc2['entities']:
        if 'Q' in entity:
            d2_v[entity['Q']] = entity
    ov = set(d1_v.keys()) & set(d2_v.keys())
    if len(ov) > 40:
        ov = set(random.choices(list(ov), k=40))
    ov = list(ov)
    ma = dict()
    for e in ov:
        ma[e] = len(ma)
    d1_start = dict()
    d1_end = dict()
    for entity in doc1['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d1_start[span[0]] = ma[entity['Q']]
                d1_end[span[1] - 1] = ma[entity['Q']]
    d2_start = dict()
    d2_end = dict()
    for entity in doc2['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d2_start[span[0]] = ma[entity['Q']]
                d2_end[span[1] - 1] = ma[entity['Q']]
    k1 = gen_c_complete(tokenizer, doc1['tokens'], v_h['spans'][0], max_len, ['[unused1]', '[unused2]'], d1_start, d1_end, no_additional_marker, mask_entity)
    k2 = gen_c_complete(tokenizer, doc2['tokens'], v_t['spans'][0], max_len, ['[unused3]', '[unused4]'], d2_start, d2_end, no_additional_marker, mask_entity)
    
    qbuf, dbuf, intro_ret, intro_ret_no_q = process(tokenizer, v_h['name'], v_t['name'], k1, k2)
    
    k1_c = []
    k2_c = []
    ht_markers = ["[unused" + str(i) + "]" for i in range(1, 5)]
    b_markers = ["[unused" + str(i) + "]" for i in range(5, 101)]
    k1_sentences = " ".join(k1).split(".")
    k2_sentences = " ".join(k2).split(".")

    for k1_sent in k1_sentences:
        k1_s_tokens = k1_sent.split(" ")[:-1]
        for k1_s_t in k1_s_tokens:
            if k1_s_t in ht_markers:
                #print(k1_s_t)
                if " ".join(k1_s_tokens) not in k1_c:
                    k1_c.append(" ".join(k1_s_tokens))
            else:
                continue
    for k1_sent in k1_sentences:
        k1_s_tokens = k1_sent.split(" ")[:-1]
        for k1_s_t in k1_s_tokens:    
            if k1_s_t in b_markers and (len(" .".join(k1_c).split(" "))+len(k1_s_tokens)) <= 255:
                #print(k1_s_t)
                if " ".join(k1_s_tokens) not in k1_c:
                    k1_c.append(" ".join(k1_s_tokens))
            else:
                continue

    for k2_sent in k2_sentences:
        k2_s_tokens = k2_sent.split(" ")[:-1]
        for k2_s_t in k2_s_tokens:
            if k2_s_t in ht_markers:
                #print(k2_s_t)
                if " ".join(k2_s_tokens) not in k2_c:
                    k2_c.append(" ".join(k2_s_tokens))
            else:
                continue
    for k2_sent in k2_sentences:
        k2_s_tokens = k2_sent.split(" ")[:-1]
        for k2_s_t in k2_s_tokens:
            if k2_s_t in b_markers and (len(" .".join(k2_c).split(" "))+len(k2_s_tokens)) <= 254:
                #print(k2_s_t)
                if " ".join(k2_s_tokens) not in k2_c:
                    k2_c.append(" ".join(k2_s_tokens))
            else:
                continue        
    
    k1_c_t = " .".join(k1_c).split(" ")
    k2_c_t = " .".join(k2_c).split(" ")
    k1_c_t_ = [i for i in k1_c_t if i != ""]
    k2_c_t_ = [j for j in k2_c_t if j != ""]
    if len(k1_c_t_) > 255:
        #pdb.set_trace()
        h_start = k1_c_t_.index('[unused1]')
        h_end = k1_c_t_.index('[unused2]')
        h_len = h_end - h_start + 1
        pre_len = h_start - 1
        nex_len = len(k1_c_t_) - h_end 
        if pre_len <= 126 - int((h_len / 2)) - 1:
            pre_ptr = 0
        else:
            pre_ptr = h_start - 126 + int((h_len / 2)) + 1
        if nex_len <= 127 - int((h_len / 2)) - 1:
            nex_ptr = -1
        else:
            nex_ptr = h_end + 127 - int((h_len / 2)) - 1
        k1_c_t_ = k1_c_t_[pre_ptr:nex_ptr] 
    if len(k2_c_t_) > 254:
        #pdb.set_trace()
        t_start = k2_c_t_.index('[unused3]')
        t_end = k2_c_t_.index('[unused4]')
        t_len = t_end - t_start + 1
        pre_len = t_start - 1
        nex_len = len(k2_c_t_) - t_end
        if pre_len <= 126 - int((t_len / 2)) - 1:
            pre_ptr = 0
        else:
            pre_ptr = t_start - 126 + int((t_len / 2)) + 1
        if nex_len <= 127 - int((t_len / 2)) - 1:
            nex_ptr = -1
        else:
            nex_ptr = t_end + 127 - int((t_len / 2)) + 1
        k2_c_t_ = k2_c_t_[pre_ptr:nex_ptr]

    tokens = ['[CLS]'] + k1_c_t_ + ['[SEP]'] + k2_c_t_ + ['[SEP]']
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    if len(token_ids) > bert_max_len:
        print(len(token_ids))
        pdb.set_trace()
    if len(token_ids) < bert_max_len:
        token_ids = token_ids + [0] * (bert_max_len - len(tokens))
    attention_mask = [1] * len(tokens) + [0] * (bert_max_len - len(tokens))
    token_type_id = [0] * (len(k1_c_t_) + 2) + [1] * (len(k2_c_t_) + 1) + [0] * (bert_max_len - len(tokens))
    return tokens, token_ids, token_type_id, attention_mask, intro_ret, intro_ret_no_q


def collate_fn(batch, args, relation2id, tokenizer, redisd):
    #assert len(batch) == 1
    if batch[0][-1] == 'o':
        batch = batch[0]
        h, t = batch[0].split('#')
        r = relation2id[batch[1]]
        dps = batch[2]
        if len(dps) > 8:
            dps = random.choices(dps, k=8)
        input_ids = list()
        token_type_ids = list()
        attention_mask = list()
        dplabel = list()
        intro_rets = list()
        intro_rets_no_q = list()
        for doc1, doc2, l in dps:
            tokens, token_ids, token_type_id, amask, intro_ret, intro_ret_no_q = process_example_complete(h, t, doc1, doc2, tokenizer, args.seq_len, redisd, args.no_additional_marker, args.mask_entity)
            
            while(tokenizer.convert_tokens_to_ids("|") in token_ids):
                token_ids[token_ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
            for i_rt in intro_ret:
                while(tokenizer.convert_tokens_to_ids("|") in i_rt.ids):
                    i_rt.ids[i_rt.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
            for i_rt_nq in intro_ret_no_q:
                while(tokenizer.convert_tokens_to_ids("|") in i_rt_nq.ids):
                    i_rt_nq.ids[i_rt_nq.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
            input_ids.append(token_ids)
            token_type_ids.append(token_type_id)
            attention_mask.append(amask)
            dplabel.append(relation2id[l])
            intro_rets.append(intro_ret)
            intro_rets_no_q.append(intro_ret_no_q)
        input_ids_t = torch.tensor(input_ids, dtype=torch.int64)
        token_type_ids_t = torch.tensor(token_type_ids, dtype=torch.int64)
        attention_mask_t = torch.tensor(attention_mask, dtype=torch.int64)
        dplabel_t = torch.tensor(dplabel, dtype=torch.int64)
        rs_t = torch.tensor([r], dtype=torch.int64)

        

        intro_inputs_no_q = torch.zeros(4, len(dps), CAPACITY, dtype=torch.int64) 
        for dp, buf in enumerate(intro_rets_no_q):           
            buf.export_01_turn(out=(intro_inputs_no_q[0, dp], intro_inputs_no_q[1, dp], intro_inputs_no_q[2, dp]))
        for dp, buf in enumerate(intro_rets_no_q):          
            buf.export_relevance(device=input_ids_t.device, out=intro_inputs_no_q[3, dp])
        intro_ids_no_q = intro_inputs_no_q[0]
        intro_att_mask_no_q = intro_inputs_no_q[1]
        intro_token_type_no_q = intro_inputs_no_q[2]
        intro_labels_no_q = intro_inputs_no_q[3]

    else:
        examples = batch[0]
        h_len = tokenizer.max_len_sentences_pair // 2 - 2
        t_len = tokenizer.max_len_sentences_pair - tokenizer.max_len_sentences_pair // 2 - 2
        _input_ids = list()
        _token_type_ids = list()
        _attention_mask = list()
        _rs = list()
        intro_rets = list()
        intro_rets_no_q = list()
        for idx, example in enumerate(examples):
            doc = json.loads(redisd.get(f'dsre-doc-{example[0]}'))
            _, h_start, h_end, t_start, t_end, r = example
            if r in relation2id:
                r = relation2id[r]
            else:
                r = 'n/a'
            #pdb.set_trace()
            h_1, h_2 = expand(h_start, h_end, len(doc), h_len)
            t_1, t_2 = expand(t_start, t_end, len(doc), t_len)
            h_tokens = doc[h_1:h_start] + ['[unused1]'] + doc[h_start:h_end] + ['[unused2]'] + doc[h_end:h_2]
            t_tokens = doc[t_1:t_start] + ['[unused3]'] + doc[t_start:t_end] + ['[unused4]'] + doc[t_end:t_2]
            h_name = doc[h_start:h_end]
            t_name = doc[t_start:t_end]
            h_token_ids = tokenizer.convert_tokens_to_ids(h_tokens)
            t_token_ids = tokenizer.convert_tokens_to_ids(t_tokens)
            qbuf, dbuf, intro_ret, intro_ret_no_q = process(tokenizer, " ".join(doc[h_start:h_end]), " ".join(doc[t_start:t_end]), h_tokens, t_tokens)
            for i_rt in intro_ret:
                while(tokenizer.convert_tokens_to_ids("|") in i_rt.ids):
                    i_rt.ids[i_rt.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
            for i_rt_nq in intro_ret_no_q:
                while(tokenizer.convert_tokens_to_ids("|") in i_rt_nq.ids):
                    i_rt_nq.ids[i_rt_nq.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
            input_ids = tokenizer.build_inputs_with_special_tokens(h_token_ids, t_token_ids)
            token_type_ids = tokenizer.create_token_type_ids_from_sequences(h_token_ids, t_token_ids)
            obj = tokenizer._pad({'input_ids': input_ids, 'token_type_ids': token_type_ids}, max_length=args.seq_len, padding_strategy='max_length')
            _input_ids.append(obj['input_ids'])
            _token_type_ids.append(obj['token_type_ids'])
            _attention_mask.append(obj['attention_mask'])
            _rs.append(r)
            intro_rets.append(intro_ret)
            intro_rets_no_q.append(intro_ret_no_q)
        input_ids_t = torch.tensor(_input_ids, dtype=torch.long)
        token_type_ids_t = torch.tensor(_token_type_ids, dtype=torch.long)
        attention_mask_t = torch.tensor(_attention_mask, dtype=torch.long)
        dplabel_t = torch.tensor(_rs, dtype=torch.long)
        rs_t = None
        r = None


        intro_inputs_no_q = torch.zeros(4, len(examples), CAPACITY, dtype=torch.int64) 
        for ex, buf in enumerate(intro_rets_no_q):         
            buf.export_01_turn(out=(intro_inputs_no_q[0, ex], intro_inputs_no_q[1, ex], intro_inputs_no_q[2, ex]))
        for ex, buf in enumerate(intro_rets_no_q):           
            buf.export_relevance(device=input_ids_t.device, out=intro_inputs_no_q[3, ex])
        intro_ids_no_q = intro_inputs_no_q[0]
        intro_att_mask_no_q = intro_inputs_no_q[1]
        intro_token_type_no_q = intro_inputs_no_q[2]
        intro_labels_no_q = intro_inputs_no_q[3]
    return input_ids_t, token_type_ids_t, attention_mask_t, dplabel_t, rs_t, [r], intro_ids_no_q, intro_att_mask_no_q, intro_token_type_no_q, intro_labels_no_q, intro_rets_no_q


def collate_fn_infer(batch, args, relation2id, tokenizer, redisd):
    batch = batch[0]
    h, t = batch[0].split('#')
    rs = [relation2id[r] for r in batch[1]]
    dps = batch[2]
    input_ids = list()
    token_type_ids = list()
    attention_mask = list()
    intro_rets = list()
    intro_rets_no_q = list()
    for doc1, doc2, l in dps:
        tokens, token_ids, token_type_id, amask, intro_ret, intro_ret_no_q = process_example_complete(h, t, doc1, doc2, tokenizer, args.seq_len, redisd, args.no_additional_marker, args.mask_entity)
        while(tokenizer.convert_tokens_to_ids("|") in token_ids):
                token_ids[token_ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
        for i_rt in intro_ret:
            while(tokenizer.convert_tokens_to_ids("|") in i_rt.ids):
                i_rt.ids[i_rt.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
        for i_rt_nq in intro_ret_no_q:
            while(tokenizer.convert_tokens_to_ids("|") in i_rt_nq.ids):
                i_rt_nq.ids[i_rt_nq.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
        input_ids.append(token_ids)
        token_type_ids.append(token_type_id)
        attention_mask.append(amask)
        intro_rets.append(intro_ret)
        intro_rets_no_q.append(intro_ret_no_q)
    input_ids_t = torch.tensor(input_ids, dtype=torch.int64)
    token_type_ids_t = torch.tensor(token_type_ids, dtype=torch.int64)
    attention_mask_t = torch.tensor(attention_mask, dtype=torch.int64)

    intro_inputs_no_q = torch.zeros(4, len(dps), CAPACITY, dtype=torch.int64) 
    for dp, buf in enumerate(intro_rets_no_q):         
        buf.export_01_turn(out=(intro_inputs_no_q[0, dp], intro_inputs_no_q[1, dp], intro_inputs_no_q[2, dp]))
    for dp, buf in enumerate(intro_rets_no_q):          
        buf.export_relevance(device=input_ids_t.device, out=intro_inputs_no_q[3, dp])
    intro_ids_no_q = intro_inputs_no_q[0]
    intro_att_mask_no_q = intro_inputs_no_q[1]
    intro_token_type_no_q = intro_inputs_no_q[2]
    intro_labels_no_q = intro_inputs_no_q[3]
    return input_ids_t, token_type_ids_t, attention_mask_t, h, rs, t, intro_ids_no_q, intro_att_mask_no_q, intro_token_type_no_q, intro_labels_no_q, intro_rets_no_q



class selfalignment(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, bias=False):
        super(selfalignment, self).__init__()
        self.in_features = in_features
        self.dropout = nn.Dropout(0.1)
        self.linear=torch.nn.Linear(in_features, in_features,bias=False)
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(in_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, text, text1, textmask1,textmask2):#b,s1,h;b,s2,h;b,s1;b,s2
        logit=torch.matmul(self.linear(text),text1.transpose(1,2))#b,s1,s2
        textmask11=textmask1.unsqueeze(-1)#b,s1,1
        textmask22= textmask2.unsqueeze(1)#b,1,s2
        masked=textmask11*textmask22#b,s1,s2
        masked=(1-masked)*(-10000.0)
        logits=torch.softmax(logit+masked,-1)#b,s1,s2
        logits1 = torch.softmax(logit + masked, -2)#b,s1,s2
        output = torch.matmul(logits,text1)
        output=output*textmask1.unsqueeze(-1)#b,s1,h
        output1 = torch.matmul(logits1.transpose(1,2),text)
        output1= output1*textmask2.unsqueeze(-1)#b,s2,h
        return output+text,output1+text1,logits*textmask1.unsqueeze(-1)


def init_weights(module):
        """ Initialize the weights.
        """
        BertLayerNorm = torch.nn.LayerNorm
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.01)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


class Introspector(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
        torch.nn.init.xavier_uniform(self.classifier.weight)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)
        outputs = logits
        if labels is not None:
            labels = labels.type_as(logits)
            loss_fct = torch.nn.BCEWithLogitsLoss()
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1), labels.view(-1))
            outputs = (loss, logits)

        return outputs  # (loss), scores, (hidden_states), (attentions)



class Codred(torch.nn.Module):
    def __init__(self, args, num_relations):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.predictor = torch.nn.Linear(self.bert.config.hidden_size, num_relations)
        weight = torch.ones(num_relations, dtype=torch.float32)
        weight[0] = 0.1
        self.d_model = 768
        self.reduced_dim = 256
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=weight, reduction='none')
        self.aggregator = args.aggregator
        self.no_doc_pair_supervision = args.no_doc_pair_supervision
        self.matt = BCRTransformer(h = 8 , d_model = self.d_model , hidden_size = 768 , num_layers = 3 , device = torch.device(0))
        self.graph_enc = GraphEncoder(h = 8 , d_model = self.d_model , hidden_size = 768 , num_layers = 3)
        self.wu = nn.Linear(self.d_model , self.d_model)
        self.wv = nn.Linear(self.d_model , self.d_model)
        self.wi = nn.Linear(self.d_model , self.d_model)
        self.ln1 = nn.Linear(self.d_model , self.d_model)
        self.ln1_gnn = nn.Linear(2* self.d_model , self.d_model)
        self.dim_reduction = nn.Linear(self.d_model, self.reduced_dim)
        self.reduced_predictor = torch.nn.Linear(self.reduced_dim, num_relations)
        self.gamma = 2
        self.alpha = 0.25
        self.beta = 0.01
        self.d_k = 64
        self.num_relations = num_relations
        self.ent_emb = nn.Parameter(torch.zeros(2 , self.d_model))
        self.gnn = True
        self.norm = nn.LayerNorm(self.d_model)
        self.att_net = Attention(h=self.d_model, d_model=self.d_model)
        self.s_linear = torch.nn.Linear(self.d_model, 2)
        self.dotsim = DotProductSimilarity(scale_output=False)


    def forward(self, input_ids, token_type_ids, attention_mask, dplabel=None, rs=None, train=True):
        bag_len, seq_len = input_ids.size()
        embedding, _ = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False)
        p_embedding = embedding[:, 0, :]
        if bag_len>8:
            print("bag_len:", bag_len)
        if rs is not None or not train:
            ht_threshold_flatten_list = []
            entity_mask, entity_span_list = self.get_htb_v4(input_ids) # T[num_sentences, 3, seq_len] 0-h; 1-t; 2-b
            h_embs = []
            t_embs = []
            b_embs = []
            dp_embs = []
            graph_output_r = []
            h_num = []
            t_num = []
            b_num = []
            for dp in range(0,bag_len):
                b_embs_dp = []
                try:
                    h_span = entity_span_list[dp][0]  
                    t_span = entity_span_list[dp][1]   
                    b_span_chunks = entity_span_list[dp][2]  
                    h_emb = torch.max(embedding[dp, h_span[0]:h_span[-1]+1], dim=0)[0]
                    t_emb = torch.max(embedding[dp, t_span[0]:t_span[-1]+1], dim=0)[0]
                    h_embs.append(h_emb)
                    t_embs.append(t_emb)
                    for b_span in b_span_chunks:
                        b_emb = torch.max(embedding[dp, b_span[0]:b_span[1]+1], dim=0)[0]
                        b_embs_dp.append(b_emb)
                    if len(b_embs_dp) > 16:
                        b_embs_dp = b_embs_dp[:15]
                    else:
                        b_embs_dp = b_embs_dp
                    b_embs.append(b_embs_dp)
                    h_num.append(1)
                    t_num.append(1)
                    b_num.append(len(b_embs_dp))
                    dp_embs.append(p_embedding[dp])
                except IndexError as e:
                    continue
            print(bag_len, b_num)
            htb_index = []
            htb_embs = []
            htb_start = [0]
            htb_end = []
            for h_emb, t_emb, b_emb in zip(h_embs, t_embs, b_embs):
                htb_embs.extend([h_emb,t_emb])
                htb_index.extend([1,2])
                htb_embs.extend(b_emb)
                htb_index.extend([3]*len(b_emb))
                htb_end.append(len(htb_index)-1)
                htb_start.append(len(htb_index))
            htb_start = htb_start[:-1]

            rel_mask = torch.ones(1,len(htb_index), len(htb_index)).to(embedding.device)
            try:
                htb_embs_t = torch.stack(htb_embs, dim=0).unsqueeze(0)
            except:
                print(input_ids)
            u = self.wu(htb_embs_t) 
            v = self.wv(htb_embs_t) 
            
            alpha = u.view(1, len(htb_index), 1, htb_embs_t.size()[-1]) + v.view(1, 1, len(htb_index), htb_embs_t.size()[-1])
            alpha = F.relu(alpha)

            rel_enco = F.relu(self.ln1(alpha))
            bs,es,es,d = rel_enco.size()

            rel_mask = torch.ones(1,len(htb_index), len(htb_index)).to(embedding.device)
            Inner_structure = torch.zeros(len(htb_index), len(htb_index)).to(embedding.device)
            row = list(htb_index)
            col = list(htb_index)
            for p_i, i in enumerate(row):
                for p_j, j in enumerate(col):
                    if (i,j)==(1,1): Inner_structure[p_i, p_j] = 2
                    elif(i,j)==(1,2): Inner_structure[p_i, p_j] = 1
                    elif(i,j)==(1,3): Inner_structure[p_i, p_j] = 2
                    elif(i,j)==(2,1): Inner_structure[p_i, p_j] = 3
                    elif(i,j)==(2,2): Inner_structure[p_i, p_j] = 2
                    elif(i,j)==(2,3): Inner_structure[p_i, p_j] = 3
                    elif(i,j)==(3,1): Inner_structure[p_i, p_j] = 3
                    elif(i,j)==(3,2): Inner_structure[p_i, p_j] = 2
                    elif(i,j)==(3,3): Inner_structure[p_i, p_j] = 3
                    else:pdb.set_trace()
            Inner_structure.unsqueeze(0)
            e_num = [0]
            ptr = 0
            for b_n in b_num:
                ptr += (b_n+2)
                e_num.append(ptr)
            Block_index_row = torch.zeros(len(htb_index), len(htb_index)).to(embedding.device)
            for i in range(0, len(e_num)-1):
                start = e_num[i]
                end = e_num[i+1]
                Block_index_row[start:end][:] = i
            Block_index_col = torch.zeros(len(htb_index), len(htb_index)).to(embedding.device)
            for i in range(0, len(e_num)-1):
                start = e_num[i]
                end = e_num[i+1]
                Block_index_col[:][start:end] = i
            Pos_structure = torch.stack([Inner_structure, Block_index_row, Block_index_col])
            rel_enco_m = self.matt(rel_enco, rel_mask, Pos_structure)
            h_pos = []
            t_pos = []
            for i, e_type in enumerate(htb_index):
                if e_type == 1:
                    h_pos.append(i)
                elif e_type == 2:
                    t_pos.append(i)
                else:
                    continue
            assert len(h_pos) == len(t_pos)
            rel_enco_m_ht = []

            for i,j in zip(h_pos, t_pos):
                rel_enco_m_ht.append(rel_enco_m[0][i][j])
            t_feature_m = torch.stack(rel_enco_m_ht)
            predict_logits = self.predictor(t_feature_m)

            ht_logits = predict_logits
            mtt_logit = torch.mean(ht_logits.reshape(-1, ht_logits.size(-1)), dim=0).unsqueeze(0)
            mtt_dp_logit = ht_logits
        else:     # Inner doc
            entity_mask, entity_span_list = self.get_htb_v4(input_ids) # T[num_sentences, 3, seq_len] 0-h; 1-t; 2-b
            mtt_dp_logits = []
            ht_logits_flatten_list = []
            ht_threshold_flatten_list = []
            for dp in range(0,bag_len):
                h_embs = []
                t_embs = []
                b_embs = []
                try:
                    h_span = entity_span_list[dp][0]   
                    t_span = entity_span_list[dp][1]   
                    b_span_chunks = entity_span_list[dp][2] 
                    h_emb = torch.max(embedding[dp, h_span[0]:h_span[-1]+1], dim=0)[0]
                    t_emb = torch.max(embedding[dp, t_span[0]:t_span[-1]+1], dim=0)[0]
                    h_embs.append(h_emb)
                    t_embs.append(t_emb)
                    for b_span in b_span_chunks:
                        b_emb = torch.max(embedding[dp, b_span[0]:b_span[1]+1], dim=0)[0]
                        b_embs.append(b_emb)
                    h_index = [1 for _ in h_embs]
                    t_index = [2 for _ in t_embs]
                    b_index = [3 for _ in b_embs]
                    htb_index = []
                    htb_embs = []
                    for idx, embs in zip([h_index, t_index, b_index], [h_embs, t_embs, b_embs]):
                        htb_index.extend(idx)
                        htb_embs.extend(embs)
                    rel_mask = torch.ones(1,len(htb_index), len(htb_index)).to(embedding.device)

                    htb_embs_t = torch.stack(htb_embs, dim=0).unsqueeze(0)

                    u = self.wu(htb_embs_t) 
                    v = self.wv(htb_embs_t) 
                    alpha = u.view(1, len(htb_index), 1, htb_embs_t.size()[-1]) + v.view(1, 1, len(htb_index), htb_embs_t.size()[-1]) 
                    alpha = F.relu(alpha)

                    rel_enco = F.relu(self.ln1(alpha))

                    rel_enco_m = self.matt(rel_enco , rel_mask, None)

                    t_feature = rel_enco_m
                    bs,es,es,d = rel_enco.size()

                    predict_logits = self.predictor(t_feature.reshape(bs,es,es,d))
                    ht_logits = predict_logits[0][:len(h_index), len(h_index):len(h_index)+len(t_index)] 
                    _ht_logits_flatten = ht_logits.reshape(1, -1, self.num_relations)
                    ht_logits = predict_logits[0][:len(h_index), len(h_index):len(h_index)+len(t_index)] 
                    mtt_dp_logits.append(ht_logits)
                    ht_logits_flatten_list.append(_ht_logits_flatten)
                except Exception as e:
                    print("Error here", e)
                    pdb.set_trace()
            try:
                mtt_dp_logit = torch.stack(mtt_dp_logits).reshape(1, 1, -1, self.num_relations).squeeze(0).squeeze(0)
            except Exception as e:
                print("Error here", e)
                pdb.set_trace()


        if dplabel is not None and rs is None:
            ht_logits_flatten = torch.stack(ht_logits_flatten_list).squeeze(1)
            ht_fixed_low = (torch.ones_like(ht_logits_flatten)*8)[:,:,0].unsqueeze(-1)
            y_true = torch.zeros_like(ht_logits_flatten)
            for idx, dpl in enumerate(dplabel):
                y_true[idx, 0, dpl.item()] = 1
            bag_loss = self._multilabel_categorical_crossentropy(ht_logits_flatten, y_true, ht_fixed_low+2, ht_fixed_low)
            mtt_logit = mtt_dp_logit
            loss = bag_loss
        elif rs is not None:
            _, prediction = torch.max(mtt_logit, dim=1)
            if self.no_doc_pair_supervision:
                ht_logits_flatten = ht_logits.unsqueeze(1)
                y_true = torch.zeros_like(ht_logits_flatten)
                ht_fixed_low = (torch.ones_like(ht_logits_flatten)*8)[:,:,0].unsqueeze(-1)
                for idx, dpl in enumerate(dplabel):  
                    try:
                        y_true[idx, 0 , rs.item()] = 1
                        y_true[idx, 0 , 0] = 1
                    except:
                        pdb.set_trace()                
                bag_loss = self._multilabel_categorical_crossentropy(ht_logits_flatten, y_true, ht_fixed_low+2, ht_fixed_low)
                loss = bag_loss
            else:
                #pdb.set_trace()
                ht_logits_flatten = ht_logits.unsqueeze(1)
                y_true = torch.zeros_like(ht_logits_flatten)
                ht_fixed_low = (torch.ones_like(ht_logits_flatten)*8)[:,:,0].unsqueeze(-1)
                
                if rs.item() != 0:
                    for idx, dpl in enumerate(dplabel):
                        try:
                            y_true[idx, : , dpl.item()] = torch.ones_like(y_true[idx, : , dpl.item()])
                        except:
                            pdb.set_trace()
                bag_logits = torch.max(ht_logits,dim=0)[0].unsqueeze(0).unsqueeze(0)
                bag_true = torch.zeros_like(bag_logits)
                bag_fixed_low = (torch.ones_like(bag_logits)*8)[:,:,0].unsqueeze(-1)
                if len(rs)==1:
                    bag_true[0, 0, rs.item()] = 1
                else:
                    pdb.set_trace()
                    for r in rs.tolist():
                        bag_true[0, 0, r] = 1
                bag_loss_ = self._multilabel_categorical_crossentropy(bag_logits, bag_true, bag_fixed_low+2, bag_fixed_low)
                bag_loss = self._multilabel_categorical_crossentropy(ht_logits_flatten, y_true, ht_fixed_low+2, ht_fixed_low)
                loss = bag_loss + bag_loss_
        else:
            ht_logits_flatten = ht_logits.unsqueeze(1)
            ht_fixed_low = (torch.ones_like(ht_logits_flatten)*8)[:,:,0].unsqueeze(-1)
            _, prediction = torch.max(mtt_logit, dim=1)
            loss = None
        prediction = [] 
        return loss, prediction, mtt_logit, ht_logits_flatten.transpose(0,1), (ht_fixed_low+2).transpose(0,1)


    def _multilabel_categorical_crossentropy(self, y_pred, y_true, cr_ceil, cr_low, ghm=True, r_dropout=True):

        y_pred = (1 - 2 * y_true) * y_pred  # -1 -> pos classes, 1 -> neg classes
        y_pred_neg = y_pred - y_true * 1e12  # mask the pred oudtuts of pos classes
        y_pred_pos = y_pred - (1 - y_true) * 1e12  # mask the pred oudtuts of neg classes
        y_pred_neg = torch.cat([y_pred_neg, cr_ceil], dim=-1)
        y_pred_pos = torch.cat([y_pred_pos, -cr_low], dim=-1)
        neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
        pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
        return ((neg_loss + pos_loss + cr_low.squeeze(-1) - cr_ceil.squeeze(-1))).mean()

    def graph_encode(self , ent_encode , rel_encode , ent_mask , rel_mask):
        bs , ne , d = ent_encode.size()
        ent_encode = ent_encode + self.ent_emb[0].view(1,1,d)   
        rel_encode = rel_encode + self.ent_emb[1].view(1,1,1,d)
        rel_encode , ent_encode = self.graph_enc(rel_encode , ent_encode , rel_mask , ent_mask)
        return rel_encode


    def predict_logit(self, r_embedding, rs=None):
        weight = self.predictor.weight
        if self.aggregator == 'max':
            scores = self.predictor(r_embedding)
            prob = torch.nn.functional.softmax(scores, dim=1)
            if rs is not None:
                _, idx = torch.max(prob[:, rs[0]], dim=0, keepdim=True)
                return scores[idx], scores
            else:
                max_score, _ = torch.max(scores, dim=0, keepdim=True)
                return max_score, scores
        elif self.aggregator == 'avg':
            embedding = torch.sum(r_embedding, dim=1, keepdim=True) / r_embedding.shape[0]
            return self.predictor(embedding), self.predictor(r_embedding)
        elif self.aggregator == 'attention':
            attention_score = torch.matmul(r_embedding, torch.t(weight))
            attention_weight = torch.nn.functional.softmax(attention_score, dim=0)
            embedding = torch.matmul(torch.transpose(attention_weight, 0, 1), r_embedding)
            logit = self.predictor(embedding)
            return torch.diag(logit).unsqueeze(0), self.predictor(r_embedding)
        else:
            assert False
    
    def get_htb_v3(self, input_ids):
        htb_mask_list = []
        htb_list_batch = []
        for pi in range(input_ids.size()[0]):
            #pdb.set_trace()
            tmp = torch.nonzero(input_ids[pi] - torch.full(([input_ids.size()[1]]), 1).to(input_ids.device))
            if tmp.size()[0] < input_ids.size()[0]:
                print(input_ids)
            try:
                h_start = (input_ids[pi]==1).nonzero().detach().tolist()[0][0]
                h_end = (input_ids[pi]==2).nonzero().detach().tolist()[0][0]
                t_start = (input_ids[pi]==3).nonzero().detach().tolist()[0][0]
                t_end = (input_ids[pi]==4).nonzero().detach().tolist()[0][0]
                b_spans = torch.nonzero(torch.gt(torch.full(([input_ids.size()[1]]), 99).to(input_ids.device), input_ids[pi])).squeeze(0).squeeze(1).detach().tolist()
                token_len = input_ids[pi].nonzero().size()[0]
                b_spans = [i for i in b_spans if i <= token_len-1]
                assert len(b_spans) >= 4 #
                for i in [h_start, h_end, t_start, t_end]:
                    b_spans.remove(i)
                h_span = [h_pos for h_pos in range(h_start, h_end+1)]
                t_span = [t_pos for t_pos in range(t_start, t_end+1)]
                h_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(h_span).to(input_ids.device), 1)
                t_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(t_span).to(input_ids.device), 1)
            except:# 
                h_span = []
                t_span = []
                h_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device)
                t_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device)
                b_spans = []
            b_span_ = []
            if len(b_spans) > 0 and len(b_spans)%2==0:
                b_span_chunks = [b_spans[i:i+2] for i in range(0,len(b_spans),2)]
                b_span = []
                for span in b_span_chunks:
                    b_span.extend([b_pos for b_pos in range(span[0], span[1]+1)])
                b_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(b_span).to(input_ids.device), 1)
            elif len(b_spans) > 0 and len(b_spans)%2==1:
                b_span = []
                ptr = 0
                #pdb.set_trace()
                while(ptr<=len(b_spans)-1):
                    try:
                        if input_ids[pi][b_spans[ptr+1]] - input_ids[pi][b_spans[ptr]] == 1:
                            b_span.append([b_spans[ptr], b_spans[ptr+1]])
                            ptr += 2
                        else:
                            ptr += 1
                    except IndexError as e: 
                        ptr += 1 
                for bs in b_span:
                    b_span_.extend(bs)
                    if len(b_span_)%2 != 0:
                        print(b_spans)
                b_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(b_span_).to(input_ids.device), 1)
            else:
                b_mask = torch.zeros_like(input_ids[pi])
            htb_mask = torch.concat([h_mask.unsqueeze(0), t_mask.unsqueeze(0), b_mask.unsqueeze(0)], dim=0)
            htb_mask_list.append(htb_mask)
            htb_list_batch.append([h_span, t_span, b_span_])
        htb_mask_batch = torch.stack(htb_mask_list,dim=0)
        return htb_mask_batch, htb_list_batch 


    def get_htb_v4(self, input_ids):
        htb_mask_list = []
        htb_list_batch = []
        for pi in range(input_ids.size()[0]):
            tmp = torch.nonzero(input_ids[pi] - torch.full(([input_ids.size()[1]]), 1).to(input_ids.device))
            if tmp.size()[0] < input_ids.size()[0]:
                print(input_ids)
            try:
                h_start = (input_ids[pi]==1).nonzero().detach().tolist()[0][0]
                h_end = (input_ids[pi]==2).nonzero().detach().tolist()[0][0]
                t_start = (input_ids[pi]==3).nonzero().detach().tolist()[0][0]
                t_end = (input_ids[pi]==4).nonzero().detach().tolist()[0][0]
                b_spans = torch.nonzero(torch.gt(torch.full(([input_ids.size()[1]]), 99).to(input_ids.device), input_ids[pi])).squeeze(0).squeeze(1).detach().tolist()
                token_len = input_ids[pi].nonzero().size()[0]
                b_spans = [i for i in b_spans if i <= token_len-1]
                assert len(b_spans) >= 4 #
                for i in [h_start, h_end, t_start, t_end]:
                    b_spans.remove(i)
                h_span = [h_pos for h_pos in range(h_start, h_end+1)]
                t_span = [t_pos for t_pos in range(t_start, t_end+1)]
                h_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(h_span).to(input_ids.device), 1)
                t_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(t_span).to(input_ids.device), 1)
            except:# 
                h_span = []
                t_span = []
                h_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device)
                t_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device)
                b_spans = []
            b_span_ = []
            if len(b_spans) > 0 and len(b_spans)%2==0:
                b_span_chunks = [b_spans[i:i+2] for i in range(0,len(b_spans),2)]
                b_span = []
                for span in b_span_chunks:
                    b_span.extend([b_pos for b_pos in range(span[0], span[1]+1)])
                b_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(b_span).to(input_ids.device), 1)
                b_span_.extend(b_span)
            elif len(b_spans) > 0 and len(b_spans)%2==1:
                b_span = []
                ptr = 0
                #pdb.set_trace()
                while(ptr<=len(b_spans)-1):
                    try:
                        if input_ids[pi][b_spans[ptr+1]] - input_ids[pi][b_spans[ptr]] == 1:
                            b_span.append([b_spans[ptr], b_spans[ptr+1]])
                            ptr += 2
                        else:
                            ptr += 1
                    except IndexError as e: 
                        ptr += 1 
                for bs in b_span:
                    b_span_.extend(bs)
                    if len(b_span_)%2 != 0:
                        print(b_spans)
                b_span_chunks = [b_span_[i:i+2] for i in range(0,len(b_span_),2)]
                b_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(b_span_).to(input_ids.device), 1)
            else:
                b_span_ = []
                b_span_chunks = []
                b_mask = torch.zeros_like(input_ids[pi])
            htb_mask = torch.concat([h_mask.unsqueeze(0), t_mask.unsqueeze(0), b_mask.unsqueeze(0)], dim=0) 
            htb_mask_list.append(htb_mask)
            htb_list_batch.append([h_span, t_span, b_span_chunks])
        htb_mask_batch = torch.stack(htb_mask_list,dim=0)
        return htb_mask_batch, htb_list_batch 

class CodredCallback(TrainerCallback):
    def __init__(self):
        super().__init__()

    def on_argument(self, parser):
        parser.add_argument('--seq_len', type=int, default=512)
        parser.add_argument('--aggregator', type=str, default='attention')
        parser.add_argument('--positive_only', action='store_true')
        parser.add_argument('--positive_ep_only', action='store_true')
        parser.add_argument('--no_doc_pair_supervision', action='store_true')
        parser.add_argument('--no_additional_marker', action='store_true')
        parser.add_argument('--mask_entity', action='store_true')
        parser.add_argument('--single_path', action='store_true')
        parser.add_argument('--dsre_only', action='store_true')
        parser.add_argument('--raw_only', action='store_true')
        parser.add_argument('--load_model_path', type=str, default=None)
        parser.add_argument('--train_file', type=str, default='./data/rawdata/train_dataset.json')
        parser.add_argument('--dev_file', type=str, default='./data/rawdata/dev_dataset.json')
        parser.add_argument('--test_file', type=str, default='./data/rawdata/test_dataset.json')
        parser.add_argument('--dsre_file', type=str, default='./data/dsre_train_examples.json')
        parser.add_argument('--model_name', type=str, default='bert')


    def load_model(self):
        relations = json.load(open('./data/rawdata/relations.json'))
        relations.sort()
        self.relations = ['n/a'] + relations
        self.relation2id = dict()
        for index, relation in enumerate(self.relations):
            self.relation2id[relation] = index
        with self.trainer.cache():
            reasoner = Codred(self.args, len(self.relations))
            if self.args.load_model_path:
                load_model(reasoner, self.args.load_model_path)
            tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=True)
        self.tokenizer = tokenizer
        return reasoner

    def load_data(self):
        train_dataset = json.load(open(self.args.train_file))
        dev_dataset = json.load(open(self.args.dev_file))
        test_dataset = json.load(open(self.args.test_file))
        if self.args.positive_only:
            train_dataset = [d for d in train_dataset if d[3] != 'n/a']
            dev_dataset = [d for d in dev_dataset if d[3] != 'n/a']
            test_dataset = [d for d in test_dataset if d[3] != 'n/a']
        train_bags = place_train_data(train_dataset)
        dev_bags = place_dev_data(dev_dataset, self.args.single_path)
        test_bags = place_test_data(test_dataset, self.args.single_path)
        if self.args.positive_ep_only:
            train_bags = [b for b in train_bags if b[1] != 'n/a']
            dev_bags = [b for b in dev_bags if 'n/a' not in b[1]]
            test_bags = [b for b in test_bags if 'n/a' not in b[1]]
        self.dsre_train_dataset = json.load(open(self.args.dsre_file))
        self.dsre_train_dataset = [d for i, d in enumerate(self.dsre_train_dataset) if i % 10 == 0]
        d = list()
        for i in range(len(self.dsre_train_dataset) // 8):
            d.append(self.dsre_train_dataset[8 * i:8 * i + 8])
        if self.args.raw_only:
            pass
        elif self.args.dsre_only:
            train_bags = d
        else:
            d.extend(train_bags)
            train_bags = d
        self.redisd = redis.Redis(host='172.28.6.91', port=6379, decode_responses=True, db=0)
        with self.trainer.once():
            self.train_logger = Logger(['train_loss', 'train_acc', 'train_pos_acc', 'train_dsre_acc'], self.trainer.writer, self.args.logging_steps, self.args.local_rank)
            self.dev_logger = Logger(['dev_mean_prec', 'dev_f1', 'dev_auc'], self.trainer.writer, 1, self.args.local_rank)
            self.test_logger = Logger(['test_mean_prec', 'test_f1', 'test_auc'], self.trainer.writer, 1, self.args.local_rank)
        return train_bags, dev_bags, test_bags

    def collate_fn(self):
        return partial(collate_fn, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd), partial(collate_fn_infer, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd), partial(collate_fn_infer, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd)

    def on_train_epoch_start(self, epoch):
        pass

    def on_train_step(self, step, train_step, inputs, extra, loss, outputs):
        with self.trainer.once():
            self.train_logger.log(train_loss=loss)
            if inputs['rs'] is not None:
                _, prediction, logit, ht_logits_flatten, ht_threshold_flatten = outputs
                rs = extra['rs']
                if ht_logits_flatten is not None:
                    r_score, r_idx = torch.max(torch.max(ht_logits_flatten,dim=1)[0], dim=-1)
                    if r_score>ht_threshold_flatten[0, 0, 0]:
                        prediction = [r_idx.item()]
                    else:
                        prediction = [0]

                for p, score, gold in zip(prediction, logit, rs):
                    self.train_logger.log(train_acc=1 if p == gold else 0)
                    if gold > 0:
                        self.train_logger.log(train_pos_acc=1 if p == gold else 0)
            else:
                _, prediction, logit, ht_logits_flatten, ht_threshold_flatten = outputs
                dplabel = inputs['dplabel']
                logit, dplabel = tensor_to_obj(logit, dplabel)
                prediction = []
                if ht_logits_flatten is not None:
                    r_score, r_idx = torch.max(torch.max(ht_logits_flatten,dim=1)[0], dim=-1)
                    for dp_i, (r_s, r_i) in enumerate(zip(r_score, r_idx)):
                        if r_s > ht_threshold_flatten[dp_i, 0, 0]:
                            prediction.append(r_i.item())
                        else:
                            prediction.append(0)
                for p, l in zip(prediction, dplabel):
                    self.train_logger.log(train_dsre_acc=1 if p == l else 0)

    def on_train_epoch_end(self, epoch):
        print(epoch, self.train_logger.d)
        pass

    def on_dev_epoch_start(self, epoch):
        self._prediction = list()

    def on_dev_step(self, step, inputs, extra, outputs):
        _, prediction, logit, ht_logits_flatten, ht_threshold_flatten = outputs
        r_score, r_idx = torch.max(torch.max(ht_logits_flatten,dim=1)[0], dim=-1)
        eval_logit = torch.max(ht_logits_flatten,dim=1)[0]
        if r_score>ht_threshold_flatten[:, 0, 0]:
            prediction = [r_idx.item()]
        else:
            prediction = [0]
        h, t, rs = extra['h'], extra['t'], extra['rs']
        logit = tensor_to_obj(logit)
        self._prediction.append([prediction[0], eval_logit[0], h, t, rs])

    def on_dev_epoch_end(self, epoch):
        self._prediction = self.trainer.distributed_broadcast(self._prediction)
        results = list()
        pred_result = list()
        facts = dict()
        for p, score, h, t, rs in self._prediction:
            rs = [self.relations[r] for r in rs]
            for i in range(1, len(score)):
                pred_result.append({'entpair': [h, t], 'relation': self.relations[i], 'score': score[i]})
            results.append([h, rs, t, self.relations[p]])
            for r in rs:
                if r != 'n/a':
                    facts[(h, t, r)] = 1
        stat = eval_performance(facts, pred_result)
        with self.trainer.once():
            self.dev_logger.log(dev_mean_prec=stat['mean_prec'], dev_f1=stat['f1'], dev_auc=stat['auc'])
            json.dump(stat, open(f'output/dev-stat-dual-K1-{epoch}.json', 'w'))
            json.dump(results, open(f'output/dev-results-dual-K1-{epoch}.json', 'w'))
        return stat['f1']
    
    def on_test_epoch_start(self, epoch):
        self._prediction = list()
        pass

    def on_test_step(self, step, inputs, extra, outputs):
        _, prediction, logit, ht_logits_flatten, ht_threshold_flatten = outputs
        r_score, r_idx = torch.max(torch.max(ht_logits_flatten,dim=1)[0], dim=-1)
        eval_logit = torch.max(ht_logits_flatten,dim=1)[0]

        if r_score>ht_threshold_flatten[0, 0, 0]:
            prediction = [r_idx.item()]
        else:
            prediction = [0]
        h, t, rs = extra['h'], extra['t'], extra['rs']
        logit = tensor_to_obj(logit)
        self._prediction.append([prediction[0], eval_logit[0], h, t, rs])

    def on_test_epoch_end(self, epoch):
        self._prediction = self.trainer.distributed_broadcast(self._prediction)
        results = list()
        pred_result = list()
        facts = dict()
        out_results = list()
        coda_file = dict()
        coda_file['setting'] = 'closed'
        for p, score, h, t, rs in self._prediction:
            rs = [self.relations[r] for r in rs]
            for i in range(1, len(score)):
                pred_result.append({'entpair': [h, t], 'relation': self.relations[i], 'score': score[i]})
                out_results.append({'h_id':str(h), "t_id":str(t), "relation": str(self.relations[i]), "score": float(score[i])})
            results.append([h, rs, t, self.relations[p]])
            for r in rs:
                if r != 'n/a':
                    facts[(h, t, r)] = 1
        coda_file['predictions'] = out_results
        with self.trainer.once():
            json.dump(results, open(f'output/test-results-{epoch}.json', 'w'))
            json.dump(coda_file, open(f'output/test-codalab-results-{epoch}.json', 'w'))
        return True

    def process_train_data(self, data):
        inputs = {
            'input_ids': data[0],
            'token_type_ids': data[1],
            'attention_mask': data[2],
            'dplabel': data[3],
            'rs': data[4], 
            'train': True
        }
        intro_inputs_no_q = {
            'input_ids': data[6],
            'attention_mask': data[7],
            'token_type_ids': data[8],
            'rs':data[4],
            'dplabel': data[3],
            'train': True
        }
        return inputs, {'rs': data[5]}, intro_inputs_no_q, {'intro_rets_no_q': data[10]}

    def process_dev_data(self, data):
        inputs = {
            'input_ids': data[0],
            'token_type_ids': data[1],
            'attention_mask': data[2],
            'train': False
        }
        intro_inputs_no_q = {
            'input_ids': data[6],
            'attention_mask': data[7],
            'token_type_ids': data[8],
            'train': False
        }
        return inputs, {'h': data[3], 'rs': data[4], 't': data[5]}, intro_inputs_no_q, {'intro_rets_no_q': data[10]}
    
    def process_test_data(self, data):
        inputs = {
            'input_ids': data[0],
            'token_type_ids': data[1],
            'attention_mask': data[2],
            'train': False
        }
        intro_inputs_no_q = {
            'input_ids': data[6],
            'attention_mask': data[7],
            'token_type_ids': data[8],
            'train': False
        }
        return inputs, {'h': data[3], 'rs': data[4], 't': data[5]}, intro_inputs_no_q, {'intro_rets_no_q': data[10]}


def main():
    trainer = Trainer(CodredCallback())
    trainer.run()


if __name__ == '__main__':
    main()
