import pickle
from typing import Dict, Tuple, List
import os

import numpy as np
import torch
from models import KBCModel
from collections import defaultdict

class Dataset(object):
    def __init__(self, data_path: str, name: str):
        self.root = os.path.join(data_path, name)

        self.data = {}
        for f in ['train', 'test', 'valid']:
            in_file = open(os.path.join(self.root, f + '.pickle'), 'rb')
            self.data[f] = pickle.load(in_file)

        # mod
        with open(os.path.join(self.root,'rule/relation2id.txt'), 'r') as f:  #关系
            data = f.read().strip().split("\n")
            data = [i.split() for i in data]
        self.relation2id = dict(data)
        for rel in self.relation2id:
            self.relation2id[rel] = int(self.relation2id[rel])

        with open(os.path.join(self.root,'rule/entity2id.txt'), 'r') as f: #实体
            data = f.read().strip().split("\n")
            data = [i.split() for i in data]
        self.entity2id = dict(data)
        for rel in self.entity2id:
            self.entity2id[rel] = int(self.entity2id[rel])

        for f in ['train', 'test', 'valid']: #将原始的数据替换成后面的数据  #triple
            with open(os.path.join(self.root,'rule/'+f+'.txt'), 'r') as fx:
                data = fx.read().strip().split("\n")
                data = [i.split() for i in data]
            for ii,tri in enumerate(data):
                self.data[f][ii,:] =  np.array([self.entity2id[tri[0]], self.relation2id[tri[2]], self.entity2id[tri[1]]])
                


        print(self.data['train'].shape)

        maxis = np.max(self.data['train'], axis=0)
        maxis_valid = np.max(self.data['valid'], axis=0)
        maxis_test = np.max(self.data['test'], axis=0)
        
        self.n_entities = int(max(maxis[0], maxis[2],maxis_valid[0], maxis_valid[2],maxis_test[0], maxis_test[2]) + 1)
        self.n_predicates = int(maxis[1] + 1)
        self.n_relations = int(maxis[1] + 1)
        self.n_predicates *= 2

        # '''生成跳接字典'''
        # # # 通过加载字典，生成跳接字典
        # self.to_skip = {}
        # self.to_skip['lhs'] = defaultdict(list)
        # self.to_skip['rhs'] = defaultdict(list)
        # self.to_skip['lhs-train'] = defaultdict(list)
        # self.to_skip['rhs-train'] = defaultdict(list)
        # for f in ['train', 'test', 'valid']: #将原始的数据替换成后面的数据  #triple
        #     for triple in self.data[f]:
        #         self.to_skip['lhs'][(triple[2], triple[1]+len(self.relation2id))].append(triple[0])
        #         self.to_skip['rhs'][(triple[0], triple[1])].append(triple[2])
        #     if f == 'train':
        #         for triple in self.data[f]:
        #             self.to_skip['lhs-train'][(triple[2], triple[1]+len(self.relation2id))].append(triple[0])
        #             self.to_skip['rhs-train'][(triple[0], triple[1])].append(triple[2])
        
        # inp_f = open(os.path.join(self.root, 'rule/to_skip-new.pickle'), 'wb')
        # pickle.dump(self.to_skip, inp_f)
        # inp_f.close()
        # '''生成跳接字典'''

        inp_f = open(os.path.join(self.root, 'rule/to_skip-new.pickle'), 'rb')
        self.to_skip: Dict[str, Dict[Tuple[int, int], List[int]]] = pickle.load(inp_f)
        inp_f.close()

        # 具体的尝试思路：通过引入规则，搜索所有相关的triple（只寻找不存在的triple）
        # 并且为不存在triple附上一个权重，权重用来计算相应的输出值
        rule_num = '50'  #选择使用的规则的序号
        with open(os.path.join(self.root,'rule/rule_relation'+rule_num+'.txt'), 'r') as f:  # path
            data = f.read().strip().split("\n")
            data = [i.split() for i in data]
        rule = np.zeros((len(data),3),dtype=np.float32)
        for ii,rel in enumerate(data):
            rule[ii,:] = np.array([int(rel[0]), int(rel[1]), float(rel[2])])
            
        rule_dic = defaultdict(list)
        for ss in rule:
            rule_dic[ss[1]].append([ss[0],ss[2]])
            
            
        with open(os.path.join(self.root,'rule/rule_path'+rule_num+'.txt'), 'r') as f:  # rule
            data = f.read().strip().split("\n")
            data = [i.split() for i in data]
        path = np.zeros((len(data),4),dtype=np.float32)
        for ii,rel in enumerate(data):
            path[ii,:] = np.array([int(rel[0]), int(rel[1]), int(rel[2]), float(rel[3])])
            
        path_dic = defaultdict(list)
        for ss in path:
            path_dic[ss[2]].append([ss[0], ss[1], ss[3]])
            
        # 加载生成的规则。
        a = 0
        in_file = open(os.path.join(self.root, 'rule/rule_'+rule_num+'-new.pickle'), 'rb')
        self.data['rule'] = pickle.load(in_file)
        in_file.close()
        in_file = open(os.path.join(self.root, 'rule/rule_inv_'+rule_num+'-new.pickle'), 'rb')
        self.data['rule-inv'] = pickle.load(in_file)
        in_file.close()
        a = 0
        
        
        # '''补全数据'''
        # self.data['rule'] = []
        # head_rel_set = list(self.to_skip['rhs-train'].keys())
        # for ii,tri in enumerate(head_rel_set):
        # # for ii,tri in enumerate(self.data['train']):
        #     # 1.通过关系，确定规则； 
        #     # 2.通过规则确定备选的实体；(直接规则或者是两步规则)
        #     # 3.剔除备选的实体，讲剩余的实体纳入到三元组集合中去.
        #     tail_pro = {}
        #     # 一步规则
        #     for [rel,pro] in rule_dic[tri[1]]:
        #         tail_set = []  #当前规则生成的尾市体
        #         if rel<self.n_relations:
        #             tail_set = set(self.to_skip['rhs-train'][(tri[0],rel)])-set(self.to_skip['rhs-train'][(tri[0],tri[1])])
        #         else:
        #             tail_set = set(self.to_skip['lhs-train'][(tri[0],rel)])-set(self.to_skip['rhs-train'][(tri[0],tri[1])])
        #         for tail in tail_set:
        #             try:
        #                 tail_pro[tail] = max(pro,tail_pro[tail] )
        #             except:
        #                 tail_pro[tail] = pro
        #             if tail_pro[tail]==1:
        #                 a==0
        #             # self.data['rule'].append([tri[0], tri[1], tail, pro])
        #     # 两步规则：
        #     for [rel1, rel2, pro] in path_dic[tri[1]]:
        #         tail_set = set([])
        #         middle_set = []  #中间实体
        #         if rel1<self.n_relations:
        #             middle_set = set(self.to_skip['rhs-train'][(tri[0],rel1)])
        #         else:
        #             middle_set = set(self.to_skip['lhs-train'][(tri[0],rel1)])
                
        #         for middle in  middle_set:
        #             if rel2<self.n_relations:
        #                 tail_set = tail_set | set(self.to_skip['rhs-train'][(middle,rel2)])
        #             else:
        #                 tail_set = tail_set | set(self.to_skip['lhs-train'][(middle,rel2)])
                        
        #         tail_set = tail_set - set(self.to_skip['rhs-train'][(tri[0],tri[1])])       
                
        #         for tail in tail_set:
        #             try:
        #                 tail_pro[tail] = max(pro,tail_pro[tail] )
        #             except:
        #                 tail_pro[tail] = pro
                        
        #             if tail_pro[tail]==1:
        #                 a=0


        #     for ss in tail_pro:
        #         if tail_pro[tail]==1:
        #             a = 0
        #         self.data['rule'].append([tri[0], tri[1], ss, tail_pro[tail]])

        # self.data['rule'] = np.array(self.data['rule'],dtype=np.float32)
        # a = 0

        # self.data['rule-inv'] = []
        # tail_rel_set = list(self.to_skip['lhs-train'].keys())
        # for ii,tri in enumerate(tail_rel_set):
        # # for ii,tri in enumerate(self.data['train']):
        #     # 1.通过关系，确定规则； 
        #     # 2.通过规则确定备选的实体；(直接规则或者是两步规则)
        #     # 3.剔除备选的实体，讲剩余的实体纳入到三元组集合中去.
        #     tail_pro = {}
        #     # 一步规则
        #     for [rel,pro] in rule_dic[tri[1]]:
        #         tail_set = []  #当前规则生成的尾市体
        #         if rel<self.n_relations:
        #             tail_set = set(self.to_skip['rhs-train'][(tri[0],rel)])-set(self.to_skip['lhs-train'][(tri[0],tri[1])])
        #         else:
        #             tail_set = set(self.to_skip['lhs-train'][(tri[0],rel)])-set(self.to_skip['lhs-train'][(tri[0],tri[1])])
        #         for tail in tail_set:
        #             try:
        #                 tail_pro[tail] = max(pro,tail_pro[tail] )
        #             except:
        #                 tail_pro[tail] = pro
        #             # self.data['rule'].append([tri[0], tri[1], tail, pro])
        #     # 两步规则：
        #     for [rel1, rel2, pro] in path_dic[tri[1]]:
        #         tail_set = set([])
        #         middle_set = []
        #         if rel1<self.n_relations:
        #             middle_set = set(self.to_skip['rhs-train'][(tri[0],rel1)])
        #         else:
        #             middle_set = set(self.to_skip['lhs-train'][(tri[0],rel1)])
                
        #         for middle in  middle_set:
        #             if rel2<self.n_relations:
        #                 tail_set = tail_set | set(self.to_skip['rhs-train'][(middle,rel2)])
        #             else:
        #                 tail_set = tail_set | set(self.to_skip['lhs-train'][(middle,rel2)])
                        
        #         tail_set = tail_set - set(self.to_skip['lhs-train'][(tri[0],tri[1])])       
                
        #         for tail in tail_set:
        #             try:
        #                 tail_pro[tail] = max(pro,tail_pro[tail] )
        #             except:
        #                 tail_pro[tail] = pro
                    
        #     for ss in tail_pro:
        #         if tail_pro[tail]==1:
        #             a = 0
        #         self.data['rule-inv'].append([tri[0], tri[1], ss, tail_pro[tail]])

        # self.data['rule-inv'] = np.array(self.data['rule-inv'],dtype=np.float32)

        # inp_f = open(os.path.join(self.root, 'rule/rule_'+rule_num+'-new.pickle'), 'wb')
        # pickle.dump(self.data['rule'], inp_f)
        # inp_f.close()
        
        # inp_f = open(os.path.join(self.root, 'rule/rule_inv_'+rule_num+'-new.pickle'), 'wb')
        # pickle.dump(self.data['rule-inv'], inp_f)
        # inp_f.close()
        # '''补全数据'''


    def get_weight(self):
        appear_list = np.zeros(self.n_entities)
        copy = np.copy(self.data['train'])
        for triple in copy:
            h, r, t = triple
            appear_list[h] += 1
            appear_list[t] += 1

        w = appear_list / np.max(appear_list) * 0.9 + 0.1
        return w

    def get_examples(self, split):
        return self.data[split]

    def get_train(self):
        copy = np.copy(self.data['train'])
        tmp = np.copy(copy[:, 0])
        copy[:, 0] = copy[:, 2]
        copy[:, 2] = tmp
        copy[:, 1] += self.n_predicates // 2  # has been multiplied by two.
        return np.vstack((self.data['train'], copy))   # 将triple进行反转。问题解决

    def get_rule(self):

        return np.vstack((self.data['rule'], self.data['rule-inv']))   # 将triple进行反转。问题解决



    def eval(
            self, model: KBCModel, split: str, n_queries: int = -1, missing_eval: str = 'both',
            at: Tuple[int] = (1, 3, 10), log_result=False, save_path=None
    ):
        model.eval()
        test = self.get_examples(split)
        examples = torch.from_numpy(test.astype('int64')).cuda()
        missing = [missing_eval]
        if missing_eval == 'both':
            missing = ['rhs', 'lhs']

        mean_reciprocal_rank = {}
        hits_at = {}

        flag = False
        for m in missing:
            q = examples.clone()
            if n_queries > 0:
                permutation = torch.randperm(len(examples))[:n_queries]
                q = examples[permutation]
            if m == 'lhs':  #预测头实体：置换实体对
                tmp = torch.clone(q[:, 0])
                q[:, 0] = q[:, 2]
                q[:, 2] = tmp
                q[:, 1] += self.n_predicates // 2
            ranks = model.get_ranking(q, self.to_skip[m], batch_size=500)

            if log_result:
                if not flag:
                    results = np.concatenate((q.cpu().detach().numpy(),
                                              np.expand_dims(ranks.cpu().detach().numpy(), axis=1)), axis=1)
                    flag = True
                else:
                    results = np.concatenate((results, np.concatenate((q.cpu().detach().numpy(),
                                              np.expand_dims(ranks.cpu().detach().numpy(), axis=1)), axis=1)), axis=0)

            mean_reciprocal_rank[m] = torch.mean(1. / ranks).item()
            hits_at[m] = torch.FloatTensor((list(map(
                lambda x: torch.mean((ranks <= x).float()).item(),
                at
            ))))

        return mean_reciprocal_rank, hits_at

    def get_shape(self):
        return self.n_entities, self.n_predicates, self.n_entities
