import json
from multiprocessing import reduction
import random
from functools import partial
import pdb
from turtle import pd
import numpy as np
import redis
import sklearn
import torch
from eveliver import (Logger, load_model, tensor_to_obj)
#from _trainer import Trainer, TrainerCallback
from _trainer_entailment import Trainer, TrainerCallback
from transformers import AutoTokenizer, BertModel
from opt_einsum import contract
from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention
from matrix_transformer import Encoder as MatTransformer
from torch import nn
import torch.nn.functional as F
import os
from tqdm import tqdm
from buffer import Buffer
from utils import CAPACITY, BLOCK_SIZE, DEFAULT_MODEL_NAME
from torch.nn import CrossEntropyLoss
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from sentence_reordering import SentReOrdering
def eval_performance(facts, pred_result):
    sorted_pred_result = sorted(pred_result, key=lambda x: x['score'], reverse=True)
    prec = []
    rec = []
    correct = 0
    total = len(facts)
    for i, item in enumerate(sorted_pred_result):
        if (item['entpair'][0], item['entpair'][1], item['relation']) in facts:
            correct += 1
        prec.append(float(correct) / float(i + 1))
        rec.append(float(correct) / float(total))
    auc = sklearn.metrics.auc(x=rec, y=prec)
    np_prec = np.array(prec)
    np_rec = np.array(rec)
    f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max()
    mean_prec = np_prec.mean()
    return {'prec': np_prec.tolist(), 'rec': np_rec.tolist(), 'mean_prec': mean_prec, 'f1': f1, 'auc': auc}


def expand(start, end, total_len, max_size):
    e_size = max_size - (end - start)
    _1 = start - (e_size // 2)
    _2 = end + (e_size - e_size // 2)
    if _2 - _1 <= total_len:
        if _1 < 0:
            _2 -= -1
            _1 = 0
        elif _2 > total_len:
            _1 -= (_2 - total_len)
            _2 = total_len
    else:
        _1 = 0
        _2 = total_len
    return _1, _2


def place_train_data(dataset):
    ep2d = dict()
    for key, doc1, doc2, label in dataset:
        if key not in ep2d:
            ep2d[key] = dict()
        if label not in ep2d[key]:
            ep2d[key][label] = list()
        ep2d[key][label].append([doc1, doc2, label])
    bags = list()
    for key, l2docs in ep2d.items():
        if len(l2docs) == 1 and 'n/a' in l2docs:
            bags.append([key, 'n/a', l2docs['n/a'], 'o'])
        else:
            labels = list(l2docs.keys())
            for label in labels:
                if label != 'n/a':
                    ds = l2docs[label]
                    if 'n/a' in l2docs:
                        ds.extend(l2docs['n/a'])
                    bags.append([key, label, ds, 'o'])
    bags.sort(key=lambda x: x[0] + '#' + x[1])
    return bags


def place_dev_data(dataset, single_path):
    ep2d = dict()
    for key, doc1, doc2, label in dataset:
        if key not in ep2d:
            ep2d[key] = dict()
        if label not in ep2d[key]:
            ep2d[key][label] = list()
        ep2d[key][label].append([doc1, doc2, label])
    bags = list()
    for key, l2docs in ep2d.items():
        if len(l2docs) == 1 and 'n/a' in l2docs:
            bags.append([key, ['n/a'], l2docs['n/a'], 'o'])
        else:
            labels = list(l2docs.keys())
            ds = list()
            for label in labels:
                if single_path and label != 'n/a':
                    ds.append(random.choice(l2docs[label]))
                else:
                    ds.extend(l2docs[label])
            if 'n/a' in labels:
                labels.remove('n/a')
            bags.append([key, labels, ds, 'o'])
    bags.sort(key=lambda x: x[0] + '#' + '#'.join(x[1]))
    return bags


def gen_c(tokenizer, passage, span, max_len, bound_tokens, d_start, d_end, no_additional_marker, mask_entity):
    ret = list()
    ret.append(bound_tokens[0])
    for i in range(span[0], span[1]):
        if mask_entity:
            ret.append('[MASK]')
        else:
            ret.append(passage[i])
    ret.append(bound_tokens[1])
    prev = list()
    prev_ptr = span[0] - 1
    while len(prev) < max_len:
        if prev_ptr < 0:
            break
        if not no_additional_marker and prev_ptr in d_end:
            prev.append(f'[unused{(d_end[prev_ptr] + 2) * 2 + 2}]')
        prev.append(passage[prev_ptr])
        if not no_additional_marker and prev_ptr in d_start:
            prev.append(f'[unused{(d_start[prev_ptr] + 2) * 2 + 1}]')
        prev_ptr -= 1
    nex = list()
    nex_ptr = span[1]
    while len(nex) < max_len:
        if nex_ptr >= len(passage):
            break
        if not no_additional_marker and nex_ptr in d_start:
            nex.append(f'[unused{(d_start[nex_ptr] + 2) * 2 + 1}]')
        nex.append(passage[nex_ptr])
        if not no_additional_marker and nex_ptr in d_end:
            nex.append(f'[unused{(d_end[nex_ptr] + 2) * 2 + 2}]')
        nex_ptr += 1
    pn = max_len - len(ret)
    if len(prev) + len(nex) > pn:
        if len(prev) > pn / 2 and len(nex) > pn / 2:
            prev = prev[0:pn // 2]
            nex = nex[0:pn - pn // 2]
        elif len(prev) <= len(nex):
            nex = nex[0:pn - len(prev)]
        elif len(nex) < len(prev):
            prev = prev[0:pn - len(nex)]
    prev.reverse()
    ret = prev + ret + nex
    return ret


def gen_c_complete(tokenizer, passage, span, max_len, bound_tokens, d_start, d_end, no_additional_marker, mask_entity, ht_start, ht_end, ht):
    ret = list() 
    ret.append(bound_tokens[0]) 
    for i in range(span[0], span[1]): 
        if mask_entity:
            ret.append('[MASK]')
        else:
            ret.append(passage[i]) 
    ret.append(bound_tokens[1])
    prev = list()
    prev_ptr = span[0] - 1 
    while len(prev) < max_len:
        if prev_ptr < 0:
            break
        if not no_additional_marker and prev_ptr in d_end:
            prev.append(f'[unused{(d_end[prev_ptr] + 2) * 2 + 2}]')
        if prev_ptr in ht_end and ht=='h':
            prev.append(f'[unused2]')
        if prev_ptr in ht_end and ht=='t':
            prev.append(f'[unused4]')
        prev.append(passage[prev_ptr])
        if not no_additional_marker and prev_ptr in d_start:
            prev.append(f'[unused{(d_start[prev_ptr] + 2) * 2 + 1}]')
        if prev_ptr in ht_start and ht=='h':
            prev.append(f'[unused1]')
        if prev_ptr in ht_start and ht=='t':
            prev.append(f'[unused3]')
        prev_ptr -= 1
    nex = list()
    nex_ptr = span[1]
    while len(nex) < max_len:
        if nex_ptr >= len(passage):
            break
        if not no_additional_marker and nex_ptr in d_start:
            nex.append(f'[unused{(d_start[nex_ptr] + 2) * 2 + 1}]')
        if nex_ptr in ht_start and ht=='h':
            nex.append(f'[unused1]')
        if nex_ptr in ht_start and ht=='t':
            nex.append(f'[unused3]')
        nex.append(passage[nex_ptr])
        if not no_additional_marker and nex_ptr in d_end:
            nex.append(f'[unused{(d_end[nex_ptr] + 2) * 2 + 2}]')
        if nex_ptr in ht_end and ht=='h':
            nex.append(f'[unused2]')
        if nex_ptr in ht_end and ht=='t':
            nex.append(f'[unused4]')
        nex_ptr += 1
    prev.reverse()
    ret = prev + ret + nex
    return ret

def process(tokenizer, h, t, doc0, doc1):               

    alpha=10e-1
    beta=10e-2
    gamma=10e-3
    h_markers = ["[unused" + str(i) + "]" for i in range(1, 3)]
    t_markers = ["[unused" + str(i) + "]" for i in range(3, 5)]
    ht_markers = ["[unused" + str(i) + "]" for i in range(1, 5)]
    b_markers = ["[unused" + str(i) + "]" for i in range(5, 101)]
    max_blk_num = CAPACITY // (BLOCK_SIZE + 1)
    cnt, batches = 0, []
    d = []
    #d1 = doc1['tokens']
    #d2 = doc2['tokens']
    

    def fix_entity(doc, ht_markers, b_markers):
        #pdb.set_trace()
        markers = ht_markers + b_markers
        markers_pos = []
        if list(set(doc).intersection(set(markers))):
            for marker in markers:
                try:
                    pos = doc.index(marker)
                    markers_pos.append((pos, marker))
                except ValueError as e:
                    continue
        
        idx = 0
        while idx <= len(markers_pos)-1:
            #pdb.set_trace()
            try:
                assert (int(markers_pos[idx][1].replace("[unused", "").replace("]", "")) % 2 == 1) and (int(markers_pos[idx][1].replace("[unused", "").replace("]", "")) - int(markers_pos[idx+1][1].replace("[unused", "").replace("]", "")) == -1)
                entity_name = doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]]
                while "." in entity_name:
                    #pdb.set_trace()
                    #print(entity_name)
                    assert doc[markers_pos[idx][0] + entity_name.index(".") + 1] == "."
                    doc[markers_pos[idx][0] + entity_name.index(".") + 1] = "|"
                    #print(doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]])
                    entity_name = doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]]
                idx += 2
            except:
                #pdb.set_trace()
                idx += 1
                continue
        return doc

    d0 = fix_entity(doc0, ht_markers, b_markers)
    d1 = fix_entity(doc1, ht_markers, b_markers)
    
    #pdb.set_trace()


    question = [tokenizer.cls_token] + tokenizer.tokenize('what is the relation between' + h + 'and' + t)
    q, q_property = [question], [[('relevance', 2), ('blk_type', 1)]]
    for di in [d0, d1]:
        d.extend(di)
    qbuf, cnt = Buffer.split_document_into_blocks(q, tokenizer, properties=q_property, cnt=cnt)
    d0_buf, cnt = Buffer.split_document_into_blocks(d0, tokenizer, cnt=cnt, hard=False, docid=0)
    d1_buf, cnt = Buffer.split_document_into_blocks(d1, tokenizer, cnt=cnt, hard=False, docid=1)
    dbuf = Buffer()
    dbuf.blocks = d0_buf.blocks + d1_buf.blocks
    # A simple but fast sampling method as replacement
    for blk in dbuf:
        if list(set(tokenizer.convert_tokens_to_ids(h_markers)).intersection(set(blk.ids))) and list(set(tokenizer.convert_tokens_to_ids(b_markers)).intersection(set(blk.ids))):
            blk.relevance = 6
        elif list(set(tokenizer.convert_tokens_to_ids(t_markers)).intersection(set(blk.ids))) and list(set(tokenizer.convert_tokens_to_ids(b_markers)).intersection(set(blk.ids))):
            blk.relevance = 5
        elif list(set(tokenizer.convert_tokens_to_ids(h_markers)).intersection(set(blk.ids))):
            blk.relevance = 3
        elif list(set(tokenizer.convert_tokens_to_ids(t_markers)).intersection(set(blk.ids))):
            blk.relevance = 2
        elif list(set(tokenizer.convert_tokens_to_ids(b_markers)).intersection(set(blk.ids))):
            blk.relevance = 1
        else:
            continue
    ret = []
    n = 1
    lb = max_blk_num - 2  
    pbuf_hb, nbuf_hb = dbuf.filtered(lambda blk, idx: blk.relevance == 6, need_residue=True)     
    pbuf_tb, nbuf_tb = dbuf.filtered(lambda blk, idx: blk.relevance == 5, need_residue=True)    
    pbuf_h, nbuf_h = dbuf.filtered(lambda blk, idx: blk.relevance == 3, need_residue=True)   
    pbuf_t, nbuf_t = dbuf.filtered(lambda blk, idx: blk.relevance == 2, need_residue=True)  
    pbuf_b, nbuf_b = dbuf.filtered(lambda blk, idx: blk.relevance == 1, need_residue=True)  
    pbuf_any, nbuf_any = dbuf.filtered(lambda blk, idx: blk.relevance == 0, need_residue=True)         

    _selected_hb_blk = random.sample(pbuf_hb.blocks, min(1, len(pbuf_hb.blocks)))
    _selected_h_blk = random.sample(pbuf_h.blocks, min(1, len(pbuf_h.blocks)))
    _selected_tb_blk = random.sample(pbuf_tb.blocks, min(1, len(pbuf_tb.blocks)))
    _selected_t_blk = random.sample(pbuf_t.blocks, min(1, len(pbuf_t.blocks)))
    if len(_selected_hb_blk)>=1:
        _selected_h_context = _selected_hb_blk
    else:
        _selected_h_context = _selected_h_blk
    if len(_selected_tb_blk)>=1:
        _selected_t_context = _selected_tb_blk
    else:
        _selected_t_context = _selected_t_blk
    rest_htb = list(set(pbuf_hb.blocks).union(set(pbuf_tb.blocks)).difference(set(_selected_h_context)).difference(set(_selected_t_context)))
    _selected_htb_blks = random.sample(rest_htb, min(lb, len(rest_htb)))   
    rest_ht = list(set(pbuf_h.blocks).union(set(pbuf_t.blocks)).difference(set(_selected_h_context)).difference(set(_selected_t_context)))
    _selected_ht_blks = random.sample(rest_ht, min(lb - len(_selected_htb_blks), len(rest_ht))) 
    _selected_pblks = random.sample(pbuf_b.blocks, min(lb - len(_selected_htb_blks) - len(_selected_ht_blks), len(pbuf_b))) 
    _selected_nblks = random.sample(pbuf_any.blocks, min(lb - len(_selected_htb_blks) - len(_selected_ht_blks) - len(_selected_pblks), len(pbuf_any)))
    buf = Buffer()
    buf.blocks = _selected_h_context + _selected_t_context + _selected_htb_blks +  _selected_ht_blks + _selected_pblks + _selected_nblks 
    ret.append(buf.sort_())
    ret[0][0].ids.insert(0, tokenizer.convert_tokens_to_ids(tokenizer.cls_token))
    return qbuf, dbuf, ret[0]


def split2block(tokenizer, h, t, doc0, doc1):                  
    #pdb.set_trace()
    h_markers = ["[unused" + str(i) + "]" for i in range(1, 3)]
    t_markers = ["[unused" + str(i) + "]" for i in range(3, 5)]
    ht_markers = ["[unused" + str(i) + "]" for i in range(1, 5)]
    b_markers = ["[unused" + str(i) + "]" for i in range(5, 101)]
    max_blk_num = CAPACITY // (BLOCK_SIZE + 1)
    cnt, batches = 0, []
    d = []
    def fix_entity(doc, ht_markers, b_markers):
        #pdb.set_trace()
        markers = ht_markers + b_markers
        markers_pos = []
        if list(set(doc).intersection(set(markers))):
            for marker in markers:
                try:
                    pos = doc.index(marker)
                    markers_pos.append((pos, marker))
                except ValueError as e:
                    continue
        idx = 0
        while idx <= len(markers_pos)-1:
            try:
                assert (int(markers_pos[idx][1].replace("[unused", "").replace("]", "")) % 2 == 1) and (int(markers_pos[idx][1].replace("[unused", "").replace("]", "")) - int(markers_pos[idx+1][1].replace("[unused", "").replace("]", "")) == -1)
                entity_name = doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]]
                while "." in entity_name:
                    assert doc[markers_pos[idx][0] + entity_name.index(".") + 1] == "."
                    doc[markers_pos[idx][0] + entity_name.index(".") + 1] = "|"
                    entity_name = doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]]
                idx += 2
            except:
                idx += 1
                continue
        return doc

    d0 = fix_entity(doc0, ht_markers, b_markers)
    d1 = fix_entity(doc1, ht_markers, b_markers)
    
    for di in [d0, d1]:
        d.extend(di)
    d0_buf, cnt = Buffer.split_document_into_blocks(d0, tokenizer, cnt=cnt, hard=False, docid=0)
    d1_buf, cnt = Buffer.split_document_into_blocks(d1, tokenizer, cnt=cnt, hard=False, docid=1)
    dbuf = Buffer()
    dbuf.blocks = d0_buf.blocks + d1_buf.blocks
    for blk in dbuf.blocks:
        if blk.ids[0]!= 101:
            blk.ids = [101] + blk.ids
    return dbuf


def sent_order(tokenizer, h, t, doc0, doc1, encoder_name):
    encoder = BertModel.from_pretrained(encoder_name)
    sentence_blocks = split2block(tokenizer,h,t,doc0,doc1)
    sentences = [tokenizer.convert_ids_to_tokens(blk.ids) for blk in sentence_blocks]
    for s in sentences:
        if '[CLS]' in s:
            s.remove('[CLS]')
        if '[SEP]' in s:
            s.remove('[SEP]')
    sro = SentReOrdering(sentences=sentences, encoder=encoder, device='cuda', tokenizer=tokenizer, h=h, t=t)
    selected = sro.sentence_ordering()
    return selected

def process_example(h, t, doc1, doc2, tokenizer, max_len, redisd, no_additional_marker, mask_entity):
    doc1 = json.loads(redisd.get('codred-doc-' + doc1))
    doc2 = json.loads(redisd.get('codred-doc-' + doc2))
    v_h = None
    for entity in doc1['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == h and v_h is None:
            v_h = entity
    assert v_h is not None
    v_t = None
    for entity in doc2['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == t and v_t is None:
            v_t = entity
    assert v_t is not None
    d1_v = dict()
    for entity in doc1['entities']:
        if 'Q' in entity:
            d1_v[entity['Q']] = entity
    d2_v = dict()
    for entity in doc2['entities']:
        if 'Q' in entity:
            d2_v[entity['Q']] = entity
    ov = set(d1_v.keys()) & set(d2_v.keys())
    if len(ov) > 40:
        ov = set(random.choices(list(ov), k=40))
    ov = list(ov)
    ma = dict()
    for e in ov:
        ma[e] = len(ma)
    d1_start = dict()
    d1_end = dict()
    for entity in doc1['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d1_start[span[0]] = ma[entity['Q']]
                d1_end[span[1] - 1] = ma[entity['Q']]
    d2_start = dict()
    d2_end = dict()
    for entity in doc2['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d2_start[span[0]] = ma[entity['Q']]
                d2_end[span[1] - 1] = ma[entity['Q']]
    #pdb.set_trace()
    
    k1 = gen_c(tokenizer, doc1['tokens'], v_h['spans'][0], max_len // 2 - 2, ['[unused1]', '[unused2]'], d1_start, d1_end, no_additional_marker, mask_entity)
    k2 = gen_c(tokenizer, doc2['tokens'], v_t['spans'][0], max_len // 2 - 1, ['[unused3]', '[unused4]'], d2_start, d2_end, no_additional_marker, mask_entity)
    tokens = ['[CLS]'] + k1 + ['[SEP]'] + k2 + ['[SEP]']
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    if len(token_ids) < max_len:
        token_ids = token_ids + [0] * (max_len - len(tokens))
    attention_mask = [1] * len(tokens) + [0] * (max_len - len(tokens))
    token_type_id = [0] * (len(k1) + 2) + [1] * (len(k2) + 1) + [0] * (max_len - len(tokens))
    return tokens, token_ids, token_type_id, attention_mask


def process_example_complete(h, t, doc1, doc2, tokenizer, max_len, redisd, no_additional_marker, mask_entity):
    max_len = 99999
    bert_max_len = 512
    doc1 = json.loads(redisd.get('codred-doc-' + doc1))
    doc2 = json.loads(redisd.get('codred-doc-' + doc2))
    v_h = None
    for entity in doc1['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == h and v_h is None:
            v_h = entity
    assert v_h is not None
    v_t = None
    for entity in doc2['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == t and v_t is None:
            v_t = entity
    assert v_t is not None
    d1_v = dict()
    for entity in doc1['entities']:
        if 'Q' in entity:
            d1_v[entity['Q']] = entity
    d2_v = dict()
    for entity in doc2['entities']:
        if 'Q' in entity:
            d2_v[entity['Q']] = entity
    ov = set(d1_v.keys()) & set(d2_v.keys())
    if len(ov) > 40:
        ov = set(random.choices(list(ov), k=40))
    ov = list(ov)
    ma = dict()
    for e in ov:
        ma[e] = len(ma)
    d1_start = dict()
    d1_end = dict()
    for entity in doc1['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d1_start[span[0]] = ma[entity['Q']]
                d1_end[span[1] - 1] = ma[entity['Q']]
    d2_start = dict()
    d2_end = dict()
    for entity in doc2['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d2_start[span[0]] = ma[entity['Q']]
                d2_end[span[1] - 1] = ma[entity['Q']]
    
    h_start = [v_h['spans'][i][0] for i in range(1, len(v_h['spans']))]
    h_end = [v_h['spans'][i][1]-1 for i in range(1, len(v_h['spans']))]
    t_start = [v_t['spans'][i][0] for i in range(1, len(v_t['spans']))]
    t_end = [v_t['spans'][i][1]-1 for i in range(1, len(v_t['spans']))]
    k1 = gen_c_complete(tokenizer, doc1['tokens'], v_h['spans'][0], max_len, ['[unused1]', '[unused2]'], d1_start, d1_end, no_additional_marker, mask_entity, h_start, h_end, 'h')
    k2 = gen_c_complete(tokenizer, doc2['tokens'], v_t['spans'][0], max_len, ['[unused3]', '[unused4]'], d2_start, d2_end, no_additional_marker, mask_entity, t_start, t_end, 't')
    

    qbuf, dbuf, selector_ret = process(tokenizer, v_h['name'], v_t['name'], k1, k2)
    
    k1_c = []
    k2_c = []
    ht_markers = ["[unused" + str(i) + "]" for i in range(1, 5)]
    b_markers = ["[unused" + str(i) + "]" for i in range(5, 101)]
    k1_sentences = " ".join(k1).split(".")
    k2_sentences = " ".join(k2).split(".")

    for k1_sent in k1_sentences:
        k1_s_tokens = k1_sent.split(" ")[:-1]
        for k1_s_t in k1_s_tokens:
            if k1_s_t in ht_markers:
                #print(k1_s_t)
                if " ".join(k1_s_tokens) not in k1_c:
                    k1_c.append(" ".join(k1_s_tokens))
            else:
                continue
    for k1_sent in k1_sentences:
        k1_s_tokens = k1_sent.split(" ")[:-1]
        for k1_s_t in k1_s_tokens:    
            if k1_s_t in b_markers and (len(" .".join(k1_c).split(" "))+len(k1_s_tokens)) <= 255:
                #print(k1_s_t)
                if " ".join(k1_s_tokens) not in k1_c:
                    k1_c.append(" ".join(k1_s_tokens))
            else:
                continue

    for k2_sent in k2_sentences:
        k2_s_tokens = k2_sent.split(" ")[:-1]
        for k2_s_t in k2_s_tokens:
            if k2_s_t in ht_markers:
                #print(k2_s_t)
                if " ".join(k2_s_tokens) not in k2_c:
                    k2_c.append(" ".join(k2_s_tokens))
            else:
                continue
    for k2_sent in k2_sentences:
        k2_s_tokens = k2_sent.split(" ")[:-1]
        for k2_s_t in k2_s_tokens:
            if k2_s_t in b_markers and (len(" .".join(k2_c).split(" "))+len(k2_s_tokens)) <= 254:
                #print(k2_s_t)
                if " ".join(k2_s_tokens) not in k2_c:
                    k2_c.append(" ".join(k2_s_tokens))
            else:
                continue        
    
    k1_c_t = " .".join(k1_c).split(" ")
    k2_c_t = " .".join(k2_c).split(" ")
    k1_c_t_ = [i for i in k1_c_t if i != ""]
    k2_c_t_ = [j for j in k2_c_t if j != ""]
    if len(k1_c_t_) > 255:
        #pdb.set_trace()
        h_start = k1_c_t_.index('[unused1]')
        h_end = k1_c_t_.index('[unused2]')
        h_len = h_end - h_start + 1
        pre_len = h_start - 1
        nex_len = len(k1_c_t_) - h_end 
        if pre_len <= 126 - int((h_len / 2)) - 1:
            pre_ptr = 0
        else:
            pre_ptr = h_start - 126 + int((h_len / 2)) + 1
        if nex_len <= 127 - int((h_len / 2)) - 1:
            nex_ptr = -1
        else:
            nex_ptr = h_end + 127 - int((h_len / 2)) - 1
        k1_c_t_ = k1_c_t_[pre_ptr:nex_ptr] 
    if len(k2_c_t_) > 254:
        #pdb.set_trace()
        t_start = k2_c_t_.index('[unused3]')
        t_end = k2_c_t_.index('[unused4]')
        t_len = t_end - t_start + 1
        pre_len = t_start - 1
        nex_len = len(k2_c_t_) - t_end
        if pre_len <= 126 - int((t_len / 2)) - 1:
            pre_ptr = 0
        else:
            pre_ptr = t_start - 126 + int((t_len / 2)) + 1
        if nex_len <= 127 - int((t_len / 2)) - 1:
            nex_ptr = -1
        else:
            nex_ptr = t_end + 127 - int((t_len / 2)) + 1
        k2_c_t_ = k2_c_t_[pre_ptr:nex_ptr]

    tokens = ['[CLS]'] + k1_c_t_ + ['[SEP]'] + k2_c_t_ + ['[SEP]']
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    if len(token_ids) > bert_max_len:
        print(len(token_ids))
        pdb.set_trace()
    if len(token_ids) < bert_max_len:
        token_ids = token_ids + [0] * (bert_max_len - len(tokens))
    attention_mask = [1] * len(tokens) + [0] * (bert_max_len - len(tokens))
    token_type_id = [0] * (len(k1_c_t_) + 2) + [1] * (len(k2_c_t_) + 1) + [0] * (bert_max_len - len(tokens))
    return tokens, token_ids, token_type_id, attention_mask, selector_ret


def process_example_order(h, t, doc1, doc2, tokenizer, max_len, redisd, no_additional_marker, mask_entity, encoder):
    max_len = 99999
    bert_max_len = 512
    doc1 = json.loads(redisd.get('codred-doc-' + doc1))
    doc2 = json.loads(redisd.get('codred-doc-' + doc2))
    v_h = None
    for entity in doc1['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == h and v_h is None:
            v_h = entity
    assert v_h is not None
    v_t = None
    for entity in doc2['entities']:
        if 'Q' in entity and 'Q' + str(entity['Q']) == t and v_t is None:
            v_t = entity
    assert v_t is not None
    d1_v = dict()
    for entity in doc1['entities']:
        if 'Q' in entity:
            d1_v[entity['Q']] = entity
    d2_v = dict()
    for entity in doc2['entities']:
        if 'Q' in entity:
            d2_v[entity['Q']] = entity
    ov = set(d1_v.keys()) & set(d2_v.keys())
    if len(ov) > 40:
        ov = set(random.choices(list(ov), k=40))
    ov = list(ov)
    ma = dict()
    for e in ov:
        ma[e] = len(ma)
    d1_start = dict()
    d1_end = dict()
    for entity in doc1['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d1_start[span[0]] = ma[entity['Q']]
                d1_end[span[1] - 1] = ma[entity['Q']]
    d2_start = dict()
    d2_end = dict()
    for entity in doc2['entities']:
        if 'Q' in entity and entity['Q'] in ma:
            for span in entity['spans']:
                d2_start[span[0]] = ma[entity['Q']]
                d2_end[span[1] - 1] = ma[entity['Q']]
    
    h_start = [v_h['spans'][i][0] for i in range(1, len(v_h['spans']))]
    h_end = [v_h['spans'][i][1]-1 for i in range(1, len(v_h['spans']))]
    t_start = [v_t['spans'][i][0] for i in range(1, len(v_t['spans']))]
    t_end = [v_t['spans'][i][1]-1 for i in range(1, len(v_t['spans']))]
    k1 = gen_c_complete(tokenizer, doc1['tokens'], v_h['spans'][0], max_len, ['[unused1]', '[unused2]'], d1_start, d1_end, no_additional_marker, mask_entity, h_start, h_end, 'h')
    k2 = gen_c_complete(tokenizer, doc2['tokens'], v_t['spans'][0], max_len, ['[unused3]', '[unused4]'], d2_start, d2_end, no_additional_marker, mask_entity, t_start, t_end, 't')
    
    pdb.set_trace()

    selected_order = sent_order(tokenizer, v_h['name'], v_t['name'], k1, k2, encoder)
    k1_c = []
    k2_c = []
    ht_markers = ["[unused" + str(i) + "]" for i in range(1, 5)]
    b_markers = ["[unused" + str(i) + "]" for i in range(5, 101)]
    k1_sentences = " ".join(k1).split(".")
    k2_sentences = " ".join(k2).split(".")

    for k1_sent in k1_sentences:
        k1_s_tokens = k1_sent.split(" ")[:-1]
        for k1_s_t in k1_s_tokens:
            if k1_s_t in ht_markers:
                #print(k1_s_t)
                if " ".join(k1_s_tokens) not in k1_c:
                    k1_c.append(" ".join(k1_s_tokens))
            else:
                continue
    for k1_sent in k1_sentences:
        k1_s_tokens = k1_sent.split(" ")[:-1]
        for k1_s_t in k1_s_tokens:    
            if k1_s_t in b_markers and (len(" .".join(k1_c).split(" "))+len(k1_s_tokens)) <= 255:
                #print(k1_s_t)
                if " ".join(k1_s_tokens) not in k1_c:
                    k1_c.append(" ".join(k1_s_tokens))
            else:
                continue

    for k2_sent in k2_sentences:
        k2_s_tokens = k2_sent.split(" ")[:-1]
        for k2_s_t in k2_s_tokens:
            if k2_s_t in ht_markers:
                #print(k2_s_t)
                if " ".join(k2_s_tokens) not in k2_c:
                    k2_c.append(" ".join(k2_s_tokens))
            else:
                continue
    for k2_sent in k2_sentences:
        k2_s_tokens = k2_sent.split(" ")[:-1]
        for k2_s_t in k2_s_tokens:
            if k2_s_t in b_markers and (len(" .".join(k2_c).split(" "))+len(k2_s_tokens)) <= 254:
                #print(k2_s_t)
                if " ".join(k2_s_tokens) not in k2_c:
                    k2_c.append(" ".join(k2_s_tokens))
            else:
                continue        
    
    k1_c_t = " .".join(k1_c).split(" ")
    k2_c_t = " .".join(k2_c).split(" ")
    k1_c_t_ = [i for i in k1_c_t if i != ""]
    k2_c_t_ = [j for j in k2_c_t if j != ""]
    if len(k1_c_t_) > 255:
        #pdb.set_trace()
        h_start = k1_c_t_.index('[unused1]')
        h_end = k1_c_t_.index('[unused2]')
        h_len = h_end - h_start + 1
        pre_len = h_start - 1
        nex_len = len(k1_c_t_) - h_end 
        if pre_len <= 126 - int((h_len / 2)) - 1:
            pre_ptr = 0
        else:
            pre_ptr = h_start - 126 + int((h_len / 2)) + 1
        if nex_len <= 127 - int((h_len / 2)) - 1:
            nex_ptr = -1
        else:
            nex_ptr = h_end + 127 - int((h_len / 2)) - 1
        k1_c_t_ = k1_c_t_[pre_ptr:nex_ptr] 
    if len(k2_c_t_) > 254:
        #pdb.set_trace()
        t_start = k2_c_t_.index('[unused3]')
        t_end = k2_c_t_.index('[unused4]')
        t_len = t_end - t_start + 1
        pre_len = t_start - 1
        nex_len = len(k2_c_t_) - t_end
        if pre_len <= 126 - int((t_len / 2)) - 1:
            pre_ptr = 0
        else:
            pre_ptr = t_start - 126 + int((t_len / 2)) + 1
        if nex_len <= 127 - int((t_len / 2)) - 1:
            nex_ptr = -1
        else:
            nex_ptr = t_end + 127 - int((t_len / 2)) + 1
        k2_c_t_ = k2_c_t_[pre_ptr:nex_ptr]

    tokens = ['[CLS]'] + k1_c_t_ + ['[SEP]'] + k2_c_t_ + ['[SEP]']
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    if len(token_ids) > bert_max_len:
        print(len(token_ids))
        pdb.set_trace()
    if len(token_ids) < bert_max_len:
        token_ids = token_ids + [0] * (bert_max_len - len(tokens))
    attention_mask = [1] * len(tokens) + [0] * (bert_max_len - len(tokens))
    token_type_id = [0] * (len(k1_c_t_) + 2) + [1] * (len(k2_c_t_) + 1) + [0] * (bert_max_len - len(tokens))
    return tokens, token_ids, token_type_id, attention_mask, selected_order


def collate_fn(batch, args, relation2id, tokenizer, redisd):
    #assert len(batch) == 1
    if batch[0][-1] == 'o':
        batch = batch[0]
        h, t = batch[0].split('#')
        r = relation2id[batch[1]]
        dps = batch[2]
        if len(dps) > 8:
            dps = random.choices(dps, k=8)
        input_ids = list()
        token_type_ids = list()
        attention_mask = list()
        dplabel = list()
        intro_rets = list()
        intro_rets_no_q = list()
        selector_rets = list()
        for doc1, doc2, l in dps:
            tokens, token_ids, token_type_id, amask, selector_ret = process_example_complete(h, t, doc1, doc2, tokenizer, args.seq_len, redisd, args.no_additional_marker, args.mask_entity)

            while(tokenizer.convert_tokens_to_ids("|") in token_ids):
                token_ids[token_ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
            for i_rt in selector_ret:
                while(tokenizer.convert_tokens_to_ids("|") in i_rt.ids):
                    i_rt.ids[i_rt.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
            input_ids.append(token_ids)
            token_type_ids.append(token_type_id)
            attention_mask.append(amask)
            dplabel.append(relation2id[l])
            selector_rets.append(selector_ret)
        input_ids_t = torch.tensor(input_ids, dtype=torch.int64)
        token_type_ids_t = torch.tensor(token_type_ids, dtype=torch.int64)
        attention_mask_t = torch.tensor(attention_mask, dtype=torch.int64)
        dplabel_t = torch.tensor(dplabel, dtype=torch.int64)
        rs_t = torch.tensor([r], dtype=torch.int64)
        selector_inputs = torch.zeros(4, len(dps), CAPACITY, dtype=torch.int64) 
        for dp, buf in enumerate(selector_rets):           
            buf.export_01_turn(out=(selector_inputs[0, dp], selector_inputs[1, dp], selector_inputs[2, dp]))
        for dp, buf in enumerate(selector_rets):           
            buf.export_relevance(device=input_ids_t.device, out=selector_inputs[3, dp])
        selector_ids = selector_inputs[0]
        selector_att_mask = selector_inputs[1]
        selector_token_type = selector_inputs[2]
        selector_labels = selector_inputs[3]

    else:
        examples = batch[0]
        h_len = tokenizer.max_len_sentences_pair // 2 - 2
        t_len = tokenizer.max_len_sentences_pair - tokenizer.max_len_sentences_pair // 2 - 2
        _input_ids = list()
        _token_type_ids = list()
        _attention_mask = list()
        _rs = list()
        intro_rets = list()
        intro_rets_no_q = list()
        selector_rets = list()
        for idx, example in enumerate(examples):
            doc = json.loads(redisd.get(f'dsre-doc-{example[0]}'))
            _, h_start, h_end, t_start, t_end, r = example
            if r in relation2id:
                r = relation2id[r]
            else:
                r = 'n/a'
            #pdb.set_trace()
            h_1, h_2 = expand(h_start, h_end, len(doc), h_len)
            t_1, t_2 = expand(t_start, t_end, len(doc), t_len)
            h_tokens = doc[h_1:h_start] + ['[unused1]'] + doc[h_start:h_end] + ['[unused2]'] + doc[h_end:h_2]
            t_tokens = doc[t_1:t_start] + ['[unused3]'] + doc[t_start:t_end] + ['[unused4]'] + doc[t_end:t_2]
            h_name = doc[h_start:h_end]
            t_name = doc[t_start:t_end]
            h_token_ids = tokenizer.convert_tokens_to_ids(h_tokens)
            t_token_ids = tokenizer.convert_tokens_to_ids(t_tokens)
            #pdb.set_trace()
            qbuf, dbuf, selector_ret = process(tokenizer, " ".join(doc[h_start:h_end]), " ".join(doc[t_start:t_end]), h_tokens, t_tokens)
            for i_rt in selector_ret:
                while(tokenizer.convert_tokens_to_ids("|") in i_rt.ids):
                    i_rt.ids[i_rt.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
            input_ids = tokenizer.build_inputs_with_special_tokens(h_token_ids, t_token_ids)
            token_type_ids = tokenizer.create_token_type_ids_from_sequences(h_token_ids, t_token_ids)
            obj = tokenizer._pad({'input_ids': input_ids, 'token_type_ids': token_type_ids}, max_length=args.seq_len, padding_strategy='max_length')
            _input_ids.append(obj['input_ids'])
            _token_type_ids.append(obj['token_type_ids'])
            _attention_mask.append(obj['attention_mask'])
            _rs.append(r)
            selector_rets.append(selector_ret)
        input_ids_t = torch.tensor(_input_ids, dtype=torch.long)
        token_type_ids_t = torch.tensor(_token_type_ids, dtype=torch.long)
        attention_mask_t = torch.tensor(_attention_mask, dtype=torch.long)
        dplabel_t = torch.tensor(_rs, dtype=torch.long)
        rs_t = None
        r = None

        selector_inputs = torch.zeros(4, len(examples), CAPACITY, dtype=torch.int64) 
        for ex, buf in enumerate(selector_rets):           
            buf.export_01_turn(out=(selector_inputs[0, ex], selector_inputs[1, ex], selector_inputs[2, ex]))
        for ex, buf in enumerate(selector_rets):           
            buf.export_relevance(device=input_ids_t.device, out=selector_inputs[3, ex])

        selector_ids = selector_inputs[0]
        selector_att_mask = selector_inputs[1]
        selector_token_type = selector_inputs[2]
        selector_labels = selector_inputs[3]
    return selector_ids, selector_token_type, selector_att_mask, dplabel_t, rs_t, [r], 


def collate_fn_infer(batch, args, relation2id, tokenizer, redisd):
    #assert len(batch) == 1
    batch = batch[0]
    h, t = batch[0].split('#')
    rs = [relation2id[r] for r in batch[1]]
    dps = batch[2]
    input_ids = list()
    token_type_ids = list()
    attention_mask = list()
    selector_rets = list()
    for doc1, doc2, l in dps:
        tokens, token_ids, token_type_id, amask, selector_ret = process_example_complete(h, t, doc1, doc2, tokenizer, args.seq_len, redisd, args.no_additional_marker, args.mask_entity)
        while(tokenizer.convert_tokens_to_ids("|") in token_ids):
                token_ids[token_ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
        for i_rt in selector_ret:
            while(tokenizer.convert_tokens_to_ids("|") in i_rt.ids):
                i_rt.ids[i_rt.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".")
        input_ids.append(token_ids)
        token_type_ids.append(token_type_id)
        attention_mask.append(amask)
        selector_rets.append(selector_ret)
    input_ids_t = torch.tensor(input_ids, dtype=torch.int64)
    token_type_ids_t = torch.tensor(token_type_ids, dtype=torch.int64)
    attention_mask_t = torch.tensor(attention_mask, dtype=torch.int64)


    selector_inputs = torch.zeros(4, len(dps), CAPACITY, dtype=torch.int64) 
    for dp, buf in enumerate(selector_rets):           
        buf.export_01_turn(out=(selector_inputs[0, dp], selector_inputs[1, dp], selector_inputs[2, dp]))
    for dp, buf in enumerate(selector_rets):           
        buf.export_relevance(device=input_ids_t.device, out=selector_inputs[3, dp])

    selector_ids = selector_inputs[0]
    selector_att_mask = selector_inputs[1]
    selector_token_type = selector_inputs[2]
    selector_labels = selector_inputs[3]
    
    return selector_ids, selector_token_type, selector_att_mask, h, rs, t






class Introspector(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
        torch.nn.init.xavier_uniform(self.classifier.weight)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)
        outputs = logits
        if labels is not None:
            labels = labels.type_as(logits)
            loss_fct = torch.nn.BCEWithLogitsLoss()
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1), labels.view(-1))
            outputs = (loss, logits)

        return outputs  # (loss), scores, (hidden_states), (attentions)



class Codred(torch.nn.Module):
    def __init__(self, args, num_relations):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.predictor = torch.nn.Linear(self.bert.config.hidden_size, num_relations)
        weight = torch.ones(num_relations, dtype=torch.float32)
        weight[0] = 0.1
        self.d_model = 768
        self.reduced_dim = 256
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=weight, reduction='none')
        self.aggregator = args.aggregator
        self.no_doc_pair_supervision = args.no_doc_pair_supervision


    def forward(self, input_ids, token_type_ids, attention_mask, dplabel=None, rs=None):

        embedding, _ = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False)
        # r_embedding: T(num_sentences, embedding_size)
        r_embedding = embedding[:, 0, :]
        # logit: T(1, num_relations)
        # dp_logit: T(num_sentences, num_relations)
        logit, dp_logit = self.predict_logit(r_embedding, rs=rs)
        # prediction: T(1)
        _, prediction = torch.max(logit, dim=1)
        if dplabel is not None and rs is None:
            loss = self.loss(dp_logit, dplabel)
            # prediction: T(num_sentences)
            _, prediction = torch.max(dp_logit, dim=1)
            loss = loss.sum()
        elif rs is not None:
            if self.no_doc_pair_supervision:
                loss = self.loss(logit, rs)
            else:
                loss = self.loss(logit, rs) + self.loss(dp_logit, dplabel)
                loss = loss.sum()
        else:
            loss = None
        return loss, prediction, logit

    
    def predict_logit(self, r_embedding, rs=None):
        # r_embedding: T(num_sentences, embedding_size)
        # weight: T(num_relations, embedding_size)
        weight = self.predictor.weight
        if self.aggregator == 'max':
            # scores: T(num_sentences, num_relations)
            scores = self.predictor(r_embedding)
            # prob: T(num_sentences, num_relations)
            prob = torch.nn.functional.softmax(scores, dim=1)
            if rs is not None:
                _, idx = torch.max(prob[:, rs[0]], dim=0, keepdim=True)
                return scores[idx], scores
            else:
                # max_score: T(1, num_relations)
                max_score, _ = torch.max(scores, dim=0, keepdim=True)
                return max_score, scores
        elif self.aggregator == 'avg':
            # embedding: T(1, embedding_size)
            embedding = torch.sum(r_embedding, dim=1, keepdim=True) / r_embedding.shape[0]
            return self.predictor(embedding), self.predictor(r_embedding)
        elif self.aggregator == 'attention':
            # attention_score: T(num_sentences, num_relations)
            attention_score = torch.matmul(r_embedding, torch.t(weight))
            # attention_weight: T(num_sentences, num_relations)
            attention_weight = torch.nn.functional.softmax(attention_score, dim=0)
            # embedding: T(num_relations, embedding_size)
            embedding = torch.matmul(torch.transpose(attention_weight, 0, 1), r_embedding)
            # logit: T(num_relations, num_relations)
            logit = self.predictor(embedding)
            return torch.diag(logit).unsqueeze(0), self.predictor(r_embedding)
        else:
            assert False



    def get_htb_v4(self, input_ids):
        htb_mask_list = []
        htb_list_batch = []
        for pi in range(input_ids.size()[0]):
            #pdb.set_trace()
            tmp = torch.nonzero(input_ids[pi] - torch.full(([input_ids.size()[1]]), 1).to(input_ids.device))
            if tmp.size()[0] < input_ids.size()[0]:
                print(input_ids)
            try:
                pdb.set_trace()
                h_start = (input_ids[pi]==1).nonzero().detach().tolist()[0][0]
                h_end = (input_ids[pi]==2).nonzero().detach().tolist()[0][0]
                t_start = (input_ids[pi]==3).nonzero().detach().tolist()[0][0]
                t_end = (input_ids[pi]==4).nonzero().detach().tolist()[0][0]
                b_spans = torch.nonzero(torch.gt(torch.full(([input_ids.size()[1]]), 99).to(input_ids.device), input_ids[pi])).squeeze(0).squeeze(1).detach().tolist()
                token_len = input_ids[pi].nonzero().size()[0]
                b_spans = [i for i in b_spans if i <= token_len-1]
                assert len(b_spans) >= 4 #
                for i in [h_start, h_end, t_start, t_end]:
                    b_spans.remove(i)
                h_span = [h_pos for h_pos in range(h_start, h_end+1)]
                t_span = [t_pos for t_pos in range(t_start, t_end+1)]
                print(h_span, t_span)
                h_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(h_span).to(input_ids.device), 1)
                t_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(t_span).to(input_ids.device), 1)
            except Exception as e:
                print(e)
                #pdb.set_trace()
                h_span = []
                t_span = []
                h_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device)
                t_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device)
                b_spans = []
            b_span_ = []
            if len(b_spans) > 0 and len(b_spans)%2==0:
                b_span_chunks = [b_spans[i:i+2] for i in range(0,len(b_spans),2)]
                b_span = []
                for span in b_span_chunks:
                    b_span.extend([b_pos for b_pos in range(span[0], span[1]+1)])
                b_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(b_span).to(input_ids.device), 1)
                b_span_.extend(b_span)
            elif len(b_spans) > 0 and len(b_spans)%2==1:
                b_span = []
                ptr = 0
                #pdb.set_trace()
                while(ptr<=len(b_spans)-1):
                    try:
                        if input_ids[pi][b_spans[ptr+1]] - input_ids[pi][b_spans[ptr]] == 1:
                            b_span.append([b_spans[ptr], b_spans[ptr+1]])
                            ptr += 2
                        else:
                            ptr += 1
                    except IndexError as e:
                        #pdb.set_trace()
                        ptr += 1 # 
                for bs in b_span:
                    #pdb.set_trace()
                    #ex_bs = range(bs[0], bs[1])
                    b_span_.extend(bs)
                    if len(b_span_)%2 != 0:
                        print(b_spans)
                b_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(b_span_).to(input_ids.device), 1)
            else:
                b_mask = torch.zeros_like(input_ids[pi])
            htb_mask = torch.concat([h_mask.unsqueeze(0), t_mask.unsqueeze(0), b_mask.unsqueeze(0)], dim=0) 
            htb_mask_list.append(htb_mask)
            htb_list_batch.append([h_span, t_span, b_span_])
        # pdb.set_trace()
        htb_mask_batch = torch.stack(htb_mask_list,dim=0)
        return htb_mask_batch, htb_list_batch 

    def get_htb_v5(self, input_ids):
        htb_list_batch = []
        for pi in range(input_ids.size()[0]):
            #pdb.set_trace()
            tmp = torch.nonzero(input_ids[pi] - torch.full(([input_ids.size()[1]]), 1).to(input_ids.device))
            if tmp.size()[0] < input_ids.size()[0]:
                print(input_ids)
            try:
                #pdb.set_trace()
                h_span = []
                t_span = []
                h_starts = (input_ids[pi]==1).nonzero().detach().tolist()
                h_ends = (input_ids[pi]==2).nonzero().detach().tolist()
                t_starts = (input_ids[pi]==3).nonzero().detach().tolist()
                t_ends = (input_ids[pi]==4).nonzero().detach().tolist()


                """if len(h_starts)!=len(h_ends) or len(t_starts)!=len(t_ends):
                    pdb.set_trace()"""
                h_ptr = 0
                h_spans = [i[0] for i in (h_starts + h_ends)]
                h_spans.sort()
                while(h_ptr<=len(h_spans)-1):
                    try:
                        if input_ids[pi][h_spans[h_ptr+1]] - input_ids[pi][h_spans[h_ptr]] == 1:
                            h_span.append([h_spans[h_ptr], h_spans[h_ptr+1]])
                            h_ptr += 2
                        else:
                            h_ptr += 1
                    except IndexError as e: 
                        #pdb.set_trace()
                        h_ptr += 1 

                t_ptr = 0
                t_spans = [i[0] for i in (t_starts + t_ends)]
                t_spans.sort()
                while(t_ptr<=len(t_spans)-1):
                    try:
                        if input_ids[pi][t_spans[t_ptr+1]] - input_ids[pi][t_spans[t_ptr]] == 1:
                            t_span.append([t_spans[t_ptr], t_spans[t_ptr+1]])
                            t_ptr += 2
                        else:
                            t_ptr += 1
                    except IndexError as e:
                        #pdb.set_trace()
                        t_ptr += 1 #
                #pdb.set_trace()

                
                b_spans = torch.nonzero(torch.gt(torch.full(([input_ids.size()[1]]), 99).to(input_ids.device), input_ids[pi])).squeeze(0).squeeze(1).detach().tolist()
                token_len = input_ids[pi].nonzero().size()[0]
                b_spans = [i for i in b_spans if i <= token_len-1]
                assert len(b_spans) >= 4 #
                for i in h_spans+t_spans:
                    b_spans.remove(i)

                #h_span_chunks = [h_spans[i:i+2] for i in range(0,len(h_spans),2)]
                #t_span_chunks = [t_spans[i:i+2] for i in range(0,len(t_spans),2)]
 
                #h_span = h_span_chunks
                #t_span = t_span_chunks
            except Exception as e:# 
                print(e)
                #pdb.set_trace()
                h_span = []
                t_span = []
                b_spans = []
            b_span_ = []
            if len(b_spans) > 0 and len(b_spans)%2==0:
                b_span_chunks = [b_spans[i:i+2] for i in range(0,len(b_spans),2)]
                b_span = []
                for span in b_span_chunks:
                    b_span.extend([b_pos for b_pos in range(span[0], span[1]+1)])
                b_span_.extend(b_span)
            elif len(b_spans) > 0 and len(b_spans)%2==1:
                b_span = []
                ptr = 0
                #pdb.set_trace()
                while(ptr<=len(b_spans)-1):
                    try:
                        if input_ids[pi][b_spans[ptr+1]] - input_ids[pi][b_spans[ptr]] == 1:
                            b_span.append([b_spans[ptr], b_spans[ptr+1]])
                            ptr += 2
                        else:
                            ptr += 1
                    except IndexError as e: 
                        ptr += 1 
                for bs in b_span:

                    b_span_.extend(bs)
                    if len(b_span_)%2 != 0:
                        print(b_spans)
                b_span_chunks = [b_span_[i:i+2] for i in range(0,len(b_span_),2)]
                #pdb.set_trace()
            else:
                b_span_ = []
                b_span_chunks = []
            htb_list_batch.append([h_span, t_span, b_span_chunks])

        return htb_list_batch # 

    def get_extra_ht(self, h_pattern, t_pattern, input_list, h_span, t_span):
        #pdb.set_trace()
        extra_h = []
        extra_t = []
        h_ = [h_span[1], h_span[-2]]
        t_ = [t_span[1], t_span[-2]]
        h_ptr = 0
        h_flag = False
        while(h_ptr<=len(input_list) - len(h_pattern)):
            if input_list[h_ptr] == h_pattern[0]:
                h_flag = True
                if len(h_pattern) >= 2:
                    for offset in range(1, len(h_pattern)):
                        if input_list[h_ptr+offset] == h_pattern[0+offset]:
                            continue
                        else:
                            h_flag=False
                else:
                    offset = 0
                if h_flag == True:
                    extra_h.append([h_ptr,h_ptr+offset])
                    h_ptr += len(h_pattern)
                else:
                    h_ptr += 1
            else:
                h_ptr += 1   
        t_ptr = 0
        t_flag = False
        while(t_ptr<=len(input_list)- len(t_pattern)):
            if input_list[t_ptr] == t_pattern[0]:
                t_flag = True
                if len(t_pattern) >= 2:
                    for offset in range(1, len(t_pattern)):
                        if input_list[t_ptr+offset] == t_pattern[0+offset]:
                            continue
                        else:
                            t_flag=False
                else:
                    offset = 0
                if t_flag == True:
                    extra_t.append([t_ptr,t_ptr+offset])
                    t_ptr += len(t_pattern) 
                else:
                    t_ptr += 1
            else:
                t_ptr += 1
        try:
            extra_h.remove(h_)
        except Exception as e: 
            #print(extra_h)
            #print(h_)
            print(e)
            pdb.set_trace()
        try:
            extra_t.remove(t_)
        except Exception as e:
            #print(extra_t)
            #print(t_)
            print(e)
            pdb.set_trace()
        #print(extra_h)
        #print(extra_t)
        return extra_h, extra_t

    def chunkandmatch(self, h_pattern, t_pattern, input_list, h_span, t_span):
        extra_h = []
        extra_t = []
        h_ = [h_span[1], h_span[-2]]
        t_ = [t_span[1], t_span[-2]]
        h_len = len(h_pattern)
        t_len = len(t_pattern)
        h_chunks = [input_list[i:i+h_len] for i in range(0,len(input_list)-h_len+1,1)]
        t_chunks = [input_list[j:j+t_len] for j in range(0,len(input_list)-t_len+1,1)]
        #pdb.set_trace()
        for h_idx, h_chunk in enumerate(h_chunks):
            if h_chunk == h_pattern:
                extra_h.append([h_idx, h_idx+h_len-1])
            else:
                continue
        for t_idx, t_chunk in enumerate(t_chunks):
            if t_chunk == t_pattern:
                extra_t.append([t_idx, t_idx+t_len-1])
            else:
                continue
        #pdb.set_trace()
        #print(extra_h)
        #print(extra_t)
        #pdb.set_trace()
        extra_h.remove(h_)
        extra_t.remove(t_)
        #print(extra_h)
        #print(extra_t)
        
        return extra_h, extra_t


class FocalLoss(nn.Module):
    def __init__(self, gamma = 2, alpha = 1, size_average = True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average
        self.elipson = 0.000001
    
    def forward(self, logits, labels):
        """
        cal culates loss
        logits: batch_size * labels_length * seq_length
        labels: batch_size * seq_length
        """
        if labels.dim() > 2:
            labels = labels.contiguous().view(labels.size(0), labels.size(1), -1)
            labels = labels.transpose(1, 2)
            labels = labels.contiguous().view(-1, labels.size(2)).squeeze()
        if logits.dim() > 3:
            logits = logits.contiguous().view(logits.size(0), logits.size(1), logits.size(2), -1)
            logits = logits.transpose(2, 3)
            logits = logits.contiguous().view(-1, logits.size(1), logits.size(3)).squeeze()
        assert(logits.size(0) == labels.size(0))
        assert(logits.size(2) == labels.size(1))
        batch_size = logits.size(0)
        labels_length = logits.size(1)
        seq_length = logits.size(2)

        # transpose labels into labels onehot
        new_label = labels.unsqueeze(1)
        label_onehot = torch.zeros([batch_size, labels_length, seq_length]).scatter_(1, new_label, 1)

        # calculate log
        log_p = F.log_softmax(logits)
        pt = label_onehot * log_p
        sub_pt = 1 - pt
        fl = -self.alpha * (sub_pt)**self.gamma * log_p
        if self.size_average:
            return fl.mean()
        else:
            return fl.sum()


class CodredCallback(TrainerCallback):
    def __init__(self):
        super().__init__()

    def on_argument(self, parser):
        parser.add_argument('--seq_len', type=int, default=512)
        parser.add_argument('--aggregator', type=str, default='attention')
        parser.add_argument('--positive_only', action='store_true')
        parser.add_argument('--positive_ep_only', action='store_true')
        parser.add_argument('--no_doc_pair_supervision', action='store_true')
        parser.add_argument('--no_additional_marker', action='store_true')
        parser.add_argument('--mask_entity', action='store_true')
        parser.add_argument('--single_path', action='store_true')
        parser.add_argument('--dsre_only', action='store_true')
        parser.add_argument('--raw_only', action='store_true')
        parser.add_argument('--load_model_path', type=str, default=None)
        parser.add_argument('--train_file', type=str, default='./data/rawdata/train_dataset.json')
        parser.add_argument('--dev_file', type=str, default='./data/rawdata/dev_dataset.json')
        parser.add_argument('--dsre_file', type=str, default='./data/dsre_train_examples.json')
        parser.add_argument('--model_name', type=str, default='bert')


    def load_model(self):
        relations = json.load(open('./data/rawdata/relations.json'))
        relations.sort()
        self.relations = ['n/a'] + relations
        self.relation2id = dict()
        for index, relation in enumerate(self.relations):
            self.relation2id[relation] = index
        with self.trainer.cache():
            reasoner = Codred(self.args, len(self.relations))
            introspector = Introspector(self.args)
            if self.args.load_model_path:
                load_model(reasoner, self.args.load_model_path)
            tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=True)
        self.tokenizer = tokenizer
        return reasoner, introspector

    def load_data(self):
        train_dataset = json.load(open(self.args.train_file))
        dev_dataset = json.load(open(self.args.dev_file))
        if self.args.positive_only:
            train_dataset = [d for d in train_dataset if d[3] != 'n/a']
            dev_dataset = [d for d in dev_dataset if d[3] != 'n/a']
        train_bags = place_train_data(train_dataset)
        dev_bags = place_dev_data(dev_dataset, self.args.single_path)
        if self.args.positive_ep_only:
            train_bags = [b for b in train_bags if b[1] != 'n/a']
            dev_bags = [b for b in dev_bags if 'n/a' not in b[1]]
        self.dsre_train_dataset = json.load(open(self.args.dsre_file))
        self.dsre_train_dataset = [d for i, d in enumerate(self.dsre_train_dataset) if i % 10 == 0]
        d = list()
        for i in range(len(self.dsre_train_dataset) // 8):
            d.append(self.dsre_train_dataset[8 * i:8 * i + 8])
        if self.args.raw_only:
            pass
        elif self.args.dsre_only:
            train_bags = d
        else:
            d.extend(train_bags)
            train_bags = d
        self.redisd = redis.Redis(host='localhost', port=6379, decode_responses=True, db=1)
        with self.trainer.once():
            self.train_logger = Logger(['train_loss', 'train_acc', 'train_pos_acc', 'train_dsre_acc'], self.trainer.writer, self.args.logging_steps, self.args.local_rank)
            self.dev_logger = Logger(['dev_mean_prec', 'dev_f1', 'dev_auc'], self.trainer.writer, 1, self.args.local_rank)
        return train_bags, dev_bags

    def collate_fn(self):
        return partial(collate_fn, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd), partial(collate_fn_infer, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd), partial(collate_fn_infer, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd)

    def on_train_epoch_start(self, epoch):
        pass

    def on_train_step(self, step, train_step, inputs, extra, loss, outputs):
        with self.trainer.once():
            self.train_logger.log(train_loss=loss)
            if inputs['rs'] is not None:
                _, prediction, logit = outputs
                rs = extra['rs']
                prediction, logit = tensor_to_obj(prediction, logit)
                for p, score, gold in zip(prediction, logit, rs):
                    self.train_logger.log(train_acc=1 if p == gold else 0)
                    if gold > 0:
                        self.train_logger.log(train_pos_acc=1 if p == gold else 0)
            else:
                _, prediction, logit = outputs
                dplabel = inputs['dplabel']
                prediction, logit, dplabel = tensor_to_obj(prediction, logit, dplabel)
                for p, l in zip(prediction, dplabel):
                    self.train_logger.log(train_dsre_acc=1 if p == l else 0)

    def on_train_epoch_end(self, epoch):
        pass

    def on_dev_epoch_start(self, epoch):
        self._prediction = list()

    def on_dev_step(self, step, inputs, extra, outputs):
        _, prediction, logit = outputs
        h, t, rs = extra['h'], extra['t'], extra['rs']
        prediction, logit = tensor_to_obj(prediction, logit)
        self._prediction.append([prediction[0], logit[0], h, t, rs])

    def on_dev_epoch_end(self, epoch):
        self._prediction = self.trainer.distributed_broadcast(self._prediction)
        results = list()
        pred_result = list()
        facts = dict()
        for p, score, h, t, rs in self._prediction:
            rs = [self.relations[r] for r in rs]
            for i in range(1, len(score)):
                pred_result.append({'entpair': [h, t], 'relation': self.relations[i], 'score': score[i]})
            results.append([h, rs, t, self.relations[p]])
            for r in rs:
                if r != 'n/a':
                    facts[(h, t, r)] = 1
        stat = eval_performance(facts, pred_result)
        with self.trainer.once():
            self.dev_logger.log(dev_mean_prec=stat['mean_prec'], dev_f1=stat['f1'], dev_auc=stat['auc'])
            json.dump(stat, open(f'output/dev-stat-mat-intro-{epoch}.json', 'w'))
            json.dump(results, open(f'output/dev-results-mat-intro-{epoch}.json', 'w'))
        return stat['f1']

    def process_train_data(self, data):
        selector_inputs = {
            'input_ids': data[0],
            'token_type_ids': data[1],
            'attention_mask': data[2],
            'dplabel': data[3],
            'rs': data[4]
        }

        return selector_inputs, {'rs': data[5]}

    def process_dev_data(self, data):
        selector_inputs = {
            'input_ids': data[0],
            'token_type_ids': data[1],
            'attention_mask': data[2]
        }
        return selector_inputs, {'h': data[3], 'rs': data[4], 't': data[5]}


def main():
    trainer = Trainer(CodredCallback())
    trainer.run()


if __name__ == '__main__':
    main()
