import json
import os
import random
import sys
import time
import argparse
import itertools
import numpy as np
from joblib import Parallel, delayed
import torch

import rule_application as ra
from grapher import Grapher
from temporal_walk import store_edges
from rule_learning import rules_statistics
from score_functions import score_12


# 为每个查询应用规则，得到候选实体及其评分，存储到文件中
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="", type=str)
parser.add_argument("--test_data", default="test", type=str)
parser.add_argument("--rules", "-r", default="", type=str)
parser.add_argument("--rule_lengths", "-l", default=1, type=int, nargs="+")
parser.add_argument("--window", "-w", default=-1, type=int)
parser.add_argument("--top_k", default=20, type=int)
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument('--seed', '-s', type=int, default=12)
parser.add_argument('--negative', type=int, default=-1)
parser.add_argument('--rule_used', type=int, default=3)
parsed = vars(parser.parse_args())

dataset = parsed["dataset"]
rules_file = parsed["rules"]
window = parsed["window"]
top_k = parsed["top_k"]
num_processes = parsed["num_processes"]
rule_lengths = parsed["rule_lengths"]
rule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths
seed = parsed['seed']
negative = parsed['negative']
rule_used = parsed['rule_used']


def set_seed(seed=6666):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # dgl.seed(seed)
    torch.backends.cudnn.deterministic = True  # 防止报错没有卷积算法可用
    torch.use_deterministic_algorithms(True)  # 使用确定的PyTorch算法，如卷积等

    # torch.cuda.manual_seed(seed)  # 在当前GPU上设置随机数种子
    # torch.backends.cudnn.benchmark = False  # 如果为True，则在多个卷积算法中选择最快的
    # torch.backends.cudnn.enabled = False  # 是否使用cuDNN
    # os.environ['PYTHONHASHSEED'] = str(seed)  # 设置hash种子


set_seed(seed)
dataset_dir = "../data/" + dataset + "/"
# dir_path = f'../output_rule/{dataset}/seed{seed}'
dir_path = f'../output_rule/{dataset}/seed{seed}'
data = Grapher(dataset_dir)
if parsed['test_data'] == 'train':
    test_data = data.train_idx
elif parsed['test_data'] == 'valid':
    test_data = data.valid_idx
else:
    test_data = data.test_idx

# test_data = data.test_idx if (parsed["test_data"] == "test") else data.valid_idx
rules_dict = json.load(open(os.path.join(dir_path, rules_file)))
rules_dict = {int(k): v for k, v in rules_dict.items()}
print("Rules statistics:")
rules_statistics(rules_dict)
rules_dict = ra.filter_rules(  # 按照最小置信度和主体支持过滤规则
    rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths
)
print("Rules statistics after pruning:")
rules_statistics(rules_dict)
learn_edges = store_edges(data.train_idx)  # 字典记录每个关系对应所有的边
rel_num = len(data.relation2id)

# 将过滤后的规则写入新文件，使用时id才能对的上
new_rules_file = rules_file[:-5] + '_filtered.json'
new_dir_path = f'../output_rule/{dataset}/seed{seed}'
new_rules_path = os.path.join(new_dir_path, new_rules_file)
if not os.path.exists(new_rules_path):
    with open(new_rules_path, 'w') as file:
        json.dump(rules_dict, file)
score_func = score_12
# It is possible to specify a list of list of arguments for tuning
args = [[0.1, 0.5]]


def get_rules_info(rules_info, rule_len, cur_ts, idx):
    rels_info = []  # 记录规则使用的关系链信息，二维列表
    tss_info = []  # 记录关系链中关系对应的时间戳信息，二维列表
    rel_len_info = []  # 记录每条规则对应的拓展长度，一维列表
    mask_info = []  # 规则掩码，一维列表
    extend_rels_info = rule_len * [rel_num]  # 拓展规则
    cur_window = min(window, cur_ts)  # 历史序列长度可能小于window大小
    extend_tss_info = rule_len * [cur_window - 1]  # 拓展时间戳
    for rule_info in rules_info:
        rel_info = rule_info[2] + (rule_len - len(rule_info[2])) * [rel_num]
        ts_info = rule_info[3] + (rule_len - len(rule_info[3])) * [cur_ts - 1]
        ts_info = [cur_window - (cur_ts - ts) for ts in ts_info]
        rels_info.append(rel_info)
        tss_info.append(ts_info)
        rel_len_info.append(len(rule_info[2]) - 1)
        mask_info.append(1)
    cur_len = len(rules_info)
    for i in range(rule_used - cur_len):
        rels_info.append(extend_rels_info)
        tss_info.append(extend_tss_info)
        rel_len_info.append(0)
        mask_info.append(0)
    # 记录的长度最后要作为一维张量索引取对应规则表征张量中的值，要加上内部偏移和外部偏移
    # 训练中处理
    # rel_len_info = [idx * (rule_used * rule_len) + rel_len + i * rule_len for i, rel_len in enumerate(rel_len_info)]

    return rels_info, tss_info, rel_len_info, mask_info


# 得到每个查询候选实体的评分，和没有候选实体的查询的个数
def apply_rules(i, num_queries):
    """
    Apply rules (multiprocessing possible).

    Parameters:
        i (int): process number
        num_queries (int): minimum number of queries for each process

    Returns:
        all_candidates (list): answer candidates with corresponding confidence scores
        no_cands_counter (int): number of queries with no answer candidates
    """

    print("Start process", i, "...")
    all_candidates = [dict() for _ in range(len(args))]
    # all_candidates_with_rules = [dict() for _ in range(len(args))]
    all_candidates_evolve = [dict() for _ in range(len(args))]
    no_cands_counter = 0

    num_rest_queries = len(test_data) - (i + 1) * num_queries
    if num_rest_queries >= num_queries:
        test_queries_idx = range(i * num_queries, (i + 1) * num_queries)
    else:
        test_queries_idx = range(i * num_queries, len(test_data))

    cur_ts = test_data[test_queries_idx[0]][3]
    edges = ra.get_window_edges(data.all_idx, cur_ts, learn_edges, window)  # 得到时间窗口内所有关系对应的边的字典

    it_start = time.time()
    for j in test_queries_idx:
        test_query = test_data[j]
        # print(test_query[0], test_query[1], test_query[2], test_query[3])
        if test_query[0] == 4048 and test_query[1] == 1 and test_query[2] == 4049 and test_query[3] == 179:
            print(test_query)
        cands_dict = [dict() for _ in range(len(args))]
        cands_dict_evolve = [dict() for _ in range(len(args))]
        # cands_dict_with_rules = [dict() for _ in range(len(args))]  # 双重字典，每个候选实体对应一个规则集合

        if test_query[3] != cur_ts:
            cur_ts = test_query[3]
            edges = ra.get_window_edges(data.all_idx, cur_ts, learn_edges, window)

        if test_query[1] in rules_dict:
            dicts_idx = list(range(len(args)))
            # 对与该关系相关的每一条规则，计算候选实体，得到候选实体在适配规则上的评分，并从高到低排序，得到候选实体字典
            for rule_idx, rule in enumerate(rules_dict[test_query[1]]):
                walk_edges = ra.match_body_relations(rule, edges, test_query[0])  # 返回关系对应边数组的列表

                # 得到规则筛选后链路的DataFrame
                if 0 not in [len(x) for x in walk_edges]:
                    rule_walks = ra.get_walks(rule, walk_edges)
                    if rule["var_constraints"]:  # 有变量限制进一步筛选实体链
                        rule_walks = ra.check_var_constraints(
                            rule["var_constraints"], rule_walks
                        )

                    if not rule_walks.empty:
                        # cands_dict_with_rules = ra.get_candidates_with_rules(  # 得到每个候选实体能到达的规则集合
                        #     rule,
                        #     rule_walks,
                        #     cur_ts,
                        #     cands_dict_with_rules,
                        #     score_func,
                        #     args,
                        #     dicts_idx,
                        #     rule_idx,
                        # )
                        cands_dict = ra.get_candidates(  # 得到一个规则对应的候选实体字典及评分
                            rule,
                            rule_walks,
                            cur_ts,
                            cands_dict,
                            score_func,
                            args,
                            dicts_idx,
                        )
                        cands_dict_evolve = ra.get_candidates_evolve(  # 得到一个规则对应的候选实体字典及评分
                            rule,
                            rule_walks,
                            cur_ts,
                            cands_dict_evolve,
                            score_func,
                            args,
                            dicts_idx,
                            rule_idx,
                            rule['body_rels'],
                        )
                        for s in dicts_idx:
                            # 修改后的每个规则的评分相关信息可以使用相同的排序方法
                            cands_dict[s] = {  # 对同一个候选实体不同规则的相关信息排序，默认使用列表第一个元素评分排序
                                x: sorted(cands_dict[s][x], reverse=True)
                                for x in cands_dict[s].keys()
                            }
                            cands_dict[s] = dict(  # 根据不同的候选实体的最高评分（二维列表的第一个元素）进行排序
                                sorted(
                                    cands_dict[s].items(),
                                    key=lambda item: item[1],
                                    reverse=True,
                                )
                            )
                            cands_dict_evolve[s] = {  # 对同一个候选实体不同规则的相关信息排序，默认使用列表第一个元素评分排序
                                x: sorted(cands_dict_evolve[s][x], key=lambda item: (item[0], -item[1]), reverse=True)  # 规则分数相同按id升序排
                                for x in cands_dict_evolve[s].keys()
                            }
                            cands_dict_evolve[s] = dict(  # 根据不同的候选实体的最高评分（二维列表的第一个元素）进行排序
                                sorted(
                                    cands_dict_evolve[s].items(),
                                    key=lambda item: (item[1][0][0], -item[0]),  # 实体分数相同按id升序排
                                    reverse=True,
                                )
                                # 如果使用item[1]排序，默认使用分数最高的规则信息，但是不止比较规则的分数，还会依次逐项比较，所以需要说明清楚
                            )
                            # for k, v in cands_dict_evolve[s].items():
                            #     print(k, v[0])
                            top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]
                            unique_scores = list(  # 去除重复的评分列表
                                scores for scores, _ in itertools.groupby(top_k_scores)
                            )
                            if len(unique_scores) >= top_k:  # 评分数量不小于k个则停止规则的应用
                                dicts_idx.remove(s)
                        if not dicts_idx:
                            break
            if cands_dict[0]:
                for s in range(len(args)):
                    # Calculate noisy-or scores，为单个查询的每个候选实体计算最终评分并排序，最后记录
                    scores = list(
                        map(
                            lambda x: 1 - np.product(1 - np.array(x)),
                            cands_dict[s].values(),
                        )
                    )
                    cands_scores = dict(zip(cands_dict[s].keys(), scores))
                    noisy_or_cands = dict(
                        sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)
                    )
                    all_candidates[s][j] = noisy_or_cands

                    # 存储5个张量
                    # 可到达实体id：tail_reachable，一维
                    # 规则对应关系id：tail_reachble * rule_used * rule_len，三维
                    # 关系对应时间戳：tail_reachble * rule_used * rule_len，三维
                    # 拓展的GRU表征id：tail_reachble * rule_used，一维
                    # 掩码信息：tail_reachble * rule_used，一维
                    if negative == -1:
                        tail_ent_reachable = len(cands_dict_evolve[s])
                    else:
                        tail_ent_reachable = min(len(cands_dict_evolve[s]), negative)

                    rule_len = rule_lengths[-1]
                    tail_ent_idx = []
                    rels_info_all = []
                    tss_info_all = []
                    rel_len_info_all = []
                    mask_info_all = []
                    rule_num = len(rules_dict[test_query[1]])  # 将rule对应的关系id张量代替为规则id张量
                    tail_ent_rule_tensor = torch.zeros((tail_ent_reachable, rule_num))

                    # score_info_all = []

                    idx = 0
                    if test_query[2] in cands_dict_evolve[s]:  # 如果限制负例数量，训练时把正例放第一行
                        k = test_query[2]
                        rules_info = cands_dict_evolve[s][k][:rule_used]  # 取排名前rule_used个规则信息
                        rels_info, tss_info, rel_len_info, mask_info = get_rules_info(rules_info, rule_len, cur_ts, idx)

                        rule_id = [rule[1] for rule in rules_info]
                        rule_id_extend = rule_id + (rule_num - len(rule_id)) * [rule_id[-1]]  # 利用最后一个id拓展
                        tail_ent_rule_tensor[idx, :] = torch.Tensor(rule_id_extend)
                        if cur_ts == 179 and rule_num != tail_ent_rule_tensor.shape[1]:
                            print(rule_num, tail_ent_rule_tensor.shape[1])

                        rels_info_all.append(rels_info)  # tail_reachable * rule_used * rule_len的三维张量
                        tss_info_all.append(tss_info)
                        rel_len_info_all.extend(rel_len_info)  # tail_reachable * rule_used大小的一维张量
                        mask_info_all.extend(mask_info)
                        tail_ent_idx.append(k)

                        # # 记录每个尾实体评分最大的规则，判断随机性
                        # score_info_all.append(cands_dict_evolve[s][k][0][0])

                        idx += 1
                    for k, v in cands_dict_evolve[s].items():
                        if k == test_query[2]:
                            continue
                        rules_info = v[:rule_used]  # 取排名前rule_used个规则信息
                        rels_info, tss_info, rel_len_info, mask_info = get_rules_info(rules_info, rule_len, cur_ts, idx)

                        rule_id = [rule[1] for rule in rules_info]
                        rule_id_extend = rule_id + (rule_num - len(rule_id)) * [rule_id[-1]]  # 利用最后一个id拓展
                        tail_ent_rule_tensor[idx, :] = torch.Tensor(rule_id_extend)
                        if cur_ts == 179 and rule_num != tail_ent_rule_tensor.shape[1]:
                            print(rule_num, tail_ent_rule_tensor.shape[1])

                        rels_info_all.append(rels_info)
                        tss_info_all.append(tss_info)
                        rel_len_info_all.extend(rel_len_info)  # tail_reachable * rule_used大小的一维张量
                        mask_info_all.extend(mask_info)
                        tail_ent_idx.append(k)

                        # # 记录每个尾实体评分最大的规则，判断随机性
                        # score_info_all.append(cands_dict_evolve[s][k][0][0])

                        idx += 1
                        if idx >= tail_ent_reachable:
                            break

                    # print(tail_ent_idx, rels_info_all, tss_info_all, rel_len_info_all, mask_info_all)
                    # print(torch.LongTensor(tail_ent_idx).shape, torch.LongTensor(rels_info_all).shape, torch.LongTensor(tss_info_all).shape, torch.LongTensor(rel_len_info_all).shape, torch.LongTensor(mask_info_all).shape)
                    # res = (torch.LongTensor(tail_ent_idx), torch.LongTensor(rels_info_all), torch.LongTensor(tss_info_all), torch.LongTensor(rel_len_info_all), torch.Tensor(mask_info_all))
                    res = (torch.LongTensor(tail_ent_idx), tail_ent_rule_tensor.long(), torch.LongTensor(tss_info_all), torch.LongTensor(rel_len_info_all), torch.Tensor(mask_info_all))

                    if cur_ts in all_candidates_evolve[s]:
                        all_candidates_evolve[s][cur_ts].append(res)  # 以时间戳为键
                    else:
                        all_candidates_evolve[s][cur_ts] = [res]

            else:  # No candidates found by applying rules
                no_cands_counter += 1
                for s in range(len(args)):
                    all_candidates[s][j] = dict()

                    res = ()
                    if cur_ts in all_candidates_evolve[s]:
                        all_candidates_evolve[s][cur_ts].append(res)  # 以时间戳为键
                    else:
                        all_candidates_evolve[s][cur_ts] = [res]

        else:  # No rules exist for this relation，没有候选实体适应规则或关系没有对应的规则
            no_cands_counter += 1
            for s in range(len(args)):
                all_candidates[s][j] = dict()

                res = ()
                if cur_ts in all_candidates_evolve[s]:
                    all_candidates_evolve[s][cur_ts].append(res)  # 以时间戳为键
                else:
                    all_candidates_evolve[s][cur_ts] = [res]

        # 每个进程每处理100个记录一下
        if not (j - test_queries_idx[0] + 1) % 100:
            it_end = time.time()
            it_time = round(it_end - it_start, 6)
            print(
                "Process {0}: test samples finished: {1}/{2}, {3} sec".format(
                    i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time
                )
            )
            it_start = time.time()

    return all_candidates, all_candidates_evolve, no_cands_counter


def calc_perf_cap(final_all_candidates_with_rules, test_idx, dir_path):
    ans_tail = test_idx[:, 2]
    total_num = ans_tail.shape[0]
    correct_num = 0
    idx = -1

    for k, v in final_all_candidates_with_rules.items():
        for pair in v:
            idx += 1
            if len(pair) == 0:
                continue
            tail_ent_ids = pair[0]
            if torch.any(tail_ent_ids == ans_tail[idx]):
                correct_num += 1

    perf = correct_num / total_num
    file_path = f'{dir_path}/perf_cap.txt'
    with open(file_path, 'w') as file:
        file.write(f'Hits@1: {perf:.4f}')


start = time.time()
num_queries = len(test_data) // num_processes
output = Parallel(n_jobs=num_processes)(  # 得到所有查询的候选实体及评分
    delayed(apply_rules)(i, num_queries) for i in range(num_processes)
)
end = time.time()

# 字典列表，每个字典为在特定评分函数超参配置下，以查询id为键，候选项及评分为值的字典
final_all_candidates = [dict() for _ in range(len(args))]
final_all_candidates_evolve = [dict() for _ in range(len(args))]
for s in range(len(args)):
    for i in range(num_processes):
        final_all_candidates[s].update(output[i][0][s])

        # 处理合并字典时存在相同键导致值覆盖的情况
        keys = list(output[i][1][s].keys())
        ls1 = []
        ls2 = []
        # 如果对应时间戳的键已存在，保留原字典中的值
        if keys[0] in final_all_candidates_evolve[s].keys():
            ls1 = final_all_candidates_evolve[s][keys[0]]
        if keys[-1] in final_all_candidates_evolve[s].keys():
            ls2 = final_all_candidates_evolve[s][keys[-1]]

        final_all_candidates_evolve[s].update(output[i][1][s])
        # 合并同一时间戳的两个列表
        # 当时间戳数量小于线程数量时会出现顺序的错误
        if len(ls1) != 0:
            final_all_candidates_evolve[s][keys[0]] = ls1 + output[i][1][s][keys[0]]
        if len(ls2) != 0:
            final_all_candidates_evolve[s][keys[-1]] = output[i][1][s][keys[-1]] + ls2
        output[i][0][s].clear()
        output[i][1][s].clear()

final_no_cands_counter = 0
for i in range(num_processes):
    final_no_cands_counter += output[i][2]

total_time = round(end - start, 6)
print("Application finished in {} seconds.".format(total_time))
print("No candidates: ", final_no_cands_counter, " queries")

# dir_path_new = f'../output_rule/{dataset}/seed{seed}/cands_with_rules'
dir_path_new = f'../output_rule/{dataset}/seed{seed}/cands_with_rules'
for s in range(len(args)):
    score_func_str = score_func.__name__ + str(args[s])
    score_func_str = score_func_str.replace(" ", "")
    ra.save_candidates_static(
        rules_file,
        dir_path_new,
        final_all_candidates[s],
        final_all_candidates_evolve[s],
        rule_lengths,
        window,
        score_func_str,
        parsed['test_data'],
        negative,
        rule_used,
    )
    if parsed['test_data'] == 'test':
        calc_perf_cap(final_all_candidates_evolve[s], data.test_idx, dir_path)
