import argparse
import itertools
import json
import logging
import os
import sys
import time
import pickle

import dgl
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import random
sys.path.append("..")
from rgcn import utils
from rgcn.utils import build_sub_graph
from evolve.rrgcn import RecurrentRGCN
from evolve.hyperparameter_range import *
import torch.nn.modules.rnn
from collections import defaultdict
from rgcn.knowledge_graph import _read_triplets_as_list
import scipy.sparse as sp
from datetime import datetime
import torch.nn.functional as F
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'  # 使用决定性算法时要配置该环境变量


def set_seed(seed=6666):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # dgl.seed(seed)
    torch.backends.cudnn.deterministic = True  # 防止报错没有卷积算法可用
    torch.use_deterministic_algorithms(True)  # 使用确定的PyTorch算法，如卷积等

    # torch.cuda.manual_seed(seed)  # 在当前GPU上设置随机数种子
    # torch.backends.cudnn.benchmark = True  # 如果为True，则在多个卷积算法中选择最快的
    # torch.backends.cudnn.enabled = False  # 是否使用cuDNN
    # os.environ['PYTHONHASHSEED'] = str(seed)  # 设置hash种子


def set_logger(save_dir, log_file):
    log_file = os.path.join(save_dir, log_file)
    # set logging
    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    # set console
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m%d %H:%M:%S')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)


def test(args, model, history_list, test_list, num_rels, num_nodes, use_cuda, all_ans_list, all_ans_r_list, model_name, static_graph, time_list, history_time_nogt, mode):
    ranks_raw, ranks_filter, mrr_raw_list, mrr_filter_list = [], [], [], []
    ranks_raw_r, ranks_filter_r, mrr_raw_list_r, mrr_filter_list_r = [], [], [], []

    idx = 0
    if mode == "test":
        # test mode: load parameter form file
        if use_cuda:
            checkpoint = torch.load(os.path.join(model_name, 'model.pt'), map_location=torch.device(args.gpu))
        else:
            checkpoint = torch.load(os.path.join(model_name, 'model.pt'), map_location=torch.device('cpu'))
        print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))  # use best stat checkpoint
        logging.info("-"*10+"start testing"+"-"*10)
        model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    # print(f'lam_gru: {F.sigmoid(model.decoder_ob3.lam_gru)}  lam_rule: {F.sigmoid(model.decoder_ob3.lam_rule)}')
    # print(f'sim_atten: {F.sigmoid(model.decoder_ob3.sim_atten)}  score_atten: {F.sigmoid(model.decoder_ob3.score_atten)}')
    # sys.exit()
    # do not have inverse relation in test input
    input_list = [snap for snap in history_list[-args.test_history_len:]]

    if args.multi_step:
        all_tail_seq = sp.load_npz(
            '../data/{}/history{}/tail_history_{}.npz'.format(args.dataset, args.his_len, history_time_nogt))
        # rel
        all_rel_seq = sp.load_npz(
            '../data/{}/history{}/rel_history_{}.npz'.format(args.dataset, args.his_len, history_time_nogt))

    for time_idx, test_snap in enumerate(tqdm(test_list)):
        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
        test_triples_input = test_triples_input.to(args.gpu)

        # get history
        histroy_data = test_triples_input
        inverse_histroy_data = histroy_data[:, [2, 1, 0, 3]]
        inverse_histroy_data[:, 1] = inverse_histroy_data[:, 1] + num_rels
        histroy_data = torch.cat([histroy_data, inverse_histroy_data])
        histroy_data = histroy_data.cpu().numpy()
        if args.multi_step:
            seq_idx = histroy_data[:, 0] * num_rels * 2 + histroy_data[:, 1]
            tail_seq = torch.Tensor(all_tail_seq[seq_idx].todense())
            one_hot_tail_seq = tail_seq.masked_fill(tail_seq != 0, 1)
            # rel
            rel_seq_idx = histroy_data[:, 0] * num_nodes + histroy_data[:, 2]
            rel_seq = torch.Tensor(all_rel_seq[rel_seq_idx].todense())
            one_hot_rel_seq = rel_seq.masked_fill(rel_seq != 0, 1)
        else:
            all_tail_seq = sp.load_npz(
                '../data/{}/history{}/tail_history_{}.npz'.format(args.dataset, args.his_len, time_list[time_idx]))
            seq_idx = histroy_data[:, 0] * num_rels * 2 + histroy_data[:, 1]
            tail_seq = torch.Tensor(all_tail_seq[seq_idx].todense())
            one_hot_tail_seq = tail_seq.masked_fill(tail_seq != 0, 1)
            # rel
            all_rel_seq = sp.load_npz(
                '../data/{}/history{}/rel_history_{}.npz'.format(args.dataset, args.his_len, time_list[time_idx]))
            rel_seq_idx = histroy_data[:, 0] * num_nodes + histroy_data[:, 2]
            rel_seq = torch.Tensor(all_rel_seq[rel_seq_idx].todense())
            one_hot_rel_seq = rel_seq.masked_fill(rel_seq != 0, 1)
        if use_cuda:
            one_hot_tail_seq = one_hot_tail_seq.cuda()
            one_hot_rel_seq = one_hot_rel_seq.cuda()
        if len(args.multi_gpu) > 1:
            test_triples, evolve_score, evolve_r_score = model.module.predict(history_glist, num_rels, static_graph, test_triples_input, one_hot_tail_seq, one_hot_rel_seq, use_cuda, mode)
        else:
            test_triples, evolve_score, evolve_r_score = model.predict(history_glist, num_rels, static_graph, test_triples_input, one_hot_tail_seq, one_hot_rel_seq, use_cuda, mode)

        # mrr_filter_snap_r, mrr_snap_r, rank_raw_r, rank_filter_r = utils.get_total_rank(args.setting, test_triples, evolve_r_score, all_ans_r_list[time_idx], eval_bz=1000, rel_predict=1)
        mrr_filter_snap, mrr_snap, rank_raw, rank_filter = utils.get_total_rank(args.setting, test_triples, evolve_score, all_ans_list[time_idx], eval_bz=1000, rel_predict=0)

        # used to global statistic
        ranks_raw.append(rank_raw)
        ranks_filter.append(rank_filter)
        # used to show slide results
        mrr_raw_list.append(mrr_snap)
        mrr_filter_list.append(mrr_filter_snap)

        # relation rank
        # ranks_raw_r.append(rank_raw_r)
        # ranks_filter_r.append(rank_filter_r)
        # mrr_raw_list_r.append(mrr_snap_r)
        # mrr_filter_list_r.append(mrr_filter_snap_r)

        # reconstruct history graph list
        if args.multi_step:
            if not args.relation_evaluation:
                predicted_snap = utils.construct_snap(test_triples, num_nodes, num_rels, evolve_score, args.topk)
            else:
                predicted_snap = utils.construct_snap_r(test_triples, num_nodes, num_rels, evolve_r_score, args.topk)
            if len(predicted_snap):
                input_list.pop(0)
                input_list.append(predicted_snap)
        else:
            input_list.pop(0)
            input_list.append(test_snap)
        idx += 1

    mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r = 0, 0, 0, 0
    hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = [], [], [], []
    mrr_raw, hit_result_raw = utils.stat_ranks(ranks_raw, "raw_ent")
    mrr_filter, hit_result_filter = utils.stat_ranks(ranks_filter, "filter_ent")
    # mrr_raw_r, hit_result_raw_r = utils.stat_ranks(ranks_raw_r, "raw_rel")
    # mrr_filter_r, hit_result_filter_r = utils.stat_ranks(ranks_filter_r, "filter_rel")
    return mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r


def test_fusion(args, model, history_list, test_list, num_rels, num_nodes, use_cuda, all_ans_list, all_ans_r_list, model_name, static_graph, time_list, history_time_nogt, mode):
    ranks_raw, ranks_filter, mrr_raw_list, mrr_filter_list = [], [], [], []
    ranks_raw_r, ranks_filter_r, mrr_raw_list_r, mrr_filter_list_r = [], [], [], []

    # 得到规则候选实体的评分
    # cands_dir = os.path.join(f'../output_rule/{args.dataset}/seed{args.seed}', args.cands_file)
    cands_dir = os.path.join(f'../output_rule/{args.dataset}/seed{args.seed}', args.cands_file)
    with open(cands_dir, 'r') as file:
        cands_score = json.load(file)  # 键值为数字字符串
    cands_score = list(cands_score.values())

    idx = 0
    if mode == "test":
        # test mode: load parameter form file
        if use_cuda:
            checkpoint = torch.load(os.path.join(model_name, 'model.pt'), map_location=torch.device(args.gpu))
        else:
            checkpoint = torch.load(os.path.join(model_name, 'model.pt'), map_location=torch.device('cpu'))
        print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))  # use best stat checkpoint
        logging.info(f'softmax: {not args.not_softmax}\talpha: {args.alpha}')
        logging.info("-"*10+"start testing"+"-"*10)
        model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    # do not have inverse relation in test input
    input_list = [snap for snap in history_list[-args.test_history_len:]]

    if args.multi_step:
        all_tail_seq = sp.load_npz(
            '../data/{}/history{}/tail_history_{}.npz'.format(args.dataset, args.his_len, history_time_nogt))
        # rel
        all_rel_seq = sp.load_npz(
            '../data/{}/history{}/rel_history_{}.npz'.format(args.dataset, args.his_len, history_time_nogt))

    start_idx = 0
    end_idx = 0
    # idx_now = 0
    for time_idx, test_snap in enumerate(tqdm(test_list)):
        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
        test_triples_input = test_triples_input.to(args.gpu)

        # get history
        histroy_data = test_triples_input
        inverse_histroy_data = histroy_data[:, [2, 1, 0, 3]]
        inverse_histroy_data[:, 1] = inverse_histroy_data[:, 1] + num_rels
        histroy_data = torch.cat([histroy_data, inverse_histroy_data])
        histroy_data = histroy_data.cpu().numpy()
        if args.multi_step:
            seq_idx = histroy_data[:, 0] * num_rels * 2 + histroy_data[:, 1]
            tail_seq = torch.Tensor(all_tail_seq[seq_idx].todense())
            one_hot_tail_seq = tail_seq.masked_fill(tail_seq != 0, 1)
            # rel
            rel_seq_idx = histroy_data[:, 0] * num_nodes + histroy_data[:, 2]
            rel_seq = torch.Tensor(all_rel_seq[rel_seq_idx].todense())
            one_hot_rel_seq = rel_seq.masked_fill(rel_seq != 0, 1)
        else:
            all_tail_seq = sp.load_npz(
                '../data/{}/history{}/tail_history_{}.npz'.format(args.dataset, args.his_len, time_list[time_idx]))
            seq_idx = histroy_data[:, 0] * num_rels * 2 + histroy_data[:, 1]
            tail_seq = torch.Tensor(all_tail_seq[seq_idx].todense())
            one_hot_tail_seq = tail_seq.masked_fill(tail_seq != 0, 1)
            # rel
            all_rel_seq = sp.load_npz(
                '../data/{}/history{}/rel_history_{}.npz'.format(args.dataset, args.his_len, time_list[time_idx]))
            rel_seq_idx = histroy_data[:, 0] * num_nodes + histroy_data[:, 2]
            rel_seq = torch.Tensor(all_rel_seq[rel_seq_idx].todense())
            one_hot_rel_seq = rel_seq.masked_fill(rel_seq != 0, 1)
        if use_cuda:
            one_hot_tail_seq = one_hot_tail_seq.cuda()
            one_hot_rel_seq = one_hot_rel_seq.cuda()

        if len(args.multi_gpu) > 1:
            test_triples, evolve_score, evolve_r_score = model.module.predict(history_glist, num_rels, static_graph, test_triples_input, one_hot_tail_seq, one_hot_rel_seq, use_cuda, mode)
        else:
            test_triples, evolve_score, evolve_r_score = model.predict(history_glist, num_rels, static_graph, test_triples_input, one_hot_tail_seq, one_hot_rel_seq, use_cuda, mode)

        # 计算规则评分
        end_idx += len(test_triples)
        cands_score_part = cands_score[start_idx:end_idx]
        start_idx = end_idx
        idx_now = 0
        rule_score = torch.zeros(evolve_score.size()).cuda()
        for cand in cands_score_part:  # 规则评分矩阵
            keys = torch.Tensor([int(key) for key in list(cand.keys())]).long().cuda()  # 长张量作为索引
            values = torch.Tensor(list(cand.values())).cuda()
            rule_score[idx_now, keys] = values
            idx_now += 1
        rule_mask = rule_score > 0  # 掩码矩阵
        for i, mask in enumerate(rule_mask):  # 如果规则应用没有候选实体，则所有规则评分都为0，只使用演化评分，掩码为true
            if not torch.any(mask):
                rule_mask[i] = True
        # if args.not_softmax:
        #     pass
        #     # idx_true = torch.where(rule_mask == True)
        #     # idx_false = torch.where(rule_mask == False)
        #     # rule_mask[idx_true[0], idx_true[1]] = True
        #     # rule_mask[idx_false[0], idx_false[1]] = False
        #     # evolve_score = evolve_score * rule_mask
        #     # evolve_score = torch.where(evolve_score == 0, -1000, evolve_score)
        # else:
        #     evolve_score = evolve_score.softmax(dim=1)
        #     # evolve_score = evolve_score * rule_mask

        # evolve_score = evolve_score * rule_mask
        # 计算融合评分
        fusion_score = args.alpha * evolve_score + (1 - args.alpha) * rule_score
        # fusion_score = evolve_score
        _, _, rank_raw, rank_filter = utils.get_total_rank(args.setting, test_triples, fusion_score, all_ans_list[time_idx], eval_bz=1000, rel_predict=0)
        ranks_raw.append(rank_raw)
        ranks_filter.append(rank_filter)

        # mrr_filter_snap_r, mrr_snap_r, rank_raw_r, rank_filter_r = utils.get_total_rank(test_triples, evolve_r_score, all_ans_r_list[time_idx], eval_bz=1000, rel_predict=1)
        # mrr_filter_snap, mrr_snap, rank_raw, rank_filter = utils.get_total_rank(test_triples, evolve_score, all_ans_list[time_idx], eval_bz=1000, rel_predict=0)

        # # used to global statistic
        # ranks_raw.append(rank_raw)
        # ranks_filter.append(rank_filter)
        # # used to show slide results
        # mrr_raw_list.append(mrr_snap)
        # mrr_filter_list.append(mrr_filter_snap)

        # # relation rank
        # ranks_raw_r.append(rank_raw_r)
        # ranks_filter_r.append(rank_filter_r)
        # mrr_raw_list_r.append(mrr_snap_r)
        # mrr_filter_list_r.append(mrr_filter_snap_r)

        # reconstruct history graph list
        if args.multi_step:
            if not args.relation_evaluation:
                predicted_snap = utils.construct_snap(test_triples, num_nodes, num_rels, evolve_score, args.topk)
            else:
                predicted_snap = utils.construct_snap_r(test_triples, num_nodes, num_rels, evolve_r_score, args.topk)
            if len(predicted_snap):
                input_list.pop(0)
                input_list.append(predicted_snap)
        else:
            input_list.pop(0)
            input_list.append(test_snap)
        idx += 1

    mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r = 0, 0, 0, 0
    hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = [], [], [], []
    mrr_raw, hit_result_raw = utils.stat_ranks(ranks_raw, "raw_ent")
    mrr_filter, hit_result_filter = utils.stat_ranks(ranks_filter, "filter_ent")

    # return mrr_raw, mrr_filter, hit_result_raw, hit_result_filter
    # mrr_raw_r, hit_result_raw_r = utils.stat_ranks(ranks_raw_r, "raw_rel")
    # mrr_filter_r, hit_result_filter_r = utils.stat_ranks(ranks_filter_r, "filter_rel")
    return mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r


def run_experiment(args, history_len=None, n_layers=None, dropout=None, n_bases=None, angle=None, history_rate=None):
    set_seed(args.seed)
    gpu_used = str(args.multi_gpu)[1:-1]
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_used  # 这里输入你的GPU_id

    # load configuration for grid search the best configuration
    if history_len:
        args.train_history_len = history_len
        args.test_history_len = history_len
    if n_layers:
        args.n_layers = n_layers
    if dropout:
        args.dropout = dropout
    if n_bases:
        args.n_bases = n_bases
    if angle:
        args.angle = angle
    if history_rate:
        args.history_rate = history_rate
    mrr_raw = None
    mrr_filter = None
    mrr_raw_r = None
    mrr_filter_r = None
    hit_result_raw = None
    hit_result_filter = None
    hit_result_raw_r = None
    hit_result_filter_r = None

    # load graph data
    print("loading graph data")
    # data = utils.load_data(args.dataset)   # 得到data类
    data_dict = utils.load_dataset(args.dataset)
    train_list, train_times = utils.split_by_time(data_dict['train'])   # 划分为snapshots，逐时间步的数据集
    valid_list, valid_times = utils.split_by_time(data_dict['valid'])
    test_list, test_times = utils.split_by_time(data_dict['test'])

    num_nodes = data_dict['num_nodes']
    num_rels = data_dict['num_rels']
    # if args.dataset == "ICEWS14":
    #     num_times = len(train_list) + len(valid_list) + len(test_list) + 1  # 时间戳从1开始
    # else:
    #     num_times = len(train_list) + len(valid_list) + len(test_list)
    num_times = len(train_list) + len(valid_list) + len(test_list)
    time_interval = train_times[1]-train_times[0]
    print("num_times", num_times, "--------------", time_interval)
    history_val_time_nogt = valid_times[0]
    history_test_time_nogt = test_times[0]
    if args.multi_step:
        print("val only use global history before:", history_val_time_nogt)
        print("test only use global history before:", history_test_time_nogt)

    all_ans_list_test = utils.load_all_answers_for_time_filter(data_dict['test'], num_rels, num_nodes, False)
    all_ans_list_r_test = utils.load_all_answers_for_time_filter(data_dict['test'], num_rels, num_nodes, True)
    all_ans_list_valid = utils.load_all_answers_for_time_filter(data_dict['valid'], num_rels, num_nodes, False)
    all_ans_list_r_valid = utils.load_all_answers_for_time_filter(data_dict['valid'], num_rels, num_nodes, True)

    if not args.model_name:
        # model_name = "gl_rate_{}-{}-{}-{}-ly{}-dilate{}-his{}-weight_{}-discount_{}-angle_{}-dp{}_{}_{}_{}-gpu{}-{}"\
        #     .format(args.history_rate, args.dataset, args.encoder, args.decoder, args.n_layers, args.dilate_len, args.train_history_len, args.weight, args.discount, args.angle,
        #             args.dropout, args.input_dropout, args.hidden_dropout, args.feat_dropout, args.gpu, args.save)
        model_name = f're{int(args.relation_evolve)}'
    else:
        model_name = args.model_name
    model_state_file = os.path.join(f'../models_evolve/{args.dataset}', model_name, f'seed{args.seed}')
    os.makedirs(model_state_file, exist_ok=True)
    if not args.test:
        log_file = 'log_train.txt'
    else:
        if args.original:
            log_file = f'log_test_original_{args.setting}.txt'
        elif not args.fusion:
            log_file = f'log_test_localFre_{args.setting}.txt'
        elif args.not_softmax:
            log_file = f'log_test_soft0_{args.setting}.txt'
        else:
            log_file = f'log_test_soft1_{args.setting}.txt'
    set_logger(model_state_file, log_file)
    if args.des_info:
        logging.info(f'{args.des_info}')

    print("Sanity Check: stat name : {}".format(model_state_file))
    print("Sanity Check: Is cuda available ? {}".format(torch.cuda.is_available()))
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()

    if args.add_static_graph:
        static_triples = np.array(_read_triplets_as_list("../data/" + args.dataset + "/e-w-graph.txt", {}, {}, load_time=False))
        num_static_rels = len(np.unique(static_triples[:, 1]))
        num_words = len(np.unique(static_triples[:, 2]))
        static_triples[:, 2] = static_triples[:, 2] + num_nodes
    else:
        num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None

    # create stat
    model = RecurrentRGCN(args,
                        args.decoder,
                        args.encoder,
                        num_nodes,
                        num_rels,
                        num_static_rels,
                        num_words,
                        num_times,
                        time_interval,
                        args.n_hidden,
                        args.opn,
                        args.history_rate,
                        sequence_len=args.train_history_len,
                        num_bases=args.n_bases,
                        num_basis=args.n_basis,
                        num_hidden_layers=args.n_layers,
                        dropout=args.dropout,
                        self_loop=args.self_loop,
                        skip_connect=args.skip_connect,
                        layer_norm=args.layer_norm,
                        input_dropout=args.input_dropout,
                        hidden_dropout=args.hidden_dropout,
                        feat_dropout=args.feat_dropout,
                        aggregation=args.aggregation,
                        weight=args.weight,
                        discount=args.discount,
                        angle=args.angle,
                        use_static=args.add_static_graph,
                        entity_prediction=args.entity_prediction,
                        relation_prediction=args.relation_prediction,
                        use_cuda=use_cuda,
                        gpu = args.gpu,
                        analysis=args.run_analysis)

    if use_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # model = model.to(device)
        if len(args.multi_gpu) > 1:
            model = nn.DataParallel(model)
        # torch.cuda.set_device(args.gpu)
        model.to(device)

    if args.add_static_graph:
        static_node_id = torch.from_numpy(np.arange(num_words + data_dict['num_nodes'])).view(-1, 1).long().cuda(args.gpu) \
            if use_cuda else torch.from_numpy(np.arange(num_words + data_dict['num_nodes'])).view(-1, 1).long()
        static_graph = build_sub_graph(len(static_node_id), num_static_rels, static_triples, use_cuda, args.gpu)

    # optimizer
    # optimizer = torch.optim.Adam(model.module.parameters(), lr=args.lr, weight_decay=1e-5)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)

    if args.test and os.path.exists(model_state_file):
        if args.fusion:
            mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = test_fusion(
                args,
                model,
                train_list+valid_list,
                test_list,
                num_rels,
                num_nodes,
                use_cuda,
                all_ans_list_test,
                all_ans_list_r_test,
                model_state_file,
                static_graph,
                test_times,
                history_test_time_nogt,
                "test")
        else:
            mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = test(
                args,
                model,
                train_list+valid_list,
                test_list,
                num_rels,
                num_nodes,
                use_cuda,
                all_ans_list_test,
                all_ans_list_r_test,
                model_state_file,
                static_graph,
                test_times,
                history_test_time_nogt,
                "test")
    elif args.test and not os.path.exists(model_state_file):
        print("--------------{} not exist, Change mode to train and generate stat for testing----------------\n".format(model_state_file))
    else:
        logging.info("----------------------------------------start training----------------------------------------")
        best_mrr = 0
        best_epoch = 0
        for epoch in range(args.n_epochs):
            model.train()
            losses = []
            losses_e = []
            losses_r = []
            losses_static = []

            idx = [_ for _ in range(len(train_list))]
            random.shuffle(idx)

            for train_sample_num in tqdm(idx):
                if train_sample_num == 0:
                    continue
                output = train_list[train_sample_num:train_sample_num+1]
                if train_sample_num - args.train_history_len < 0:
                    input_list = train_list[0: train_sample_num]
                else:
                    input_list = train_list[train_sample_num - args.train_history_len: train_sample_num]

                # generate history graph
                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]

                # history load
                histroy_data = output[0]
                inverse_histroy_data = histroy_data[:, [2, 1, 0, 3]]
                inverse_histroy_data[:, 1] = inverse_histroy_data[:, 1] + num_rels
                histroy_data = torch.cat([histroy_data, inverse_histroy_data])
                histroy_data_cuda = histroy_data.cuda()  # 在正向传播前生成反向四元组
                histroy_data = histroy_data.cpu().numpy()
                # tail
                all_tail_seq = sp.load_npz(
                    '../data/{}/history{}/tail_history_{}.npz'.format(args.dataset, args.his_len, train_times[train_sample_num]))
                seq_idx = histroy_data[:, 0] * num_rels * 2 + histroy_data[:, 1]
                tail_seq = torch.Tensor(all_tail_seq[seq_idx].todense())  # 转换成密集矩阵数组，再转换为二维张量
                # one_hot_tail_seq = tail_seq.masked_fill(tail_seq != 0, 1)  # 将张量中不为0的值填充为1
                one_hot_tail_seq = tail_seq
                # rel
                all_rel_seq = sp.load_npz(
                    '../data/{}/history{}/rel_history_{}.npz'.format(args.dataset, args.his_len, train_times[train_sample_num]))
                rel_seq_idx = histroy_data[:, 0] * num_nodes + histroy_data[:, 2]
                rel_seq = torch.Tensor(all_rel_seq[rel_seq_idx].todense())
                # one_hot_rel_seq = rel_seq.masked_fill(rel_seq != 0, 1)
                one_hot_rel_seq = rel_seq
                if use_cuda:
                    one_hot_tail_seq = one_hot_tail_seq.cuda()
                    one_hot_rel_seq = one_hot_rel_seq.cuda()

                gpu_num = torch.cuda.device_count()
                bias = 1 if len(histroy_data_cuda) % gpu_num > 0 else 0
                batch_gpu = len(histroy_data_cuda) // gpu_num + bias

                loss_e, loss_r, loss_static = model.forward(history_glist, histroy_data_cuda, batch_gpu, static_graph, one_hot_tail_seq, one_hot_rel_seq, use_cuda)
                # if args.multi_gpu[0] != -1:
                #     loss_e, loss_r, loss_static = model.forward(history_glist, histroy_data_cuda, static_graph, one_hot_tail_seq, one_hot_rel_seq, use_cuda)
                # else:
                #     loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, one_hot_tail_seq, one_hot_rel_seq, use_cuda)
                loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static

                if len(args.multi_gpu) > 1:
                    losses.append(loss.sum().item())
                    losses_e.append(loss_e.sum().item())
                    losses_r.append(loss_r.sum().item())
                    losses_static.append(loss_static.sum().item())
                    loss.sum().backward()
                else:
                    losses.append(loss.item())
                    losses_e.append(loss_e.item())
                    losses_r.append(loss_r.item())
                    losses_static.append(loss_static.item())
                    loss.backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients
                optimizer.step()
                optimizer.zero_grad()

            logging.info("Epoch {:04d} | Ave Loss: {:.4f} | entity-relation-static:{:.4f}-{:.4f}-{:.4f} Best MRR {:.4f} | Model {} "
                  .format(epoch, np.mean(losses), np.mean(losses_e), np.mean(losses_r), np.mean(losses_static), best_mrr, model_name))

            # validation
            if epoch % args.evaluate_every == 0:
                mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = test(args,
                                                                                    model,
                                                                                    train_list, 
                                                                                    valid_list, 
                                                                                    num_rels, 
                                                                                    num_nodes, 
                                                                                    use_cuda, 
                                                                                    all_ans_list_valid, 
                                                                                    all_ans_list_r_valid, 
                                                                                    model_state_file, 
                                                                                    static_graph,
                                                                                    valid_times,
                                                                                    history_val_time_nogt,
                                                                                    mode="valid")

                if not args.relation_evaluation:  # entity prediction evalution
                    if mrr_filter < best_mrr:
                        if epoch - best_epoch >= args.max_train_length:
                            break
                    else:
                        best_mrr = mrr_filter
                        best_epoch = epoch
                        torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, os.path.join(model_state_file, 'model.pt'))
                else:
                    if mrr_filter_r < best_mrr:
                        if epoch - best_epoch >= args.max_train_length:
                            break
                    else:
                        best_mrr = mrr_filter_r
                        best_epoch = epoch
                        torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, os.path.join(model_state_file, 'model.pt'))
        mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = test(
            args,
            model,
            train_list+valid_list,
            test_list, 
            num_rels, 
            num_nodes, 
            use_cuda, 
            all_ans_list_test, 
            all_ans_list_r_test, 
            model_state_file, 
            static_graph,
            test_times,
            history_test_time_nogt,
            mode="test")
        if args.fusion:
            mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = test_fusion(
                args,
                model,
                train_list+valid_list,
                test_list,
                num_rels,
                num_nodes,
                use_cuda,
                all_ans_list_test,
                all_ans_list_r_test,
                model_state_file,
                static_graph,
                test_times,
                history_test_time_nogt,
                "test")
    return mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r


if __name__ == '__main__':
    import warnings
    warnings.filterwarnings("ignore")
    parser = argparse.ArgumentParser(description='TIRGN')

    # 新增参数
    parser.add_argument('--relation_evolve', action='store_true', default=False)
    parser.add_argument('--seed', type=int, default=6666)
    parser.add_argument('--model_name', type=str, default='')
    parser.add_argument('--des_info', type=str, default='')
    parser.add_argument('--fre_norm', action='store_true', default=False)
    parser.add_argument('--original', action='store_true', default=False)

    # 规则融合
    parser.add_argument('--cands_file', type=str, default='')
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('--not_softmax', action='store_true', default=False)
    parser.add_argument('--fusion', action='store_true', default=False)
    parser.add_argument('--setting', type=str, default='random')
    parser.add_argument('--rule_ten_ls_file', type=str, default='')
    parser.add_argument('--cands_with_rules_file', type=str, default='')
    parser.add_argument('--rule_evolve', action='store_true', default=False)
    parser.add_argument('--rule_neg', type=int, default=-1)
    parser.add_argument('--atten_matr_num', type=int, default=0)
    parser.add_argument('--gru', action='store_true', default=False)
    parser.add_argument('--gru_init', type=str, default='rel')
    parser.add_argument('--gru_atten', type=str, default='')
    parser.add_argument('--gru_time_method', type=str, default='')
    parser.add_argument('--lam_gru', type=float, default=0.1)
    parser.add_argument('--cands_type', type=str, default='with_rules')
    parser.add_argument('--score_method', type=str, default='all_plus')
    parser.add_argument('--score_atten', type=float, default=0.5)
    parser.add_argument('--weight_score_learnable', action='store_true', default=False)  # 自动学习权重
    parser.add_argument('--rule_used', type=int, default=-1)
    parser.add_argument('--rule_atten', type=str, default='')
    parser.add_argument('--semantic_sim', type=str, default='')
    parser.add_argument('--time_sim', type=str, default='')
    parser.add_argument('--lam_rule', type=float, default=0.1)
    parser.add_argument('--weight_time_learnable', action='store_true', default=False)
    parser.add_argument('--sim_atten', type=float, default=0.5)
    parser.add_argument('--time_emb', action='store_true', default=False)
    parser.add_argument('--three_plus', action='store_true', default=False)
    parser.add_argument('--rule_rate', type=float, default=0.2)
    parser.add_argument('--evolve_type', type=str, default='dynamic')
    parser.add_argument('--window', type=int, default=30)
    parser.add_argument('--time_file', action='store_true', default=False)
    parser.add_argument('--time_fun_origin', action='store_true', default=False)
    parser.add_argument('--negative', type=int, default=-1)
    parser.add_argument('--multi_gpu', type=int, nargs='*', default=-1)
    parser.add_argument('--his_len', type=int, default=1)


    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--batch-size", type=int, default=1,
                        help="batch-size")
    parser.add_argument("-d", "--dataset", type=str, required=True,
                        help="dataset to use")
    parser.add_argument("--test", action='store_true', default=False,
                        help="load stat from dir and directly test")
    parser.add_argument("--run-analysis", action='store_true', default=False,
                        help="print log info")
    parser.add_argument("--run-statistic", action='store_true', default=False,
                        help="statistic the result")
    parser.add_argument("--multi-step", action='store_true', default=False,
                        help="do multi-steps inference without ground truth")
    parser.add_argument("--topk", type=int, default=50,
                        help="choose top k entities as results when do multi-steps without ground truth")
    parser.add_argument("--add-static-graph",  action='store_true', default=False,
                        help="use the info of static graph")
    parser.add_argument("--add-rel-word", action='store_true', default=False,
                        help="use words in relaitons")
    parser.add_argument("--relation-evaluation", action='store_true', default=False,
                        help="save model accordding to the relation evalution")

    # configuration for encoder RGCN stat
    parser.add_argument("--weight", type=float, default=1,
                        help="weight of static constraint")
    parser.add_argument("--task-weight", type=float, default=0.7,
                        help="weight of entity prediction task")
    parser.add_argument("--discount", type=float, default=1,
                        help="discount of weight of static constraint")
    parser.add_argument("--angle", type=int, default=10,
                        help="evolution speed")

    parser.add_argument("--encoder", type=str, default="uvrgcn",
                        help="method of encoder")
    parser.add_argument("--aggregation", type=str, default="none",
                        help="method of aggregation")
    parser.add_argument("--dropout", type=float, default=0.2,
                        help="dropout probability")
    parser.add_argument("--skip-connect", action='store_true', default=False,
                        help="whether to use skip connect in a RGCN Unit")
    parser.add_argument("--n-hidden", type=int, default=200,
                        help="number of hidden units")
    parser.add_argument("--opn", type=str, default="sub",
                        help="opn of compgcn")

    parser.add_argument("--n-bases", type=int, default=100,
                        help="number of weight blocks for each relation")
    parser.add_argument("--n-basis", type=int, default=100,
                        help="number of basis vector for compgcn")
    parser.add_argument("--n-layers", type=int, default=2,
                        help="number of propagation rounds")
    parser.add_argument("--self-loop", action='store_true', default=True,
                        help="perform layer normalization in every layer of gcn ")
    parser.add_argument("--layer-norm", action='store_true', default=False,
                        help="perform layer normalization in every layer of gcn ")
    parser.add_argument("--relation-prediction", action='store_true', default=False,
                        help="add relation prediction loss")
    parser.add_argument("--entity-prediction", action='store_true', default=False,
                        help="add entity prediction loss")
    parser.add_argument("--split_by_relation", action='store_true', default=False,
                        help="do relation prediction")

    # configuration for stat training
    parser.add_argument("--n-epochs", type=int, default=50,
                        help="number of minimum training epochs on each time step")
    parser.add_argument("--lr", type=float, default=0.001,
                        help="learning rate")
    parser.add_argument("--grad-norm", type=float, default=1.0,
                        help="norm to clip gradient to")
    parser.add_argument('--max_train_length', type=int, default=10)

    # configuration for evaluating
    parser.add_argument("--evaluate-every", type=int, default=20,
                        help="perform evaluation every n epochs")

    # configuration for decoder
    parser.add_argument("--decoder", type=str, default="convtranse",
                        help="method of decoder")
    parser.add_argument("--input-dropout", type=float, default=0.2,
                        help="input dropout for decoder ")
    parser.add_argument("--hidden-dropout", type=float, default=0.2,
                        help="hidden dropout for decoder")
    parser.add_argument("--feat-dropout", type=float, default=0.2,
                        help="feat dropout for decoder")

    # configuration for sequences stat
    parser.add_argument("--train-history-len", type=int, default=10,
                        help="history length")
    parser.add_argument("--test-history-len", type=int, default=20,
                        help="history length for test")
    parser.add_argument("--dilate-len", type=int, default=1,
                        help="dilate history graph")

    # configuration for optimal parameters
    parser.add_argument("--grid_search", action='store_true', default=False,
                        help="perform grid search for best configuration")
    parser.add_argument("-tune", "--tune", type=str, default="history_len,n_layers,dropout,n_bases,angle,history_rate",
                        help="stat to use")
    parser.add_argument("--num-k", type=int, default=500,
                        help="number of triples generated")

    # configuration for global history
    parser.add_argument("--history_rate", type=float, default=0.3,
                        help="history rate")

    parser.add_argument("--save", type=str, default="one",
                        help="number of save")

    args = parser.parse_args()
    print(args)
    if args.grid_search:
        hyper_dict = {
                      'lam_rule': [0.6, 0.7]
                      }  # 组合参数
        list_keys = list(hyper_dict.keys())
        list_values = list(hyper_dict.values())
        combinations = list(itertools.product(*list_values))
        model_name = args.model_name

        for comb in combinations:
            update_dict = dict(zip(list_keys, comb))
            args.__dict__.update(update_dict)
            args.model_name = model_name
            for k, v in update_dict.items():
                args.model_name += f'-{k}_{v}'
            start_time = datetime.now()
            run_experiment(args)
            end_time = datetime.now()
            logging.info(f'total time: {end_time - start_time}')
        # out_log = '../results/{}.{}.gs'.format(args.dataset, args.encoder+"-"+args.decoder+"-"+args.save)
        # with open(out_log, 'w') as o_f:
        #     # o_f = open(out_log, 'w')
        #     print("** Grid Search **")
        #     o_f.write("** Grid Search **\n")
        #     hyperparameters = args.tune.split(',')

        #     if args.tune == '' or len(hyperparameters) < 1:
        #         print("No hyperparameter specified.")
        #         sys.exit(0)
        #     if args.dataset == "ICEWS14s":
        #         hp_range_ = hp_range
        #     if args.dataset == "WIKI":
        #         hp_range_ = hp_range_WIKI
        #     if args.dataset == "YAGO":
        #         hp_range_ = hp_range_YAGO
        #     if args.dataset == "ICEWS18":
        #         hp_range_ = hp_range_ICEWS18
        #     if args.dataset == "ICEWS05-15":
        #         hp_range_ = hp_range_ICEWS05_15
        #     if args.dataset == "GDELT":
        #         hp_range_ = hp_range_GDELT
        #     grid = hp_range_[hyperparameters[0]]
        #     for hp in hyperparameters[1:]:
        #         grid = itertools.product(grid, hp_range_[hp])
        #     hits_at_1s = {}
        #     hits_at_10s = {}
        #     mrrs = {}
        #     grid = list(grid)
        #     print('* {} hyperparameter combinations to try'.format(len(grid)))
        #     o_f.write('* {} hyperparameter combinations to try\n'.format(len(grid)))
        #     o_f.close()

        # for i, grid_entry in enumerate(list(grid)):

        #     o_f = open(out_log, 'a')

        #     if not (type(grid_entry) is list or type(grid_entry) is list):
        #         grid_entry = [grid_entry]
        #     grid_entry = utils.flatten(grid_entry)
        #     print('\n\n* Hyperparameter Set {}:'.format(i))
        #     o_f.write('* Hyperparameter Set {}:\n'.format(i))
        #     signature = ''
        #     print(grid_entry)
        #     o_f.write("\t".join([str(_) for _ in grid_entry]) + "\n")
        #     # def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
        #     args.test = False
        #     args.multi_step = False
        #     mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = run_experiment(args, grid_entry[0], grid_entry[1], grid_entry[2], grid_entry[3], grid_entry[4], grid_entry[5])
        #     hits = [1, 3, 10]
        #     o_f.write("MRR (raw): {:.6f}\n".format(mrr_raw))
        #     for hit_i, hit in enumerate(hits):
        #         o_f.write("Hits (raw) @ {}: {:.6f}\n".format(hit, hit_result_raw[hit_i].item()))
        #     o_f.write("MRR (raw): {:.6f}\n".format(mrr_filter))
        #     for hit_i, hit in enumerate(hits):
        #         o_f.write("Hits (raw) @ {}: {:.6f}\n".format(hit, hit_result_filter[hit_i].item()))
        #     o_f.write("MRR (raw): {:.6f}\n".format(mrr_raw_r))
        #     for hit_i, hit in enumerate(hits):
        #         o_f.write("Hits (raw) @ {}: {:.6f}\n".format(hit, hit_result_raw_r[hit_i].item()))
        #     o_f.write("MRR (raw): {:.6f}\n".format(mrr_filter_r))
        #     for hit_i, hit in enumerate(hits):
        #         o_f.write("Hits (raw) @ {}: {:.6f}\n".format(hit, hit_result_filter_r[hit_i].item()))
        #     # no ground truth
        #     args.test = True
        #     args.topk = 0
        #     args.multi_step = True
        #     mrr_raw, mrr_filter, mrr_raw_r, mrr_filter_r, hit_result_raw, hit_result_filter, hit_result_raw_r, hit_result_filter_r = run_experiment(
        #         args, grid_entry[0], grid_entry[1], grid_entry[2], grid_entry[3], grid_entry[4], grid_entry[5])
        #     o_f.write("No ground truth result:\n")
        #     o_f.write("MRR (raw): {:.6f}\n".format(mrr_raw))
        #     for hit_i, hit in enumerate(hits):
        #         o_f.write("Hits (raw) @ {}: {:.6f}\n".format(hit, hit_result_raw[hit_i].item()))
        #     o_f.write("MRR (raw): {:.6f}\n".format(mrr_filter))
        #     for hit_i, hit in enumerate(hits):
        #         o_f.write("Hits (raw) @ {}: {:.6f}\n".format(hit, hit_result_filter[hit_i].item()))
        #     o_f.write("MRR (raw): {:.6f}\n".format(mrr_raw_r))
        #     for hit_i, hit in enumerate(hits):
        #         o_f.write("Hits (raw) @ {}: {:.6f}\n".format(hit, hit_result_raw_r[hit_i].item()))
        #     o_f.write("MRR (raw): {:.6f}\n".format(mrr_filter_r))
        #     for hit_i, hit in enumerate(hits):
        #         o_f.write("Hits (raw) @ {}: {:.6f}\n".format(hit, hit_result_filter_r[hit_i].item()))

    # single run
    else:
        start_time = datetime.now()
        run_experiment(args)
        end_time = datetime.now()
        logging.info(f'total time: {end_time - start_time}')
    sys.exit()
