#coding:utf-8
import numpy as np
import json
import torch
from utils import padding, padding_triple_id, build_kb_adj_mat
import logging
from torch.utils.data import Dataset, DataLoader
logging.basicConfig(level = logging.INFO)


class OnehopDataset(Dataset):
    def __init__(self, path_to_file, net_info):
        self.dataset = []
        with open(path_to_file) as f:
            for idx, line in enumerate(f):
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])

    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data])+1  

        decoder_len = max([len(item['response']) for item in data])+1
        triple_num = max([len(item['all_triples_one_hop']) for item in data])
        entity_len = max([len(item['all_entities_one_hop']) + max(item['post_triples']) for item in data])
        only_two_entity_len = 3
        triple_num_one_two = 3 
        triple_len_one_two = 3 
        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)
        responses_length = []
        local_entity_length = []
        only_two_entity_length = []
        local_entity = []
        only_two_entity = []
        kb_fact_rels = np.full((len(data), triple_num), 2, dtype=int)
        kb_adj_mats = np.empty(len(data), dtype=object)
        q2e_adj_mats = np.full((len(data), entity_len), 0, dtype=int)
        match_entity_one_hop = np.full((len(data), decoder_len), -1, dtype=int)
        match_entity_only_two = np.full((len(data), decoder_len), -1, dtype=int)
        one_two_triples_id = []
        g2l_only_two_list = []

        response_entity = []

        next_id = 0
        for item in data:
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[next_id, i] = self.word2id[post_word]
                    
                else:
                    posts_id[next_id, i] = self.word2id['_UNK']
            
            tmp_response_entity = set()
            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[next_id, i] = self.word2id[response_word]

                    if response_word in self.entity2id:
                        tmp_response_entity.add(self.entity2id[response_word])
                    
                else:
                    responses_id[next_id, i] = self.word2id['_UNK']

            # responses_length
            responses_length.append(len(item['response']) + 1)

            # local_entity
            local_entity_tmp = []
            for i in range(len(item['post_triples'])):
                if item['post_triples'][i] == 0:
                    continue
                elif item['post'][i] not in self.entity2id:
                    continue
                elif self.entity2id[item['post'][i]] in local_entity_tmp:
                    continue
                else:
                    local_entity_tmp.append(self.entity2id[item['post'][i]])

            for entity_index in item['all_entities_one_hop']:
                if self.csk_entities[entity_index] not in self.entity2id:
                    continue
                if self.entity2id[self.csk_entities[entity_index]] in local_entity_tmp:
                    continue
                else:
                    local_entity_tmp.append(self.entity2id[self.csk_entities[entity_index]])
            local_entity_len_tmp = len(local_entity_tmp)

            local_entity_tmp += [1] * (entity_len - len(local_entity_tmp))
            local_entity.append(local_entity_tmp)

            # kb_adj_mat and kb_fact_rel
            g2l = dict()
            for i in range(len(local_entity_tmp)):
                g2l[local_entity_tmp[i]] = i

            entity2fact_e, entity2fact_f = [], []
            fact2entity_f, fact2entity_e = [], []

            tmp_count = 0
            for i in range(len(item['all_triples_one_hop'])):
                sbj = self.csk_triples[item['all_triples_one_hop'][i]].split()[0][:-1]
                rel = self.csk_triples[item['all_triples_one_hop'][i]].split()[1][:-1]
                obj = self.csk_triples[item['all_triples_one_hop'][i]].split()[2]

                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                
                entity2fact_e += [g2l[self.entity2id[sbj]]] #头实体编号
                entity2fact_f += [tmp_count] 
                fact2entity_f += [tmp_count]
                fact2entity_e += [g2l[self.entity2id[obj]]] #尾实体编号
                kb_fact_rels[next_id, tmp_count] = self.entity2id[rel] #记录关系
                tmp_count += 1

            kb_adj_mats[next_id] = (np.array(entity2fact_f, dtype=int), np.array(entity2fact_e, dtype=int), np.array([1.0] * len(entity2fact_f))), (np.array(fact2entity_e, dtype=int), np.array(fact2entity_f, dtype=int), np.array([1.0] * len(fact2entity_e)))
            
            # q2e_adj_mat
            for i in range(len(item['post_triples'])):
                if item['post_triples'][i] == 0:
                    continue
                elif item['post'][i] not in self.entity2id:
                    continue
                else:
                    q2e_adj_mats[next_id, g2l[self.entity2id[item['post'][i]]]] = 1

            # match_entity_one_hop
            for i in range(len(item['match_response_index_one_hop'])):
                if item['match_response_index_one_hop'][i] == -1:
                    continue
                if self.csk_entities[item['match_response_index_one_hop'][i]] not in self.entity2id:
                    continue
                if self.entity2id[self.csk_entities[item['match_response_index_one_hop'][i]]] not in g2l:
                    continue
                else:
                    match_entity_one_hop[next_id, i] = g2l[self.entity2id[self.csk_entities[item['match_response_index_one_hop'][i]]]]

            # only_two_entity
            only_two_entity_tmp = [1] * only_two_entity_len
            only_two_entity.append(only_two_entity_tmp)

            response_entity.append(list(tmp_response_entity))
            
            temp_1_2 = [[ ] for _ in range(2)]
            # one_two_triple
            one_two_triples_id.append(padding_triple_id(self.entity2id, [[self.csk_triples[x].split(', ') for x in triple] for triple in temp_1_2], triple_num_one_two, triple_len_one_two))
            
            # local_entity_length
            local_entity_length.append(local_entity_len_tmp)

            # only_two_entity_length
            only_two_entity_length.append(3)

            next_id += 1
        

        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))
        local_entity = torch.LongTensor(np.array(local_entity))
        #responses_length = torch.LongTensor(np.array(responses_length))
        q2e_adj_mat = torch.LongTensor(np.array(q2e_adj_mats))
        kb_adj_mat = torch.LongTensor(np.array(q2e_adj_mats))
        kb_fact_rel = torch.LongTensor(np.array(kb_fact_rels))
        #match_entity_one_hop = torch.LongTensor(np.array(match_entity_one_hop))
        only_two_entity = torch.LongTensor(np.array(only_two_entity))
        #match_entity_only_two = torch.LongTensor(np.array(match_entity_only_two))
        one_two_triples_id = torch.LongTensor(np.array(one_two_triples_id))
        #local_entity_length = torch.LongTensor(np.array(local_entity_length))
        #only_two_entity_length = torch.LongTensor(np.array(only_two_entity_length))
        
        batch_size = local_entity.shape[0]
        max_fact = kb_fact_rel.shape[1]
        max_local_entity = local_entity.shape[1]
        max_only_two_entity = only_two_entity.shape[1]
        decoder_len = answer_text.shape[1]

        (e2f_batch, e2f_f, e2f_e, e2f_val), (f2e_batch, f2e_e, f2e_f, f2e_val) = build_kb_adj_mat(kb_adj_mats, 0.0)
        entity2fact_index = torch.LongTensor([e2f_batch, e2f_f, e2f_e])
        entity2fact_val = torch.FloatTensor(e2f_val)
        entity2fact_mat = torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, torch.Size([batch_size, max_fact, max_local_entity])).to_dense() #哪个边对应哪个头实体
        
        fact2entity_index = torch.LongTensor([f2e_batch, f2e_e, f2e_f])
        fact2entity_val = torch.FloatTensor(f2e_val)
        fact2entity_mat = torch.sparse.FloatTensor(fact2entity_index, fact2entity_val, torch.Size([batch_size, max_local_entity, max_fact])).to_dense()   

        # central entity mask
        local_entity_mask = np.zeros([batch_size, local_entity.shape[1]])
        for i in range(batch_size):
            local_entity_mask[i][0:local_entity_length[i]] = 1
        local_entity_mask = torch.LongTensor(local_entity_mask)

        # two-hop entity mask
        only_two_entity_mask = np.zeros([batch_size, only_two_entity.shape[1]])
        for i in range(batch_size):
            only_two_entity_mask[i][0:only_two_entity_length[i]] = 1
        only_two_entity_mask =torch.LongTensor(only_two_entity_mask)

        # decoder mask
        decoder_mask = np.zeros([batch_size, decoder_len])
        for i in range(batch_size):
            decoder_mask[i][0:responses_length[i]] = 1
        decoder_mask = torch.LongTensor(decoder_mask)

        one_hot_entities_local = torch.zeros(batch_size, decoder_len, max_local_entity)
        for b in range(batch_size):
            for d in range(decoder_len):
                if match_entity_one_hop[b][d] == -1:
                    continue
                else:
                    one_hot_entities_local[b][d][match_entity_one_hop[b][d]] = 1

        one_hot_entities_only_two = torch.zeros(batch_size, decoder_len, max_only_two_entity)

        return query_text, answer_text, local_entity, decoder_mask, q2e_adj_mat, entity2fact_mat, fact2entity_mat, kb_fact_rel, one_hot_entities_local, \
            only_two_entity, one_hot_entities_only_two, one_two_triples_id, local_entity_mask, only_two_entity_mask



class DialogDataset(Dataset):
    def __init__(self, path_to_file, net_info):
        self.dataset = []
        with open(path_to_file) as f:
            for idx, line in enumerate(f):
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])

    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data])+1  

        decoder_len = max([len(item['response']) for item in data])+1
        triple_num = max([len(item['all_triples_one_hop']) for item in data])
        entity_len = max([len(item['all_entities_one_hop']) + max(item['post_triples']) for item in data])
        only_two_entity_len = max([len(item['only_two']) for item in data])
        triple_num_one_two = max([len(item['one_two_triple']) for item in data]) 
        triple_len_one_two = max([len(tri) for item in data for tri in item['one_two_triple']]) 
        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)
        responses_length = []
        local_entity_length = []
        only_two_entity_length = []
        local_entity = []
        only_two_entity = []
        kb_fact_rels = np.full((len(data), triple_num), 2, dtype=int)
        kb_adj_mats = np.empty(len(data), dtype=object)
        q2e_adj_mats = np.full((len(data), entity_len), 0, dtype=int)
        match_entity_one_hop = np.full((len(data), decoder_len), -1, dtype=int)
        match_entity_only_two = np.full((len(data), decoder_len), -1, dtype=int)
        one_two_triples_id = []
        g2l_only_two_list = []

        response_entity = []

        next_id = 0
        for item in data:
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[next_id, i] = self.word2id[post_word]
                    
                else:
                    posts_id[next_id, i] = self.word2id['_UNK']
            
            tmp_response_entity = set()
            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[next_id, i] = self.word2id[response_word]

                    if response_word in self.entity2id:
                        tmp_response_entity.add(self.entity2id[response_word])
                    
                else:
                    responses_id[next_id, i] = self.word2id['_UNK']

            # responses_length
            responses_length.append(len(item['response']) + 1)

            # local_entity
            local_entity_tmp = []
            for i in range(len(item['post_triples'])):
                if item['post_triples'][i] == 0:
                    continue
                elif item['post'][i] not in self.entity2id:
                    continue
                elif self.entity2id[item['post'][i]] in local_entity_tmp:
                    continue
                else:
                    local_entity_tmp.append(self.entity2id[item['post'][i]])

            for entity_index in item['all_entities_one_hop']:
                if self.csk_entities[entity_index] not in self.entity2id:
                    continue
                if self.entity2id[self.csk_entities[entity_index]] in local_entity_tmp:
                    continue
                else:
                    local_entity_tmp.append(self.entity2id[self.csk_entities[entity_index]])
            local_entity_len_tmp = len(local_entity_tmp)

            local_entity_tmp += [1] * (entity_len - len(local_entity_tmp))
            local_entity.append(local_entity_tmp)

            # kb_adj_mat and kb_fact_rel
            g2l = dict()
            for i in range(len(local_entity_tmp)):
                g2l[local_entity_tmp[i]] = i

            entity2fact_e, entity2fact_f = [], []
            fact2entity_f, fact2entity_e = [], []

            tmp_count = 0
            for i in range(len(item['all_triples_one_hop'])):
                sbj = self.csk_triples[item['all_triples_one_hop'][i]].split()[0][:-1]
                rel = self.csk_triples[item['all_triples_one_hop'][i]].split()[1][:-1]
                obj = self.csk_triples[item['all_triples_one_hop'][i]].split()[2]

                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                
                entity2fact_e += [g2l[self.entity2id[sbj]]] #头实体编号
                entity2fact_f += [tmp_count] 
                fact2entity_f += [tmp_count]
                fact2entity_e += [g2l[self.entity2id[obj]]] #尾实体编号
                kb_fact_rels[next_id, tmp_count] = self.entity2id[rel] #记录关系
                tmp_count += 1

            kb_adj_mats[next_id] = (np.array(entity2fact_f, dtype=int), np.array(entity2fact_e, dtype=int), np.array([1.0] * len(entity2fact_f))), (np.array(fact2entity_e, dtype=int), np.array(fact2entity_f, dtype=int), np.array([1.0] * len(fact2entity_e)))
            
            # q2e_adj_mat
            for i in range(len(item['post_triples'])):
                if item['post_triples'][i] == 0:
                    continue
                elif item['post'][i] not in self.entity2id:
                    continue
                else:
                    q2e_adj_mats[next_id, g2l[self.entity2id[item['post'][i]]]] = 1

            # match_entity_one_hop
            for i in range(len(item['match_response_index_one_hop'])):
                if item['match_response_index_one_hop'][i] == -1:
                    continue
                if self.csk_entities[item['match_response_index_one_hop'][i]] not in self.entity2id:
                    continue
                if self.entity2id[self.csk_entities[item['match_response_index_one_hop'][i]]] not in g2l:
                    continue
                else:
                    match_entity_one_hop[next_id, i] = g2l[self.entity2id[self.csk_entities[item['match_response_index_one_hop'][i]]]]

            # only_two_entity
            only_two_entity_tmp = []
            for entity_index in item['only_two']:
                if self.csk_entities[entity_index] not in self.entity2id:
                    continue
                if self.entity2id[self.csk_entities[entity_index]] in only_two_entity_tmp:
                    continue
                else:
                    only_two_entity_tmp.append(self.entity2id[self.csk_entities[entity_index]])
            only_two_entity_len_tmp = len(only_two_entity_tmp)
            only_two_entity_tmp += [1] * (only_two_entity_len - len(only_two_entity_tmp))
            only_two_entity.append(only_two_entity_tmp)


            # match_entity_two_hop
            g2l_only_two = dict()
            for i in range(len(only_two_entity_tmp)):
                g2l_only_two[only_two_entity_tmp[i]] = i

            for i in range(len(item['match_response_index_only_two'])):
                if item['match_response_index_only_two'][i] == -1:
                    continue
                if self.csk_entities[item['match_response_index_only_two'][i]] not in self.entity2id:
                    continue
                else:
                    match_entity_only_two[next_id, i] = g2l_only_two[self.entity2id[self.csk_entities[item['match_response_index_only_two'][i]]]]

            response_entity.append(list(tmp_response_entity))
            
            # one_two_triple
            one_two_triples_id.append(padding_triple_id(self.entity2id, [[self.csk_triples[x].split(', ') for x in triple] for triple in item['one_two_triple']], triple_num_one_two, triple_len_one_two))
            
            ############################ g2l_only_two
            g2l_only_two_list.append(g2l_only_two)

            # local_entity_length
            local_entity_length.append(local_entity_len_tmp)

            # only_two_entity_length
            only_two_entity_length.append(only_two_entity_len_tmp)

            next_id += 1
        

        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))
        local_entity = torch.LongTensor(np.array(local_entity))
        #responses_length = torch.LongTensor(np.array(responses_length))
        q2e_adj_mat = torch.LongTensor(np.array(q2e_adj_mats))
        kb_adj_mat = torch.LongTensor(np.array(q2e_adj_mats))
        kb_fact_rel = torch.LongTensor(np.array(kb_fact_rels))
        #match_entity_one_hop = torch.LongTensor(np.array(match_entity_one_hop))
        only_two_entity = torch.LongTensor(np.array(only_two_entity))
        #match_entity_only_two = torch.LongTensor(np.array(match_entity_only_two))
        one_two_triples_id = torch.LongTensor(np.array(one_two_triples_id))
        #local_entity_length = torch.LongTensor(np.array(local_entity_length))
        #only_two_entity_length = torch.LongTensor(np.array(only_two_entity_length))
        
        batch_size = local_entity.shape[0]
        max_fact = kb_fact_rel.shape[1]
        max_local_entity = local_entity.shape[1]
        max_only_two_entity = only_two_entity.shape[1]
        decoder_len = answer_text.shape[1]

        (e2f_batch, e2f_f, e2f_e, e2f_val), (f2e_batch, f2e_e, f2e_f, f2e_val) = build_kb_adj_mat(kb_adj_mats, 0.0)
        entity2fact_index = torch.LongTensor([e2f_batch, e2f_f, e2f_e])
        entity2fact_val = torch.FloatTensor(e2f_val)
        entity2fact_mat = torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, torch.Size([batch_size, max_fact, max_local_entity])).to_dense() #哪个边对应哪个头实体
        
        fact2entity_index = torch.LongTensor([f2e_batch, f2e_e, f2e_f])
        fact2entity_val = torch.FloatTensor(f2e_val)
        fact2entity_mat = torch.sparse.FloatTensor(fact2entity_index, fact2entity_val, torch.Size([batch_size, max_local_entity, max_fact])).to_dense()   

        # central entity mask
        local_entity_mask = np.zeros([batch_size, local_entity.shape[1]])
        for i in range(batch_size):
            local_entity_mask[i][0:local_entity_length[i]] = 1
        local_entity_mask = torch.LongTensor(local_entity_mask)

        # two-hop entity mask
        only_two_entity_mask = np.zeros([batch_size, only_two_entity.shape[1]])
        for i in range(batch_size):
            only_two_entity_mask[i][0:only_two_entity_length[i]] = 1
        only_two_entity_mask =torch.LongTensor(only_two_entity_mask)

        # decoder mask
        decoder_mask = np.zeros([batch_size, decoder_len])
        for i in range(batch_size):
            decoder_mask[i][0:responses_length[i]] = 1
        decoder_mask = torch.LongTensor(decoder_mask)


        one_hot_entities_local = torch.zeros(batch_size, decoder_len, max_local_entity)
        for b in range(batch_size):
            for d in range(decoder_len):
                if match_entity_one_hop[b][d] == -1:
                    continue
                else:
                    one_hot_entities_local[b][d][match_entity_one_hop[b][d]] = 1
                            
        one_hot_entities_only_two = torch.zeros(batch_size, decoder_len, max_only_two_entity)
        for b in range(batch_size):
            for d in range(decoder_len):
                if match_entity_only_two[b][d] == -1:
                    continue
                else:
                    one_hot_entities_only_two[b][d][match_entity_only_two[b][d]] = 1
                    

        return query_text, answer_text, local_entity, decoder_mask, q2e_adj_mat, entity2fact_mat, fact2entity_mat, kb_fact_rel, one_hot_entities_local, \
            only_two_entity, one_hot_entities_only_two, one_two_triples_id, local_entity_mask, only_two_entity_mask


def load_vocab(config):
    with open('%s/resource.txt' % config.data_dir) as f:
        d = json.loads(f.readline())
    f.close()
    
    csk_triples = d['csk_triples'] #所有的三元组
    csk_entities = d['csk_entities'] #所有的实体
    raw_vocab = d['vocab_dict']
    kb_dict = d['dict_csk'] #一个实体其所有的三元组
    dict_csk_entities = d['dict_csk_entities'] #所有实体的编号
    dict_csk_triples = d['dict_csk_triples']

    return csk_entities, csk_triples, kb_dict, dict_csk_entities, dict_csk_triples, raw_vocab
    

def build_vocab(raw_vocab, config, trans='transE'):

    logging.info("Creating word vocabulary...")
    vocab_list = ['_PAD','_GO', '_EOS', '_UNK', ] + sorted(raw_vocab, key=raw_vocab.get, reverse=True)
    if len(vocab_list) > config.symbols:
        vocab_list = vocab_list[:config.symbols]
    
    logging.info("Creating entity vocabulary...")
    entity_list = ['_NONE', '_PAD_H', '_PAD_R', '_PAD_T', '_NAF_H', '_NAF_R', '_NAF_T'] 
    with open('%s/entity.txt' % config.data_dir) as f:
        for i, line in enumerate(f):
            e = line.strip()
            entity_list.append(e)
    
    logging.info("Creating relation vocabulary...")
    relation_list = []
    with open('%s/relation.txt' % config.data_dir) as f:
        for i, line in enumerate(f):
            r = line.strip()
            relation_list.append(r)

    logging.info("Loading word vectors...")
    vectors = {}
    error_line = 0
    with open('%s/glove.840B.300d.txt' % config.data_dir, encoding = 'utf8', errors='ignore') as f:
        for i, line in enumerate(f):
            if i % 100000 == 0:
                logging.info("    processing line %d" % i)
            s = line.strip()
            if len(s.split()) != 301:
                logging.info(i)
                error_line += 1
                continue
            word = s[:s.find(' ')]
            vector = s[s.find(' ')+1:]
            vectors[word] = vector
    logging.info("error line: %d" %error_line)
    
    embed = []
    for word in vocab_list:
        if word in vectors:
            #vector = map(float, vectors[word].split())
            vector = vectors[word].split()
        else:
            vector = np.zeros((config.embed_units), dtype=np.float32) 
        embed.append(vector)
    embed = np.array(embed, dtype=np.float32)
            
    logging.info("Loading entity vectors...")
    entity_embed = []
    with open('%s/entity_%s.txt' % (config.data_dir, trans)) as f:
        for i, line in enumerate(f):
            s = line.strip().split('\t')
            #entity_embed.append(map(float, s))
            entity_embed.append(s)

    logging.info("Loading relation vectors...")
    relation_embed = []
    with open('%s/relation_%s.txt' % (config.data_dir, trans)) as f:
        for i, line in enumerate(f):
            s = line.strip().split('\t')
            relation_embed.append(s)

    entity_relation_embed = np.array(entity_embed+relation_embed, dtype=np.float32)
    entity_embed = np.array(entity_embed, dtype=np.float32)
    relation_embed = np.array(relation_embed, dtype=np.float32)


    word2id = dict()
    entity2id = dict()
    relation2id = dict()
    for word in vocab_list:
        word2id[word] = len(word2id)
    for entity in entity_list + relation_list:
        entity2id[entity] = len(entity2id)
    relation_list = ['_PAD_R', 'Selfto', 'TextFrom', 'TextTo'] + relation_list
    for relation in relation_list:
        relation2id[relation] = len(relation2id)

    return word2id, entity2id, relation2id, vocab_list, embed, entity_list, entity_embed, relation_list, relation_embed, entity_relation_embed


class BertDataset(Dataset):
    def __init__(self, path_to_file, net_info):
        self.dataset = []
        with open(path_to_file) as f:
            for idx, line in enumerate(f):
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']
        self.relation2id = net_info['relation2id']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])

    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data])+1  
        decoder_len = max([len(item['response']) for item in data])+1

        entity_len = [len(item['extend_nodes']['zero_hop']) + len(item['golden_nodes']['one_hop'])\
             + len(item['golden_nodes']['two_hop']) + len(item['golden_nodes']['three_hop']) for item in data]
        entity_len = max(entity_len)
        #entity_len = max([sum([len(item['nodes'][k]) for k in item['nodes'].keys()]) for item in data])

        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)
        node_ids = np.full((len(data), entity_len), 0, dtype=int)
        node_hop = np.full((len(data), entity_len), 0, dtype=int)
        node_mask = np.full((len(data), entity_len + 1, entity_len + 1), 0, dtype=float)

        for idx, item in enumerate(data):
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[idx, i] = self.word2id[post_word]                    
                else:
                    posts_id[idx, i] = self.word2id['_UNK']
            
            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[idx, i] = self.word2id[response_word]
                else:
                    responses_id[idx, i] = self.word2id['_UNK']

            # node
            local_entity_tmp = []
            i = 0

            for entity_index in item['extend_nodes']['zero_hop']:
                if self.csk_entities[entity_index] in self.entity2id:
                    node_ids[idx, i] = self.entity2id[self.csk_entities[entity_index]]
                    node_hop[idx, i] = 1
                    i += 1
                    local_entity_tmp.append(self.entity2id[self.csk_entities[entity_index]])

            for entity_index in item['golden_nodes']['one_hop']:
                if self.csk_entities[entity_index] in self.entity2id:
                    node_ids[idx, i] = self.entity2id[self.csk_entities[entity_index]]
                    node_hop[idx, i] = 2
                    i += 1
                    local_entity_tmp.append(self.entity2id[self.csk_entities[entity_index]])
            
            for entity_index in item['golden_nodes']['two_hop']:
                if self.csk_entities[entity_index] in self.entity2id:
                    node_ids[idx, i] = self.entity2id[self.csk_entities[entity_index]]
                    node_hop[idx, i] = 3
                    i += 1
                    local_entity_tmp.append(self.entity2id[self.csk_entities[entity_index]])
            
            for entity_index in item['golden_nodes']['three_hop']:
                if self.csk_entities[entity_index] in self.entity2id:
                    node_ids[idx, i] = self.entity2id[self.csk_entities[entity_index]]
                    node_hop[idx, i] = 4
                    i += 1
                    local_entity_tmp.append(self.entity2id[self.csk_entities[entity_index]])
            
            node_mask[idx, 0, 0] = 1
            # kb_adj_mat and kb_fact_rel
            g2l = dict()
            for i in range(len(local_entity_tmp)):
                g2l[local_entity_tmp[i]] = i
                node_mask[idx, i + 1, i + 1] = 1
                node_mask[idx, 0, i + 1] = 2
                node_mask[idx, i + 1, 0] = 2
            
            for tri_index in item['golden_tris']['zero_hop'] + item['golden_tris']['one_hop'] + \
                item['golden_tris']['two_hop'] + item['golden_tris']['three_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]

                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue

                index_1 = g2l[self.entity2id[sbj]] + 1
                index_2 = g2l[self.entity2id[obj]] + 1
                index_rel = self.relation2id[rel]

                node_mask[idx, index_1, index_2] = index_rel
                node_mask[idx, index_2, index_1] = index_rel
        
        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))
        node = torch.LongTensor(np.array(node_ids))
        node_hop = torch.LongTensor(np.array(node_hop))
        node_mask = torch.LongTensor(np.array(node_mask))

        padding_num = -2 ** 32 + 1
        very_neg_num = torch.ones_like(node_mask, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(node_mask, dtype=torch.float32)
        attention_mask = torch.where(node_mask==0, very_neg_num, zero_num)

        return query_text, answer_text, node, node_hop, node_mask, attention_mask


class MultiTaskDataset(Dataset):
    def __init__(self, path_to_file, net_info):
        self.dataset = []
        with open(path_to_file) as f:
            for idx, line in enumerate(f):
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']
        self.relation2id = net_info['relation2id']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])


    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data])+1  
        decoder_len = max([len(item['response']) for item in data])+1

        golden_len = max([int(item['length']['golden']) for item in data])
        zero_hop_len = min(500, max([int(item['length']['zero_hop']) for item in data]))
        one_hop_len = min(500, max([int(item['length']['one_hop']) for item in data]))
        two_hop_len = min(500, max([int(item['length']['two_hop']) for item in data]))
        three_hop_len = min(500, max([int(item['length']['three_hop']) for item in data]))
        
        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)
        
        golden_graph = np.full((len(data), golden_len), 0, dtype=int)
        golden_hop = np.full((len(data), golden_len), 0, dtype=int)
        golden_edges = np.full((len(data), golden_len + 1, golden_len + 1), 0, dtype=int)

        zero_graph = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_hop = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_labels = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_edges = np.full((len(data), zero_hop_len + 1, zero_hop_len + 1), 0, dtype=int)

        one_graph = np.full((len(data), one_hop_len), 0, dtype=int)
        one_hop = np.full((len(data), one_hop_len), 0, dtype=int)
        one_labels = np.full((len(data), one_hop_len), 0, dtype=int)
        one_edges = np.full((len(data), one_hop_len + 1, one_hop_len + 1), 0, dtype=int)

        two_graph = np.full((len(data), two_hop_len), 0, dtype=int)
        two_hop = np.full((len(data), two_hop_len), 0, dtype=int)
        two_labels = np.full((len(data), two_hop_len), 0, dtype=int)
        two_edges = np.full((len(data), two_hop_len + 1, two_hop_len + 1), 0, dtype=int)
        
        three_graph = np.full((len(data), three_hop_len), 0, dtype=int)
        three_hop = np.full((len(data), three_hop_len), 0, dtype=int)
        three_labels = np.full((len(data), three_hop_len), 0, dtype=int)
        three_edges = np.full((len(data), three_hop_len + 1, three_hop_len + 1), 0, dtype=int)


        for idx, item in enumerate(data):
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[idx, i] = self.word2id[post_word]                    
                else:
                    posts_id[idx, i] = self.word2id['_UNK']
            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[idx, i] = self.word2id[response_word]
                else:
                    responses_id[idx, i] = self.word2id['_UNK']

            # golden graph
            hop_idx = 1
            i = 0
            golden_edges[idx, 0, 0] = 1
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop'], item['golden_nodes']['one_hop'], item['golden_nodes']['two_hop'], item['golden_nodes']['three_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        golden_graph[idx][i] = entity
                        g2l[entity] = i
                        golden_hop[idx][i] = hop_idx
                        i += 1
                        golden_edges[idx, i, i] = 1
                        golden_edges[idx, 0, i] = 2
                        golden_edges[idx, i, 0] = 3
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop'] + item['golden_tris']['one_hop'] + item['golden_tris']['two_hop'] + item['golden_tris']['three_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 1
                index_2 = g2l[self.entity2id[obj]] + 1
                index_rel = self.relation2id[rel]
                golden_edges[idx, index_2, index_1] = index_rel

            #zero hop graph
            hop_idx = 1
            i = 0
            zero_edges[idx, 0, 0] = 1
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['golden_nodes']['zero_hop']:
                            zero_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        zero_graph[idx][i] = entity
                        g2l[entity] = i
                        zero_hop[idx][i] = hop_idx
                        i += 1
                        zero_edges[idx, i, i] = 1
                        zero_edges[idx, 0, i] = 2                        
                        zero_edges[idx, i, 0] = 3
                        if i == zero_hop_len:
                            break
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 1
                index_2 = g2l[self.entity2id[obj]] + 1
                index_rel = self.relation2id[rel]
                zero_edges[idx, index_2, index_1] = index_rel

            #one hop graph
            hop_idx = 1
            i = 0
            one_edges[idx, 0, 0] = 1
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop'], item['extend_nodes']['one_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['golden_nodes']['one_hop']:
                            one_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        one_graph[idx][i] = entity
                        g2l[entity] = i
                        one_hop[idx][i] = hop_idx
                        i += 1
                        one_edges[idx, i, i] = 1
                        one_edges[idx, 0, i] = 2
                        one_edges[idx, i, 0] = 3
                        if i == one_hop_len:
                            break
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop'] + item['extend_tris']['one_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 1
                index_2 = g2l[self.entity2id[obj]] + 1
                index_rel = self.relation2id[rel]
                one_edges[idx, index_2, index_1] = index_rel

            #two hop graph
            hop_idx = 1
            i = 0
            two_edges[idx, 0, 0] = 1
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop'], item['golden_nodes']['one_hop'], item['extend_nodes']['two_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['golden_nodes']['two_hop']:
                            two_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        two_graph[idx][i] = entity
                        g2l[entity] = i
                        two_hop[idx][i] = hop_idx
                        i += 1
                        two_edges[idx, i, i] = 1
                        two_edges[idx, 0, i] = 2
                        two_edges[idx, i, 0] = 3
                        if i == two_hop_len:
                            break
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop'] + item['golden_tris']['one_hop'] + item['extend_tris']['two_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 1
                index_2 = g2l[self.entity2id[obj]] + 1
                index_rel = self.relation2id[rel]
                two_edges[idx, index_2, index_1] = index_rel

            #three edge graph
            hop_idx = 1
            i = 0
            three_edges[idx, 0, 0] = 1
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop'], item['golden_nodes']['one_hop'], item['golden_nodes']['two_hop'], item['extend_nodes']['three_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['golden_nodes']['three_hop']:
                            three_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        three_graph[idx][i] = entity
                        g2l[entity] = i
                        three_hop[idx][i] = hop_idx
                        i += 1
                        three_edges[idx, i, i] = 1
                        three_edges[idx, 0, i] = 2
                        three_edges[idx, i, 0] = 3
                        if i ==  three_hop_len:
                            break
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop'] + item['golden_tris']['one_hop'] + item['golden_tris']['two_hop'] + item['extend_tris']['three_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 1
                index_2 = g2l[self.entity2id[obj]] + 1
                index_rel = self.relation2id[rel]
                three_edges[idx, index_2, index_1] = index_rel

        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))

        padding_num = -2 ** 32 + 1

        golden_nodes = torch.LongTensor(golden_graph)
        golden_hops = torch.LongTensor(golden_hop)
        golden_edges = torch.LongTensor(golden_edges)
        very_neg_num = torch.ones_like(golden_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(golden_edges, dtype=torch.float32)
        golden_mask = torch.where(golden_edges==0, very_neg_num, zero_num)

        zero_nodes = torch.LongTensor(zero_graph)
        zero_hops = torch.LongTensor(zero_hop)
        zero_labels = torch.LongTensor(zero_labels)
        zero_edges = torch.LongTensor(zero_edges)
        very_neg_num = torch.ones_like(zero_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(zero_edges, dtype=torch.float32)
        zero_mask = torch.where(zero_edges==0, very_neg_num, zero_num)

        one_nodes = torch.LongTensor(one_graph)
        one_hops = torch.LongTensor(one_hop)
        one_labels = torch.LongTensor(one_labels)        
        one_edges = torch.LongTensor(one_edges)
        very_neg_num = torch.ones_like(one_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(one_edges, dtype=torch.float32)
        one_mask = torch.where(one_edges==0, very_neg_num, zero_num)

        two_nodes = torch.LongTensor(two_graph)
        two_hops = torch.LongTensor(two_hop)
        two_labels = torch.LongTensor(two_labels)        
        two_edges = torch.LongTensor(two_edges)
        very_neg_num = torch.ones_like(two_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(two_edges, dtype=torch.float32)
        two_mask = torch.where(two_edges==0, very_neg_num, zero_num)

        three_nodes = torch.LongTensor(three_graph)
        three_hops = torch.LongTensor(three_hop)
        three_labels = torch.LongTensor(three_labels)
        three_edges = torch.LongTensor(three_edges)
        very_neg_num = torch.ones_like(three_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(three_edges, dtype=torch.float32)
        three_mask = torch.where(three_edges==0, very_neg_num, zero_num)

        return query_text, answer_text,\
             golden_nodes, golden_hops, golden_edges, golden_mask, \
                 zero_nodes, zero_hops, zero_edges, zero_mask, zero_labels,\
                     one_nodes, one_hops, one_edges, one_mask, one_labels,\
                         two_nodes, two_hops, two_edges, two_mask, two_labels,\
                             three_nodes, three_hops, three_edges, three_mask, three_labels


class InferenceDataset(Dataset):
    def __init__(self, path_to_file, net_info):
        self.dataset = []
        with open(path_to_file) as f:
            for line in f:
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']
        self.relation2id = net_info['relation2id']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])

    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data])+1  
        decoder_len = max([len(item['response']) for item in data])+1
        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)

        zero_nodes = []
        zero_hops = []
        zero_tris = []
        zero_nodes_dict = []

        gold_0 = []
        gold_1 = []
        gold_2 = []
        gold_3 = []
        for idx, item in enumerate(data):
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[idx, i] = self.word2id[post_word]                    
                else:
                    posts_id[idx, i] = self.word2id['_UNK']

            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[idx, i] = self.word2id[response_word]                    
                else:
                    responses_id[idx, i] = self.word2id['_UNK']
                    
            #zero hop graph
            tmp_nodes = []
            tmp_hops = []
            tmp_dict = dict()
            i = 0
            for entity_index in item['zero_nodes']:
                if self.csk_entities[entity_index] in self.entity2id:
                    entity = self.entity2id[self.csk_entities[entity_index]]
                    tmp_nodes.append(entity)
                    tmp_hops.append(1)
                    tmp_dict[entity] = i
                    i += 1
            tmp_tris = list(item['gold_tris']['zero_hop'])

            zero_nodes.append(tmp_nodes)
            zero_hops.append(tmp_hops)
            zero_nodes_dict.append(tmp_dict)
            zero_tris.append(tmp_tris)

            gold_0.append([self.entity2id[self.csk_entities[e]] for e in item['gold']['zero_hop']])
            gold_1.append([self.entity2id[self.csk_entities[e]] for e in item['gold']['one_hop']])
            gold_2.append([self.entity2id[self.csk_entities[e]] for e in item['gold']['two_hop']])
            gold_3.append([self.entity2id[self.csk_entities[e]] for e in item['gold']['three_hop']])

        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))

        return query_text, answer_text, gold_0, gold_1, gold_2, gold_3,\
            zero_nodes, zero_hops, zero_nodes_dict, zero_tris


class NodeClassifierDataset(Dataset):
    def __init__(self, path_to_file, net_info):
        self.dataset = []
        with open(path_to_file) as f:
            for idx, line in enumerate(f):
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']
        self.relation2id = net_info['relation2id']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])


    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data])+1  
        decoder_len = max([len(item['response']) for item in data])+1
        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)

        zero_len_list, one_len_list, two_len_list, three_len_list = [], [], [], []
        for item in data:
            gold_len = len(item['extend_nodes']['zero_hop'])
            zero_len_list.append(gold_len)

            extend_len = len(item['golden_nodes']['one_hop'])
            one_len_list.append(gold_len + 2 * extend_len)
            gold_len += extend_len

            extend_len = len(item['golden_nodes']['two_hop'])
            two_len_list.append(gold_len + 2 * extend_len)
            gold_len += extend_len

            extend_len = len(item['golden_nodes']['three_hop'])
            three_len_list.append(gold_len + 2 * extend_len)

        zero_hop_len = max(zero_len_list)
        one_hop_len = max(one_len_list)
        two_hop_len = max(two_len_list)
        three_hop_len = max(three_len_list)
        
        zero_graph = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_hop = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_labels = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_edges = np.full((len(data), zero_hop_len + 2, zero_hop_len + 2), 0, dtype=int)

        one_graph = np.full((len(data), one_hop_len), 0, dtype=int)
        one_hop = np.full((len(data), one_hop_len), 0, dtype=int)
        one_labels = np.full((len(data), one_hop_len), 0, dtype=int)
        one_edges = np.full((len(data), one_hop_len + 2, one_hop_len + 2), 0, dtype=int)

        two_graph = np.full((len(data), two_hop_len), 0, dtype=int)
        two_hop = np.full((len(data), two_hop_len), 0, dtype=int)
        two_labels = np.full((len(data), two_hop_len), 0, dtype=int)
        two_edges = np.full((len(data), two_hop_len + 2, two_hop_len + 2), 0, dtype=int)
        
        three_graph = np.full((len(data), three_hop_len), 0, dtype=int)
        three_hop = np.full((len(data), three_hop_len), 0, dtype=int)
        three_labels = np.full((len(data), three_hop_len), 0, dtype=int)
        three_edges = np.full((len(data), three_hop_len + 2, three_hop_len + 2), 0, dtype=int)

        for idx, item in enumerate(data):
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[idx, i] = self.word2id[post_word]                    
                else:
                    posts_id[idx, i] = self.word2id['_UNK']

            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[idx, i] = self.word2id[response_word]                    
                else:
                    responses_id[idx, i] = self.word2id['_UNK']

            #zero hop graph
            hop_idx = 1
            i = 0
            zero_edges[idx, 0, 0] = 1
            zero_edges[idx, 1, 1] = 1
            zero_edges[idx, 0, 1] = 2
            zero_edges[idx, 1, 0] = 3
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['golden_nodes']['zero_hop']:
                            zero_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        zero_graph[idx][i] = entity
                        g2l[entity] = i
                        zero_hop[idx][i] = hop_idx
                        i += 1
                        zero_edges[idx, i + 1, i + 1] = 1
                        zero_edges[idx, 0, i + 1] = 2
                        zero_edges[idx, i + 1, 0] = 3
                        zero_edges[idx, 1, i + 1] = 2
                        zero_edges[idx, i + 1, 1] = 3
                        if i == zero_hop_len:
                            break
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 2
                index_2 = g2l[self.entity2id[obj]] + 2
                index_rel = self.relation2id[rel]
                zero_edges[idx, index_2, index_1] = index_rel

            #one hop graph
            hop_idx = 1
            i = 0
            one_edges[idx, 0, 0] = 1
            one_edges[idx, 1, 1] = 1
            one_edges[idx, 0, 1] = 2
            one_edges[idx, 1, 0] = 3
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop'], item['extend_nodes']['one_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['golden_nodes']['one_hop']:
                            one_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        one_graph[idx][i] = entity
                        g2l[entity] = i
                        one_hop[idx][i] = hop_idx
                        i += 1
                        one_edges[idx, i + 1, i + 1] = 1
                        one_edges[idx, 0, i + 1] = 2
                        one_edges[idx, i + 1, 0] = 3
                        one_edges[idx, 1, i + 1] = 2
                        one_edges[idx, i + 1, 1] = 3
                        if i == one_len_list[idx]:
                            break
                if i == one_len_list[idx]:
                    break
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop'] + item['extend_tris']['one_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 2
                index_2 = g2l[self.entity2id[obj]] + 2
                index_rel = self.relation2id[rel]
                one_edges[idx, index_2, index_1] = index_rel

            #two hop graph
            hop_idx = 1
            i = 0
            two_edges[idx, 0, 0] = 1
            two_edges[idx, 1, 1] = 1
            two_edges[idx, 0, 1] = 2
            two_edges[idx, 1, 0] = 3
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop'], item['golden_nodes']['one_hop'], item['extend_nodes']['two_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['golden_nodes']['two_hop']:
                            two_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        two_graph[idx][i] = entity
                        g2l[entity] = i
                        two_hop[idx][i] = hop_idx
                        i += 1
                        two_edges[idx, i + 1, i + 1] = 1
                        two_edges[idx, 0, i + 1] = 2
                        two_edges[idx, i + 1, 0] = 3
                        two_edges[idx, 1, i + 1] = 2
                        two_edges[idx, i + 1, 1] = 3
                        if i == two_len_list[idx]:
                            break
                if i == two_len_list[idx]:
                    break
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop'] + item['golden_tris']['one_hop'] + item['extend_tris']['two_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 2
                index_2 = g2l[self.entity2id[obj]] + 2
                index_rel = self.relation2id[rel]
                two_edges[idx, index_2, index_1] = index_rel

            #three edge graph
            hop_idx = 1
            i = 0
            three_edges[idx, 0, 0] = 1
            three_edges[idx, 1, 1] = 1
            three_edges[idx, 0, 1] = 2
            three_edges[idx, 1, 0] = 3
            g2l = dict()
            for nodes in [item['extend_nodes']['zero_hop'], item['golden_nodes']['one_hop'], item['golden_nodes']['two_hop'], item['extend_nodes']['three_hop']]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['golden_nodes']['three_hop']:
                            three_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        three_graph[idx][i] = entity
                        g2l[entity] = i
                        three_hop[idx][i] = hop_idx
                        i += 1
                        three_edges[idx, i + 1, i + 1] = 1
                        three_edges[idx, 0, i + 1] = 2
                        three_edges[idx, i + 1, 0] = 3
                        three_edges[idx, 1, i + 1] = 2
                        three_edges[idx, i + 1, 1] = 3
                        if i ==  three_len_list[idx]:
                            break
                if i == three_len_list[idx]:
                    break
                hop_idx += 1
            for tri_index in item['golden_tris']['zero_hop'] + item['golden_tris']['one_hop'] + item['golden_tris']['two_hop'] + item['extend_tris']['three_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 2
                index_2 = g2l[self.entity2id[obj]] + 2
                index_rel = self.relation2id[rel]
                three_edges[idx, index_2, index_1] = index_rel

        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))

        padding_num = -2 ** 32 + 1

        zero_nodes = torch.LongTensor(zero_graph)
        zero_hops = torch.LongTensor(zero_hop)
        zero_labels = torch.LongTensor(zero_labels)
        zero_labels_mask = torch.where(zero_hops==1, 
            torch.ones_like(zero_labels, dtype=torch.float32), 
            torch.zeros_like(zero_labels, dtype=torch.float32))
        zero_edges = torch.LongTensor(zero_edges)
        very_neg_num = torch.ones_like(zero_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(zero_edges, dtype=torch.float32)
        zero_mask = torch.where(zero_edges==0, very_neg_num, zero_num)

        one_nodes = torch.LongTensor(one_graph)
        one_hops = torch.LongTensor(one_hop)
        one_labels = torch.LongTensor(one_labels)        
        one_labels_mask = torch.where(one_hops==2, 
            torch.ones_like(one_labels, dtype=torch.float32), 
            torch.zeros_like(one_labels, dtype=torch.float32))
        one_edges = torch.LongTensor(one_edges)
        very_neg_num = torch.ones_like(one_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(one_edges, dtype=torch.float32)
        one_mask = torch.where(one_edges==0, very_neg_num, zero_num)

        two_nodes = torch.LongTensor(two_graph)
        two_hops = torch.LongTensor(two_hop)
        two_labels = torch.LongTensor(two_labels)        
        two_labels_mask = torch.where(two_hops==3, 
            torch.ones_like(two_labels, dtype=torch.float32), 
            torch.zeros_like(two_labels, dtype=torch.float32))
        two_edges = torch.LongTensor(two_edges)
        very_neg_num = torch.ones_like(two_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(two_edges, dtype=torch.float32)
        two_mask = torch.where(two_edges==0, very_neg_num, zero_num)

        three_nodes = torch.LongTensor(three_graph)
        three_hops = torch.LongTensor(three_hop)
        three_labels = torch.LongTensor(three_labels)
        three_labels_mask = torch.where(three_hops==4, 
            torch.ones_like(three_labels, dtype=torch.float32), 
            torch.zeros_like(three_labels, dtype=torch.float32))
        three_edges = torch.LongTensor(three_edges)
        very_neg_num = torch.ones_like(three_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(three_edges, dtype=torch.float32)
        three_mask = torch.where(three_edges==0, very_neg_num, zero_num)

        return query_text, answer_text, zero_nodes, zero_hops, zero_edges, zero_mask, zero_labels, zero_labels_mask,\
            one_nodes, one_hops, one_edges, one_mask, one_labels, one_labels_mask,\
            two_nodes, two_hops, two_edges, two_mask, two_labels, two_labels_mask,\
            three_nodes, three_hops, three_edges, three_mask, three_labels, three_labels_mask


class NodeEvaluateDataset(Dataset):
    def __init__(self, path_to_file, net_info):
        self.dataset = []
        with open(path_to_file) as f:
            for line in f:
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']
        self.relation2id = net_info['relation2id']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])


    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data])+1  
        decoder_len = max([len(item['response']) for item in data])+1
        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)

        zero_len_list, one_len_list, two_len_list, three_len_list = [], [], [], []
        for item in data:
            gold_len = len(item['zero_nodes'])
            zero_len_list.append(gold_len)

            gold_len += len(item['gold']['one_hop'])
            one_len_list.append(gold_len + len(item['rest']['one_hop']))

            gold_len += len(item['gold']['two_hop'])
            two_len_list.append(gold_len + len(item['rest']['two_hop']))

            gold_len += len(item['gold']['three_hop'])
            three_len_list.append(gold_len + len(item['rest']['three_hop']))

        zero_hop_len = max(zero_len_list)
        one_hop_len = max(one_len_list)
        two_hop_len = max(two_len_list)
        three_hop_len = max(three_len_list)
        
        zero_graph = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_hop = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_labels = np.full((len(data), zero_hop_len), 0, dtype=int)
        zero_edges = np.full((len(data), zero_hop_len + 2, zero_hop_len + 2), 0, dtype=int)

        one_graph = np.full((len(data), one_hop_len), 0, dtype=int)
        one_hop = np.full((len(data), one_hop_len), 0, dtype=int)
        one_labels = np.full((len(data), one_hop_len), 0, dtype=int)
        one_edges = np.full((len(data), one_hop_len + 2, one_hop_len + 2), 0, dtype=int)

        two_graph = np.full((len(data), two_hop_len), 0, dtype=int)
        two_hop = np.full((len(data), two_hop_len), 0, dtype=int)
        two_labels = np.full((len(data), two_hop_len), 0, dtype=int)
        two_edges = np.full((len(data), two_hop_len + 2, two_hop_len + 2), 0, dtype=int)
        
        three_graph = np.full((len(data), three_hop_len), 0, dtype=int)
        three_hop = np.full((len(data), three_hop_len), 0, dtype=int)
        three_labels = np.full((len(data), three_hop_len), 0, dtype=int)
        three_edges = np.full((len(data), three_hop_len + 2, three_hop_len + 2), 0, dtype=int)
        
        gold_0, gold_1, gold_2, gold_3 = [], [], [], []

        for idx, item in enumerate(data):
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[idx, i] = self.word2id[post_word]                    
                else:
                    posts_id[idx, i] = self.word2id['_UNK']

            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[idx, i] = self.word2id[response_word]                    
                else:
                    responses_id[idx, i] = self.word2id['_UNK']

            tmp_gold_0 = []
            for entity_index in item['gold']['zero_hop']:
                if self.csk_entities[entity_index] in self.entity2id:
                    tmp_gold_0.append(self.entity2id[self.csk_entities[entity_index]])
            gold_0.append(tmp_gold_0)

            tmp_gold_1 = []
            for entity_index in item['gold']['one_hop']:
                if self.csk_entities[entity_index] in self.entity2id:
                    tmp_gold_1.append(self.entity2id[self.csk_entities[entity_index]])
            gold_1.append(tmp_gold_1)
            
            tmp_gold_2 = []
            for entity_index in item['gold']['two_hop']:
                if self.csk_entities[entity_index] in self.entity2id:
                    tmp_gold_2.append(self.entity2id[self.csk_entities[entity_index]])
            gold_2.append(tmp_gold_2)
            
            tmp_gold_3 = []
            for entity_index in item['gold']['three_hop']:
                if self.csk_entities[entity_index] in self.entity2id:
                    tmp_gold_3.append(self.entity2id[self.csk_entities[entity_index]])
            gold_3.append(tmp_gold_3)

            #zero hop graph
            i = 0
            zero_edges[idx, 0, 0] = 1
            zero_edges[idx, 1, 1] = 1
            zero_edges[idx, 0, 1] = 2
            zero_edges[idx, 1, 0] = 3
            g2l = dict()
            for entity_index in item['zero_nodes']:
                if self.csk_entities[entity_index] in self.entity2id:
                    if entity_index in item['gold']['zero_hop']:
                        zero_labels[idx, i] = 1
                    entity = self.entity2id[self.csk_entities[entity_index]]
                    zero_graph[idx][i] = entity
                    assert entity not in g2l
                    g2l[entity] = i
                    zero_hop[idx][i] = 1
                    i += 1
                    zero_edges[idx, i + 1, i + 1] = 1
                    zero_edges[idx, 0, i + 1] = 2
                    zero_edges[idx, i + 1, 0] = 3
                    zero_edges[idx, 1, i + 1] = 2
                    zero_edges[idx, i + 1, 1] = 3
                    if i == zero_hop_len:
                        break
            for tri_index in item['gold_tris']['zero_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 2
                index_2 = g2l[self.entity2id[obj]] + 2
                index_rel = self.relation2id[rel]
                zero_edges[idx, index_2, index_1] = index_rel

            #one hop graph
            hop_idx = 1
            i = 0
            one_edges[idx, 0, 0] = 1
            one_edges[idx, 1, 1] = 1
            one_edges[idx, 0, 1] = 2
            one_edges[idx, 1, 0] = 3    
            g2l = dict()
            one_hop_node_tmp = item['gold']['one_hop'] + item['rest']['one_hop']
            for nodes in [item['zero_nodes'], one_hop_node_tmp]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['gold']['one_hop']:
                            one_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        one_graph[idx][i] = entity
                        assert entity not in g2l
                        g2l[entity] = i
                        one_hop[idx][i] = hop_idx
                        i += 1
                        one_edges[idx, i + 1, i + 1] = 1
                        one_edges[idx, 0, i + 1] = 2
                        one_edges[idx, i + 1, 0] = 3
                        one_edges[idx, 1, i + 1] = 2
                        one_edges[idx, i + 1, 1] = 3
                        if i == one_len_list[idx]:
                            break
                if i == one_len_list[idx]:
                    break
                hop_idx += 1
            for tri_index in item['gold_tris']['zero_hop'] + item['rest_tris']['one_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 2
                index_2 = g2l[self.entity2id[obj]] + 2
                index_rel = self.relation2id[rel]
                one_edges[idx, index_2, index_1] = index_rel

            #two hop graph
            hop_idx = 1
            i = 0
            two_edges[idx, 0, 0] = 1
            two_edges[idx, 1, 1] = 1
            two_edges[idx, 0, 1] = 2
            two_edges[idx, 1, 0] = 3
            g2l = dict()
            two_hop_node_tmp = item['gold']['two_hop'] + item['rest']['two_hop']
            for nodes in [item['zero_nodes'], item['gold']['one_hop'], two_hop_node_tmp]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['gold']['two_hop']:
                            two_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        two_graph[idx][i] = entity
                        assert entity not in g2l
                        g2l[entity] = i
                        two_hop[idx][i] = hop_idx
                        i += 1
                        two_edges[idx, i + 1, i + 1] = 1
                        two_edges[idx, 0, i + 1] = 2
                        two_edges[idx, i + 1, 0] = 3
                        two_edges[idx, 1, i + 1] = 2
                        two_edges[idx, i + 1, 1] = 3
                        if i == two_len_list[idx]:
                            break
                if i == two_len_list[idx]:
                    break
                hop_idx += 1
            for tri_index in item['gold_tris']['zero_hop'] + item['gold_tris']['one_hop'] + item['rest_tris']['two_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 2
                index_2 = g2l[self.entity2id[obj]] + 2
                index_rel = self.relation2id[rel]
                two_edges[idx, index_2, index_1] = index_rel

            #three edge graph
            hop_idx = 1
            i = 0
            three_edges[idx, 0, 0] = 1
            three_edges[idx, 1, 1] = 1
            three_edges[idx, 0, 1] = 2
            three_edges[idx, 1, 0] = 3
            g2l = dict()
            three_hop_node_tmp = item['gold']['three_hop'] + item['rest']['three_hop']
            for nodes in [item['zero_nodes'], item['gold']['one_hop'], item['gold']['two_hop'], three_hop_node_tmp]:
                for entity_index in nodes:
                    if self.csk_entities[entity_index] in self.entity2id:
                        if entity_index in item['gold']['three_hop']:
                            three_labels[idx, i] = 1
                        entity = self.entity2id[self.csk_entities[entity_index]]
                        three_graph[idx][i] = entity
                        assert entity not in g2l
                        g2l[entity] = i
                        three_hop[idx][i] = hop_idx
                        i += 1
                        three_edges[idx, i + 1, i + 1] = 1
                        three_edges[idx, 0, i + 1] = 2
                        three_edges[idx, i + 1, 0] = 3
                        three_edges[idx, 1, i + 1] = 2
                        three_edges[idx, i + 1, 1] = 3
                        if i ==  three_len_list[idx]:
                            break
                if i == three_len_list[idx]:
                    break
                hop_idx += 1
            for tri_index in item['gold_tris']['zero_hop'] + item['gold_tris']['one_hop'] + item['gold_tris']['two_hop'] + item['rest_tris']['three_hop']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 2
                index_2 = g2l[self.entity2id[obj]] + 2
                index_rel = self.relation2id[rel]
                three_edges[idx, index_2, index_1] = index_rel

        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))

        padding_num = -2 ** 32 + 1

        label_0 = (zero_hop == 1).astype(int)
        zero_nodes = torch.LongTensor(zero_graph)
        zero_hops = torch.LongTensor(zero_hop)
        zero_labels = torch.LongTensor(zero_labels)
        zero_labels_mask = torch.where(zero_hops==1, 
            torch.ones_like(zero_labels, dtype=torch.float32), 
            torch.zeros_like(zero_labels, dtype=torch.float32))
        zero_edges = torch.LongTensor(zero_edges)
        very_neg_num = torch.ones_like(zero_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(zero_edges, dtype=torch.float32)
        zero_mask = torch.where(zero_edges==0, very_neg_num, zero_num)

        label_1 = (one_hop == 2).astype(int)
        one_nodes = torch.LongTensor(one_graph)
        one_hops = torch.LongTensor(one_hop)
        one_labels = torch.LongTensor(one_labels)        
        one_labels_mask = torch.where(one_hops==2, 
            torch.ones_like(one_labels, dtype=torch.float32), 
            torch.zeros_like(one_labels, dtype=torch.float32))
        one_edges = torch.LongTensor(one_edges)
        very_neg_num = torch.ones_like(one_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(one_edges, dtype=torch.float32)
        one_mask = torch.where(one_edges==0, very_neg_num, zero_num)
        
        label_2 = (two_hop == 3).astype(int)
        two_nodes = torch.LongTensor(two_graph)
        two_hops = torch.LongTensor(two_hop)
        two_labels = torch.LongTensor(two_labels)        
        two_labels_mask = torch.where(two_hops==3, 
            torch.ones_like(two_labels, dtype=torch.float32), 
            torch.zeros_like(two_labels, dtype=torch.float32))
        two_edges = torch.LongTensor(two_edges)
        very_neg_num = torch.ones_like(two_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(two_edges, dtype=torch.float32)
        two_mask = torch.where(two_edges==0, very_neg_num, zero_num)

        label_3 = (three_hop == 4).astype(int)
        three_nodes = torch.LongTensor(three_graph)
        three_hops = torch.LongTensor(three_hop)
        three_labels = torch.LongTensor(three_labels)
        three_labels_mask = torch.where(three_hops==4, 
            torch.ones_like(three_labels, dtype=torch.float32), 
            torch.zeros_like(three_labels, dtype=torch.float32))
        three_edges = torch.LongTensor(three_edges)
        very_neg_num = torch.ones_like(three_edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(three_edges, dtype=torch.float32)
        three_mask = torch.where(three_edges==0, very_neg_num, zero_num)

        return query_text, answer_text, zero_nodes, zero_hops, zero_edges, zero_mask, zero_labels, zero_labels_mask,\
            one_nodes, one_hops, one_edges, one_mask, one_labels, one_labels_mask,\
            two_nodes, two_hops, two_edges, two_mask, two_labels, two_labels_mask,\
            three_nodes, three_hops, three_edges, three_mask, three_labels, three_labels_mask,\
            gold_0, zero_graph, label_0, gold_1, one_graph, label_1,\
            gold_2, two_graph, label_2, gold_3, three_graph, label_3
            