import json
import os
import random
import sys
import numpy as np
import pandas as pd
from collections import Counter

import torch

from temporal_walk import store_edges


def filter_rules(rules_dict, min_conf, min_body_supp, rule_lengths):
    """
    Filter for rules with a minimum confidence, minimum body support, and
    specified rule lengths.

    Parameters.
        rules_dict (dict): rules
        min_conf (float): minimum confidence value
        min_body_supp (int): minimum body support value
        rule_lengths (list): rule lengths

    Returns:
        new_rules_dict (dict): filtered rules
    """

    new_rules_dict = dict()
    for k in rules_dict:
        new_rules_dict[k] = []
        for rule in rules_dict[k]:
            cond = (
                (rule["conf"] >= min_conf)
                and (rule["body_supp"] >= min_body_supp)
                and (len(rule["body_rels"]) in rule_lengths)
            )
            if cond:
                new_rules_dict[k].append(rule)

    new_rules_dict = {k: v for k, v in new_rules_dict.items() if v}  # 过滤后有些关系不再有支持的规则，值为空list，将其删除

    return new_rules_dict


# 得到特定时间窗口内所有四元组中，存在的关系对应的边
def get_window_edges(all_data, test_query_ts, learn_edges, window=-1):
    """
    Get the edges in the data (for rule application) that occur in the specified time window.
    If window is 0, all edges before the test query timestamp are included.
    If window is -1, the edges on which the rules are learned are used. 只使用训练集中的边
    If window is an integer n > 0, all edges within n timestamps before the test query
    timestamp are included.

    Parameters:
        all_data (np.ndarray): complete dataset (train/valid/test)
        test_query_ts (np.ndarray): test query timestamp
        learn_edges (dict): edges on which the rules are learned
        window (int): time window used for rule application

    Returns:
        window_edges (dict): edges in the window for rule application
    """

    if window > 0:
        mask = (all_data[:, 3] < test_query_ts) * (
            all_data[:, 3] >= test_query_ts - window
        )
        window_edges = store_edges(all_data[mask])
    elif window == 0:
        mask = all_data[:, 3] < test_query_ts
        window_edges = store_edges(all_data[mask])
    elif window == -1:
        window_edges = learn_edges

    return window_edges


# 返回一个列表，列表中每个元素对应规则主体中一个关系满足的边，以[sub, obj, ts]二维数组的形式存储，第一个关系的主语实体和查询主语实体相同
def match_body_relations(rule, edges, test_query_sub):
    """
    Find edges that could constitute walks (starting from the test query subject)
    that match the rule.
    First, find edges whose subject match the query subject and the relation matches
    the first relation in the rule body. Then, find edges whose subjects match the
    current targets and the relation the next relation in the rule body.
    Memory-efficient implementation.

    Parameters:
        rule (dict): rule from rules_dict
        edges (dict): edges for rule application
        test_query_sub (int): test query subject

    Returns:
        walk_edges (list of np.ndarrays): edges that could constitute rule walks
    """

    rels = rule["body_rels"]
    # Match query subject and first body relation
    try:
        rel_edges = edges[rels[0]]
        mask = rel_edges[:, 0] == test_query_sub
        new_edges = rel_edges[mask]
        walk_edges = [
            np.hstack((new_edges[:, 0:1], new_edges[:, 2:4]))
        ]  # [sub, obj, ts]二维数组，主语实体全部相同
        # print(walk_edges)
        # sys.exit()
        cur_targets = np.array(list(set(walk_edges[0][:, 1])))  # 所有宾语实体的一维数组

        for i in range(1, len(rels)):
            # Match current targets and next body relation
            try:
                rel_edges = edges[rels[i]]
                mask = np.any(rel_edges[:, 0] == cur_targets[:, None], axis=0)
                new_edges = rel_edges[mask]
                walk_edges.append(
                    np.hstack((new_edges[:, 0:1], new_edges[:, 2:4]))
                )  # [sub, obj, ts]
                cur_targets = np.array(list(set(walk_edges[i][:, 1])))
            except KeyError:
                walk_edges.append([])
                break
    except KeyError:
        walk_edges = [[]]

    return walk_edges


def match_body_relations_complete(rule, edges, test_query_sub):
    """
    Find edges that could constitute walks (starting from the test query subject)
    that match the rule.
    First, find edges whose subject match the query subject and the relation matches
    the first relation in the rule body. Then, find edges whose subjects match the
    current targets and the relation the next relation in the rule body.

    Parameters:
        rule (dict): rule from rules_dict
        edges (dict): edges for rule application
        test_query_sub (int): test query subject

    Returns:
        walk_edges (list of np.ndarrays): edges that could constitute rule walks
    """

    rels = rule["body_rels"]
    # Match query subject and first body relation
    try:
        rel_edges = edges[rels[0]]
        mask = rel_edges[:, 0] == test_query_sub
        new_edges = rel_edges[mask]
        walk_edges = [new_edges]
        cur_targets = np.array(list(set(walk_edges[0][:, 2])))

        for i in range(1, len(rels)):
            # Match current targets and next body relation
            try:
                rel_edges = edges[rels[i]]
                mask = np.any(rel_edges[:, 0] == cur_targets[:, None], axis=0)
                new_edges = rel_edges[mask]
                walk_edges.append(new_edges)
                cur_targets = np.array(list(set(walk_edges[i][:, 2])))
            except KeyError:
                walk_edges.append([])
                break
    except KeyError:
        walk_edges = [[]]

    return walk_edges


# 得到满足规则的实体链的DataFrame
def get_walks(rule, walk_edges):
    """
    Get walks for a given rule. Take the time constraints into account.
    Memory-efficient implementation.

    Parameters:
        rule (dict): rule from rules_dict
        walk_edges (list of np.ndarrays): edges from match_body_relations

    Returns:
        rule_walks (pd.DataFrame): all walks matching the rule
    """

    df_edges = []
    df = pd.DataFrame(
        walk_edges[0],
        columns=["entity_" + str(0), "entity_" + str(1), "timestamp_" + str(0)],
        dtype=np.uint16,
    )  # Change type if necessary for better memory efficiency
    if not rule["var_constraints"]:  # 如果规则中没有相同实体的要求则只需要保留两列
        del df["entity_" + str(0)]
    df_edges.append(df)
    df = df[0:0]  # Memory efficiency
    # print(df_edges)
    # sys.exit()

    for i in range(1, len(walk_edges)):
        df = pd.DataFrame(
            walk_edges[i],
            columns=["entity_" + str(i), "entity_" + str(i + 1), "timestamp_" + str(i)],
            dtype=np.uint16,
        )  # Change type if necessary
        df_edges.append(df)
        df = df[0:0]

    rule_walks = df_edges[0]
    df_edges[0] = df_edges[0][0:0]
    # 依据实体对边进行合并，再根据时间戳进行筛选
    # 如果没有实体变量限制，只需要关注最后一个实体，中间实体是什么不关注；反之需要记录全部实体，最后在判断是否满足相同实体的条件
    for i in range(1, len(df_edges)):
        rule_walks = pd.merge(rule_walks, df_edges[i], on=["entity_" + str(i)])
        # print(rule_walks)
        # sys.exit()
        rule_walks = rule_walks[
            rule_walks["timestamp_" + str(i - 1)] <= rule_walks["timestamp_" + str(i)]
        ]
        if not rule["var_constraints"]:
            del rule_walks["entity_" + str(i)]
        df_edges[i] = df_edges[i][0:0]

    # 删除除最早时间戳外的所有时间戳信息
    # for i in range(1, len(rule["body_rels"])):
    #     del rule_walks["timestamp_" + str(i)]

    return rule_walks


def get_walks_complete(rule, walk_edges):
    """
    Get complete walks for a given rule. Take the time constraints into account.

    Parameters:
        rule (dict): rule from rules_dict
        walk_edges (list of np.ndarrays): edges from match_body_relations

    Returns:
        rule_walks (pd.DataFrame): all walks matching the rule
    """

    df_edges = []
    df = pd.DataFrame(
        walk_edges[0],
        columns=[
            "entity_" + str(0),
            "relation_" + str(0),
            "entity_" + str(1),
            "timestamp_" + str(0),
        ],
        dtype=np.uint16,
    )  # Change type if necessary for better memory efficiency
    df_edges.append(df)

    for i in range(1, len(walk_edges)):
        df = pd.DataFrame(
            walk_edges[i],
            columns=[
                "entity_" + str(i),
                "relation_" + str(i),
                "entity_" + str(i + 1),
                "timestamp_" + str(i),
            ],
            dtype=np.uint16,
        )  # Change type if necessary
        df_edges.append(df)

    rule_walks = df_edges[0]
    for i in range(1, len(df_edges)):
        rule_walks = pd.merge(rule_walks, df_edges[i], on=["entity_" + str(i)])
        rule_walks = rule_walks[
            rule_walks["timestamp_" + str(i - 1)] <= rule_walks["timestamp_" + str(i)]
        ]

    return rule_walks


def check_var_constraints(var_constraints, rule_walks):
    """
    Check variable constraints of the rule.

    Parameters:
        var_constraints (list): variable constraints from the rule
        rule_walks (pd.DataFrame): all walks matching the rule

    Returns:
        rule_walks (pd.DataFrame): all walks matching the rule including the variable constraints
    """

    for const in var_constraints:
        for i in range(len(const) - 1):
            rule_walks = rule_walks[
                rule_walks["entity_" + str(const[i])]
                == rule_walks["entity_" + str(const[i + 1])]
            ]

    return rule_walks


# 得到单个规则的候选实体的字典，记录了对应的评分
def get_candidates(
    rule, rule_walks, test_query_ts, cands_dict,  score_func, args, dicts_idx
):
    """
    Get from the walks that follow the rule the answer candidates.
    Add the confidence of the rule that leads to these candidates.

    Parameters:
        rule (dict): rule from rules_dict
        rule_walks (pd.DataFrame): rule walks (satisfying all constraints from the rule)
        test_query_ts (int): test query timestamp
        cands_dict (dict): candidates along with the confidences of the rules that generated these candidates
        score_func (function): function for calculating the candidate score
        args (list): arguments for the scoring function
        dicts_idx (list): indices for candidate dictionaries

    Returns:
        cands_dict (dict): updated candidates
    """

    max_entity = "entity_" + str(len(rule["body_rels"]))
    cands = set(rule_walks[max_entity])

    for cand in cands:
        cands_walks = rule_walks[rule_walks[max_entity] == cand]
        for s in dicts_idx:
            score = score_func(rule, cands_walks, test_query_ts, *args[s]).astype(
                np.float32
            )
            # dict_temp = {0: score}
            try:
                cands_dict[s][cand].append(score)
            except KeyError:
                cands_dict[s][cand] = [score]

    return cands_dict


# 得到单个规则的候选实体的字典，记录了对应的评分和相关信息，值为二维列表
def get_candidates_evolve(
    rule, rule_walks, test_query_ts, cands_dict,  score_func, args, dicts_idx, rule_idx, body_rels
):
    """
    Get from the walks that follow the rule the answer candidates.
    Add the confidence of the rule that leads to these candidates.

    Parameters:
        rule (dict): rule from rules_dict
        rule_walks (pd.DataFrame): rule walks (satisfying all constraints from the rule)
        test_query_ts (int): test query timestamp
        cands_dict (dict): candidates along with the confidences of the rules that generated these candidates
        score_func (function): function for calculating the candidate score
        args (list): arguments for the scoring function
        dicts_idx (list): indices for candidate dictionaries

    Returns:
        cands_dict (dict): updated candidates
    """

    max_entity = "entity_" + str(len(rule["body_rels"]))
    cands = set(rule_walks[max_entity])

    for cand in cands:
        cands_walks = rule_walks[rule_walks[max_entity] == cand]
        for s in dicts_idx:
            score = score_func(rule, cands_walks, test_query_ts, *args[s]).astype(
                np.float32
            )
            max_cands_ts = max(cands_walks["timestamp_0"])
            cands_walks_max_ts = cands_walks[cands_walks["timestamp_0"] == max_cands_ts]  # 得到最早时间戳最大的路径

            rule_len = len(body_rels)
            sorted_col = [f'timestamp_{i}' for i in range(rule_len)]
            ascending_dir = [False] * rule_len
            cands_walks_max_ts_sorted = cands_walks_max_ts.sort_values(by=sorted_col, ascending=ascending_dir)  # 先对dataFrame排序再选择

            # max_len = cands_walks_max_ts_sorted.shape[0]
            # idx = random.randint(0, max_len - 1)
            # cands_walks_selected = cands_walks_max_ts_sorted.iloc[idx]  # 随机选择存在不可复现的问题
            cands_walks_selected = cands_walks_max_ts_sorted.iloc[0]  # 选择时间戳降序第一个，即时间最近的一个规则
            ts_ls = [cands_walks_selected[f'timestamp_{i}'] for i in range(rule_len)]
            res = [score, rule_idx, body_rels, ts_ls]

            # dict_temp = {0: score}
            try:
                cands_dict[s][cand].append(res)
            except KeyError:
                cands_dict[s][cand] = [res]

    return cands_dict


def get_candidates_with_rules(
    rule, rule_walks, test_query_ts, cands_dict, score_func, args, dicts_idx, rule_idx,
):
    """
    Get from the walks that follow the rule the answer candidates.
    Add the confidence of the rule that leads to these candidates.

    Parameters:
        rule (dict): rule from rules_dict
        rule_walks (pd.DataFrame): rule walks (satisfying all constraints from the rule)
        test_query_ts (int): test query timestamp
        cands_dict (dict): candidates along with the confidences of the rules that generated these candidates
        score_func (function): function for calculating the candidate score
        args (list): arguments for the scoring function
        dicts_idx (list): indices for candidate dictionaries

    Returns:
        cands_dict (dict): updated candidates
    """

    max_entity = "entity_" + str(len(rule["body_rels"]))
    cands = set(rule_walks[max_entity])

    # for cand in cands:
    #     # cands_walks = rule_walks[rule_walks[max_entity] == cand]
    #     for s in dicts_idx:
    #         # score = score_func(rule, cands_walks, test_query_ts, *args[s]).astype(
    #         #     np.float32
    #         # )
    #         # dict_temp = {0: score}
    #         try:
    #             cands_dict[s][cand].append(rule_idx)
    #         except KeyError:
    #             cands_dict[s][cand] = [rule_idx]
    for cand in cands:
        cands_walks = rule_walks[rule_walks[max_entity] == cand]
        for s in dicts_idx:
            score = score_func(rule, cands_walks, test_query_ts, *args[s]).astype(
                np.float32
            )
            # dict_temp = {0: score}
            res = [score, rule_idx]
            try:
                cands_dict[s][cand].append(res)
            except KeyError:
                cands_dict[s][cand] = [res]

    return cands_dict


# 记录所有查询及其候选实体到文件
def save_candidates(
    rules_file, dir_path, all_candidates, rule_lengths, window, score_func_str
):
    """
    Save the candidates.

    Parameters:
        rules_file (str): name of rules file
        dir_path (str): path to output directory
        all_candidates (dict): candidates for all test queries
        rule_lengths (list): rule lengths
        window (int): time window used for rule application
        score_func_str (str): scoring function

    Returns:
        None
    """

    all_candidates = {int(k): v for k, v in all_candidates.items()}
    for k in all_candidates:
        all_candidates[k] = {int(cand): v for cand, v in all_candidates[k].items()}
    filename = "{0}_win{1}_{2}_cands.json".format(
        rules_file[:-11], window, score_func_str
    )
    filename = filename.replace(" ", "")
    with open(os.path.join(dir_path, filename), "w", encoding="utf-8") as fout:
        json.dump(all_candidates, fout)


# 记录所有查询及其候选实体到文件
def save_candidates_with_rules(
    rules_file, dir_path, all_candidates, all_candidates_with_rules, rule_lengths, window, score_func_str, type, negative, rule_used
):
    """
    Save the candidates.

    Parameters:
        rules_file (str): name of rules file
        dir_path (str): path to output directory
        all_candidates (dict): candidates for all test queries
        rule_lengths (list): rule lengths
        window (int): time window used for rule application
        score_func_str (str): scoring function

    Returns:
        None
    """

    all_candidates = {int(k): v for k, v in all_candidates.items()}
    all_candidates_with_rules = {int(k): v for k, v in all_candidates_with_rules.items()}
    for k in all_candidates:
        all_candidates[k] = {int(cand): v for cand, v in all_candidates[k].items()}
    # for k in all_candidates_with_rules:
    #     for i in range(len(all_candidates_with_rules[k])):
    #         all_candidates_with_rules[k][i] = {int(cand): v for cand, v in all_candidates_with_rules[k][i].items()}
    # filename = "{0}_win{1}_{2}_cands.json".format(
    #     rules_file[:-11], window, score_func_str
    # )

    # with open(rules_file, 'r') as file:
    #     rules = json.load(file)

    # all_candidates_with_rules_tensor = {}
    # for k, v in all_candidates_with_rules.items():
    #     all_candidates_with_rules_tensor[k] = []
    #     for sample in v:

    # filename = "{0}_win{1}_{2}_cands_with_rules_{3}.json".format(
    #     rules_file[:-11], window, score_func_str, type
    # )
    # filename = filename.replace(" ", "")
    # with open(os.path.join(dir_path, filename), "w", encoding="utf-8") as fout:
    #     json.dump(all_candidates_with_rules, fout)
    filename = "{0}_win{1}_{2}_cands_with_rules_neg{3}_ruleUsed{4}_{5}.pt".format(
        rules_file[:-11], window, score_func_str, negative, rule_used, type
    )
    filename = filename.replace(" ", "")
    torch.save(all_candidates_with_rules, os.path.join(dir_path, filename))


# 记录所有查询及其候选实体到文件
def save_candidates_with_rules_sort(
    rules_file, dir_path, all_candidates, all_candidates_with_rules, rule_lengths, window, score_func_str, type, negative
):
    """
    Save the candidates.

    Parameters:
        rules_file (str): name of rules file
        dir_path (str): path to output directory
        all_candidates (dict): candidates for all test queries
        rule_lengths (list): rule lengths
        window (int): time window used for rule application
        score_func_str (str): scoring function

    Returns:
        None
    """

    all_candidates = {int(k): v for k, v in all_candidates.items()}
    all_candidates_with_rules = {int(k): v for k, v in all_candidates_with_rules.items()}
    for k in all_candidates:
        all_candidates[k] = {int(cand): v for cand, v in all_candidates[k].items()}
    # for k in all_candidates_with_rules:
    #     for i in range(len(all_candidates_with_rules[k])):
    #         all_candidates_with_rules[k][i] = {int(cand): v for cand, v in all_candidates_with_rules[k][i].items()}
    # filename = "{0}_win{1}_{2}_cands.json".format(
    #     rules_file[:-11], window, score_func_str
    # )

    # with open(rules_file, 'r') as file:
    #     rules = json.load(file)

    # all_candidates_with_rules_tensor = {}
    # for k, v in all_candidates_with_rules.items():
    #     all_candidates_with_rules_tensor[k] = []
    #     for sample in v:

    # filename = "{0}_win{1}_{2}_cands_with_rules_{3}.json".format(
    #     rules_file[:-11], window, score_func_str, type
    # )
    # filename = filename.replace(" ", "")
    # with open(os.path.join(dir_path, filename), "w", encoding="utf-8") as fout:
    #     json.dump(all_candidates_with_rules, fout)
    filename = "{0}_win{1}_{2}_cands_with_rules_neg{3}_sort_{4}.pt".format(
        rules_file[:-11], window, score_func_str, negative, type
    )
    filename = filename.replace(" ", "")
    torch.save(all_candidates_with_rules, os.path.join(dir_path, filename))


# 记录所有查询及其候选实体到文件
def save_candidates_evolve(
    rules_file, dir_path, all_candidates, all_candidates_with_evolve, rule_lengths, window, score_func_str, type, negative, rule_used
):
    """
    Save the candidates.

    Parameters:
        rules_file (str): name of rules file
        dir_path (str): path to output directory
        all_candidates (dict): candidates for all test queries
        rule_lengths (list): rule lengths
        window (int): time window used for rule application
        score_func_str (str): scoring function

    Returns:
        None
    """

    all_candidates = {int(k): v for k, v in all_candidates.items()}
    all_candidates_with_evolve = {int(k): v for k, v in all_candidates_with_evolve.items()}
    for k in all_candidates:
        all_candidates[k] = {int(cand): v for cand, v in all_candidates[k].items()}

    filename = "{0}_win{1}_{2}_cands_evolve_neg{3}_ruleUsed{4}_{5}.pt".format(
        rules_file[:-11], window, score_func_str, negative, rule_used, type
    )
    filename = filename.replace(" ", "")
    torch.save(all_candidates_with_evolve, os.path.join(dir_path, filename))


# 记录所有查询及其候选实体到文件
def save_candidates_static(
    rules_file, dir_path, all_candidates, all_candidates_with_evolve, rule_lengths, window, score_func_str, type, negative, rule_used
):
    """
    Save the candidates.

    Parameters:
        rules_file (str): name of rules file
        dir_path (str): path to output directory
        all_candidates (dict): candidates for all test queries
        rule_lengths (list): rule lengths
        window (int): time window used for rule application
        score_func_str (str): scoring function

    Returns:
        None
    """

    all_candidates = {int(k): v for k, v in all_candidates.items()}
    all_candidates_with_evolve = {int(k): v for k, v in all_candidates_with_evolve.items()}
    for k in all_candidates:
        all_candidates[k] = {int(cand): v for cand, v in all_candidates[k].items()}

    filename = "{0}_win{1}_{2}_cands_static_neg{3}_ruleUsed{4}_{5}.pt".format(
        rules_file[:-11], window, score_func_str, negative, rule_used, type
    )
    filename = filename.replace(" ", "")
    os.makedirs(dir_path, exist_ok=True)
    torch.save(all_candidates_with_evolve, os.path.join(dir_path, filename))


def verbalize_walk(walk, data):
    """
    Verbalize walk from rule application.

    Parameters:
        walk (pandas.core.series.Series): walk that matches the rule body from get_walks
        data (grapher.Grapher): graph data

    Returns:
        walk_str (str): verbalized walk
    """

    l = len(walk) // 3
    walk = walk.values.tolist()

    walk_str = data.id2entity[walk[0]] + "\t"
    for j in range(l):
        walk_str += data.id2relation[walk[3 * j + 1]] + "\t"
        walk_str += data.id2entity[walk[3 * j + 2]] + "\t"
        walk_str += data.id2ts[walk[3 * j + 3]] + "\t"

    return walk_str[:-1]
