import json
import random
import sys
import numpy as np
from datetime import datetime

from torch.nn import functional as F
import torch
from torch.nn.parameter import Parameter
import math
import os
from torch.nn import GRU
import torch.nn as nn
path_dir = os.getcwd()
# os.environ['CUDA_VISIBLE_DEVICES'] = '5,6'  # 这里输入你的GPU_id


class TimeConvTransR(torch.nn.Module):
    def __init__(self, num_relations, embedding_dim, input_dropout=0, hidden_dropout=0, feature_map_dropout=0, channels=50, kernel_size=3, use_bias=True):
        super(TimeConvTransR, self).__init__()
        self.inp_drop = torch.nn.Dropout(input_dropout)
        self.hidden_drop = torch.nn.Dropout(hidden_dropout)
        self.feature_map_drop = torch.nn.Dropout(feature_map_dropout)
        self.loss = torch.nn.BCELoss()

        input_cha = 4
        self.conv1 = torch.nn.Conv1d(input_cha, channels, kernel_size, stride=1,
                               padding=int(math.floor(kernel_size / 2)))  # kernel size is odd, then padding = math.floor(kernel_size/2)
        self.bn0 = torch.nn.BatchNorm1d(input_cha)
        self.bn1 = torch.nn.BatchNorm1d(channels)
        self.bn2 = torch.nn.BatchNorm1d(embedding_dim)
        self.register_parameter('b', Parameter(torch.zeros(num_relations*2)))
        self.fc = torch.nn.Linear(embedding_dim * channels, embedding_dim)

    def forward(self, embedding, emb_rel, emb_time, triplets, nodes_id=None, mode="train", negative_rate=0, partial_embeding=None, fre_norm=False):

        e1_embedded_all = F.tanh(embedding)
        batch_size = len(triplets)
        e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)
        e2_embedded = e1_embedded_all[triplets[:, 2]].unsqueeze(1)
        emb_time_1, emb_time_2 = emb_time
        emb_time_1 = emb_time_1.unsqueeze(1)
        emb_time_2 = emb_time_2.unsqueeze(1)

        # emb_rel = emb_rel[-1]

        stacked_inputs = torch.cat([e1_embedded, e2_embedded, emb_time_1, emb_time_2], 1)
        # stacked_inputs = torch.cat([e1_embedded, e2_embedded], 1)  # 不使用时间函数
        # stacked_inputs = torch.cat([e1_embedded, e2_embedded, emb_time_1], 1)  # 不带周期
        # stacked_inputs = torch.cat([e1_embedded, e2_embedded, emb_time_2], 1)  # 带周期
        stacked_inputs = self.bn0(stacked_inputs)
        x = self.inp_drop(stacked_inputs)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)
        if partial_embeding is None:
            x = torch.mm(x, emb_rel.transpose(1, 0))
        else:
            x = torch.mm(x, emb_rel.transpose(1, 0))
            if fre_norm:
                partial_embeding = F.normalize(partial_embeding)
            x = torch.mul(x, partial_embeding)  # 只保留在当前snapshot中出现过的尾实体的分数
        return x


class TimeConvTransE(torch.nn.Module):
    def __init__(self, args, num_entities, embedding_dim, input_dropout=0, hidden_dropout=0, feature_map_dropout=0, channels=50, kernel_size=3, use_bias=True):

        super(TimeConvTransE, self).__init__()

        self.inp_drop = torch.nn.Dropout(input_dropout)
        self.hidden_drop = torch.nn.Dropout(hidden_dropout)
        self.feature_map_drop = torch.nn.Dropout(feature_map_dropout)
        self.loss = torch.nn.BCELoss()

        input_cha = 4
        self.conv1 = torch.nn.Conv1d(input_cha, channels, kernel_size, stride=1,
                               padding=int(math.floor(kernel_size / 2)))
        self.bn0 = torch.nn.BatchNorm1d(input_cha)
        self.bn1 = torch.nn.BatchNorm1d(channels)
        self.bn2 = torch.nn.BatchNorm1d(embedding_dim)
        self.register_parameter('b', Parameter(torch.zeros(num_entities)))
        self.fc = torch.nn.Linear(embedding_dim * channels, embedding_dim)
        self.args = args

    def forward(self, embedding, emb_rel, emb_time, triplets, nodes_id=None, mode="train", negative_rate=0, partial_embeding=None, fre_norm=False):
        if self.args.cands_type == 'evolve':
            rel_embedded = emb_rel[-1][triplets[:, 1]].unsqueeze(1)
        else:
            rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)
        e1_embedded_all = F.tanh(embedding)
        batch_size = len(triplets)
        e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)  # batch_size,1,h_dim
        # rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)  # batch_size,1,h_dim
        emb_time_1, emb_time_2 = emb_time
        emb_time_1 = emb_time_1.unsqueeze(1)
        emb_time_2 = emb_time_2.unsqueeze(1)

        stacked_inputs = torch.cat([e1_embedded, rel_embedded, emb_time_1, emb_time_2], 1)  # batch_size,2,h_dim
        # stacked_inputs = torch.cat([e1_embedded, rel_embedded], 1)  # 不使用时间函数
        # stacked_inputs = torch.cat([e1_embedded, rel_embedded, emb_time_1], 1)  # 不带周期
        # stacked_inputs = torch.cat([e1_embedded, rel_embedded, emb_time_2], 1)  # 带周期
        stacked_inputs = self.bn0(stacked_inputs)  # batch_size,2,h_dim
        x = self.inp_drop(stacked_inputs)  # batch_size,2,h_dim
        x = self.conv1(x)  # batch_size,2,h_dim
        x = self.bn1(x)  # batch_size,channels,h_dim
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(batch_size, -1)  # batch_size,channels*h_dim
        x = self.fc(x)  # batch_size,channels*h_dim
        x = self.hidden_drop(x)  # batch_size,h_dim
        if batch_size > 1:
            x = self.bn2(x)
        x = F.relu(x)
        if partial_embeding is None:
            x = torch.mm(x, e1_embedded_all.transpose(1, 0))
        else:
            x = torch.mm(x, e1_embedded_all.transpose(1, 0))
            if fre_norm:
                partial_embeding = F.normalize(partial_embeding)
            x = torch.mul(x, partial_embeding)  # 将历史上没出现过的尾实体的评分置为0
        return x

    def forward_slow(self, embedding, emb_rel, triplets):

        e1_embedded_all = F.tanh(embedding)
        batch_size = len(triplets)
        e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)
        rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)
        stacked_inputs = torch.cat([e1_embedded, rel_embedded], 1)
        stacked_inputs = self.bn0(stacked_inputs)
        x = self.inp_drop(stacked_inputs)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        x = self.hidden_drop(x)
        if batch_size > 1:
            x = self.bn2(x)
        x = F.relu(x)
        e2_embedded = e1_embedded_all[triplets[:, 2]]
        score = torch.sum(torch.mul(x, e2_embedded), dim=1)
        pred = score
        return pred


class RuleConvTransE_old(torch.nn.Module):
    def __init__(self, args, num_entities, embedding_dim, input_dropout=0, hidden_dropout=0, feature_map_dropout=0, mode='train', channels=50, kernel_size=3, use_bias=True):

        super(RuleConvTransE_old, self).__init__()

        self.inp_drop = torch.nn.Dropout(input_dropout)
        self.hidden_drop = torch.nn.Dropout(hidden_dropout)
        self.feature_map_drop = torch.nn.Dropout(feature_map_dropout)
        self.loss = torch.nn.BCELoss()

        if args.time_emb:
            input_cha = 4
        else:
            input_cha = 2
        self.conv1 = torch.nn.Conv1d(input_cha, channels, kernel_size, stride=1,
                               padding=int(math.floor(kernel_size / 2)))
        self.bn0 = torch.nn.BatchNorm1d(input_cha)
        self.bn1 = torch.nn.BatchNorm1d(channels)
        self.bn2 = torch.nn.BatchNorm1d(embedding_dim)
        self.register_parameter('b', Parameter(torch.zeros(num_entities)))
        self.fc = torch.nn.Linear(embedding_dim * channels, embedding_dim)
        self.h_dim = args.n_hidden
        self.args = args

        self.gru = GRU(input_size=embedding_dim, hidden_size=embedding_dim, batch_first=True)
        self.cands_with_rules_dict_train = {}
        self.cands_with_rules_dict_valid = {}
        self.cands_with_rules_dict_test = {}
        if args.cands_type == 'with_rules':
            # file_dir = f'../output_rule/{args.dataset}/seed{args.seed}/cands_with_rules'
            file_dir = f'/data/zsz/Fusion/output_rule/{args.dataset}/seed{args.seed}/cands_with_rules'
        elif args.cands_type == 'evolve':
            file_dir = f'../output_rule/{args.dataset}/seed{args.seed}/cands_evolve'

        # 读取文件
        train_name = f'{args.cands_with_rules_file}_train.pt'
        train_path = os.path.join(file_dir, train_name)
        valid_name = f'{args.cands_with_rules_file}_valid.pt'
        valid_path = os.path.join(file_dir, valid_name)
        test_name = f'{args.cands_with_rules_file}_test.pt'
        test_path = os.path.join(file_dir, test_name)

        # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        # 训练需要加载三个文件，测试只需要加载一个文件
        if mode == 'train':
        #     self.cands_with_rules_dict_train = torch.load(train_path, map_location=device)
        #     self.cands_with_rules_dict_valid = torch.load(valid_path, map_location=device)
        # self.cands_with_rules_dict_test = torch.load(test_path, map_location=device)
            self.cands_with_rules_dict_train = torch.load(train_path)
            self.cands_with_rules_dict_valid = torch.load(valid_path)
        self.cands_with_rules_dict_test = torch.load(test_path)
        if args.cands_type == 'with_rules':
            rule_ten_ls_file = f'{args.rule_ten_ls_file}'
            rule_ten_ls_path = os.path.join(file_dir, rule_ten_ls_file)
            # self.rule_ten_ls = torch.load(rule_ten_ls_path, map_location=device)
            self.rule_ten_ls = torch.load(rule_ten_ls_path)

        # 对聚合规则表征和关系表征作加权
        if args.atten_matr_num > 0:
            self.weight_rule = nn.Parameter(torch.Tensor(self.h_dim, self.h_dim))
            nn.init.xavier_uniform_(self.weight_rule, gain=nn.init.calculate_gain('relu'))
            self.activation = F.rrelu
            self.dropout = nn.Dropout(args.dropout)
            if args.atten_matr_num > 1:
                self.weight_rel = nn.Parameter(torch.Tensor(self.h_dim, self.h_dim))
                nn.init.xavier_uniform_(self.weight_rel, gain=nn.init.calculate_gain('relu'))

        if args.weight_score_learnable:
            self.score_atten = nn.Parameter(torch.tensor(args.score_atten).float(), requires_grad=True)

        if args.weight_time_learnable:
            self.lam_gru = nn.Parameter(torch.tensor(args.lam_gru).float(), requires_grad=True)
            self.lam_rule = nn.Parameter(torch.tensor(args.lam_rule).float(), requires_grad=True)
            self.sim_atten = nn.Parameter(torch.tensor(args.sim_atten).float(), requires_grad=True)

    def get_query_repr(self, ent_emb, rel_emb, time_emb_1, time_emb_2, batch_size):
        if self.args.time_emb:
            stacked_inputs = torch.cat([ent_emb, rel_emb, time_emb_1, time_emb_2], 1)  # batch_size,2,h_dim
        else:
            stacked_inputs = torch.cat([ent_emb, rel_emb], 1)  # 不使用时间函数
        # stacked_inputs = torch.cat([e1_embedded, rel_embedded, emb_time_1], 1)  # 不带周期
        # stacked_inputs = torch.cat([e1_embedded, rel_embedded, emb_time_2], 1)  # 带周期
        stacked_inputs = self.bn0(stacked_inputs)  # batch_size,2,h_dim
        x = self.inp_drop(stacked_inputs)  # batch_size,2,h_dim
        x = self.conv1(x)  # batch_size,2,h_dim
        x = self.bn1(x)  # batch_size,channels,h_dim
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(batch_size, -1)  # batch_size,channels*h_dim
        x = self.fc(x)  # batch_size,channels*h_dim
        x = self.hidden_drop(x)  # batch_size,h_dim
        if batch_size > 1:
            x = self.bn2(x)
        x = F.relu(x)

        return x

    def forward(self, args, embedding, emb_rel, emb_time, triplets, nodes_id=None, mode="train", negative_rate=0, partial_embeding=None, fre_norm=False):
        if self.args.cands_type == 'with_rules':
            pseudo_rel_emb = torch.zeros((1, args.n_hidden), requires_grad=False).cuda()  # 伪关系嵌入，padding长度不足的关系
            extend_rel_emb = torch.cat((emb_rel, pseudo_rel_emb), dim=0)
            rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)  # batch_size,1,h_dim
        elif self.args.cands_type == 'evolve':
            seq_len = emb_rel.shape[0]
            pseudo_rel_emb = torch.zeros((seq_len, 1, args.n_hidden), requires_grad=False).cuda()  # 伪关系嵌入，padding长度不足的关系
            extend_rel_emb = torch.cat((emb_rel, pseudo_rel_emb), dim=1)
            rel_embedded = emb_rel[-1][triplets[:, 1]].unsqueeze(1)  # batch_size,1,h_dim

        e1_embedded_all = F.tanh(embedding)
        batch_size = len(triplets)
        e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)  # batch_size,1,h_dim
        # rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)  # batch_size,1,h_dim
        emb_time_1, emb_time_2 = emb_time
        emb_time_1 = emb_time_1.unsqueeze(1)
        emb_time_2 = emb_time_2.unsqueeze(1)
        ent_num = e1_embedded_all.shape[0]

        # 利用ConvTransE每个样本对所有尾实体的评分
        query_repr_all = self.get_query_repr(e1_embedded, rel_embedded, emb_time_1, emb_time_2, batch_size)
        score_all = torch.mm(query_repr_all, e1_embedded_all.transpose(1, 0))
        score_rule = torch.zeros(batch_size, ent_num).cuda()

        time = triplets[0, 3].item()
        # 得到对应时间戳元组列表，包括能达到的尾实体一维id张量和01二维张量
        if mode == 'train':
            tail2rule_list = self.cands_with_rules_dict_train[time]
        elif mode == 'valid':
            tail2rule_list = self.cands_with_rules_dict_valid[time]
        else:
            tail2rule_list = self.cands_with_rules_dict_test[time]

        # 遍历每个样本
        for i, triplet in enumerate(triplets):
            start_time = datetime.now()

            h, r, _, _ = triplet.tolist()
            h_emb = e1_embedded[i, 0]
            r_emb = rel_embedded[i, 0]

            if len(tail2rule_list[i]) == 0:  # 可到达的尾实体数为0
                continue

            if args.cands_type == 'with_rules':
                # 得到每个规则对应的关系嵌入，初始化隐藏表征，再过gru，得到演化规则张量
                if len(self.rule_ten_ls[r]) == 0:  # 关系对应规则数为0
                    continue
                rule_tensor, rule_id = self.rule_ten_ls[r]  # rule_num * rule_len
                rule_tensor = rule_tensor.cuda()
                rule_id = rule_id.cuda()
                rule_num, rule_len = rule_tensor.shape
                indices_times = (torch.arange(len(rule_id)) * rule_len).cuda()
                extend_rule_id = rule_id + indices_times
                rule_tensor = rule_tensor.reshape(-1)
                rule2rel_emb = extend_rel_emb[rule_tensor]  # (rule_num * rule_len) * h_dim
                rule2rel_emb = rule2rel_emb.reshape(rule_num, rule_len, -1)
                if args.gru:
                    if args.gru_init == 'zero':
                        init_hidden_emb = torch.zeros(1, rule_num, self.h_dim).cuda()
                    elif args.gru_init == 'rel':
                        init_hidden_emb = (r_emb.unsqueeze(0).repeat_interleave(rule_num, 0)).unsqueeze(0)  # 1 * rule_num * h_dim，使用关系作初始化隐藏嵌入
                    elif args.gru_init == 'rand':
                        init_hidden_emb = torch.rand(1, rule_num, self.h_dim).cuda()
                        init_hidden_emb = F.normalize(init_hidden_emb)

                    all_len_emb, _ = self.gru(rule2rel_emb, init_hidden_emb)
                    all_len_emb = all_len_emb.reshape(-1, args.n_hidden)  # 扩展为(rule_num * rule_len) * h_dim

                    if args.gru_atten != '':
                        mask_template = torch.tril(torch.ones(rule_len, rule_len), diagonal=0).cuda()  # 将主对角线及下三角的元素设为1
                        mask_gru = mask_template[rule_id]  # 大小为(tail_ent_num * rule_used) * rule_len的二维gru掩码张量
                        if args.gru_atten == 'mean':
                            mask_gru = mask_gru / mask_gru.sum(1).reshape(-1, 1)  # 除以长度求平均
                        # elif args.gru_atten == 'time':
                        #     if args.weight_time_learnable:
                        #         lam_gru = F.sigmoid(self.lam_gru)
                        #     else:
                        #         lam_gru = args.lam_gru
                        #     tss_info_all_exp = lam_gru * torch.exp(tss_info_all - args.train_history_len)
                        #     mask_gru = tss_info_all_exp.reshape(tail_ent_num * rule_used, rule_len) * mask_gru  # 先乘再padding成负无穷
                        #     mask_gru = torch.where(mask_gru == 0, -1e9, mask_gru)
                        #     mask_gru = F.softmax(mask_gru)
                        mask_gru = mask_gru.flatten().reshape(-1, 1)
                        rule_emb = (all_len_emb * mask_gru).reshape(rule_num, rule_len, -1)
                        rule_emb = rule_emb.sum(1)  # (tail_ent_num * rule_used) * h_dim的二维规则张量
                    else:
                        rule_emb = all_len_emb[extend_rule_id]  # 得到每个规则对应长度的表征
                    rule_emb = F.normalize(rule_emb)

                else:  # 简单的用规则关系链的关系表征平均值作为规则表征
                    rule_len_ls = extend_rule_id.tolist()
                    rule_len_ls = [idx + 1 - rule_len * i for i, idx in enumerate(rule_len_ls)]
                    rule_len_tensor = torch.Tensor(rule_len_ls).reshape(-1, 1).cuda()
                    rule_emb = rule2rel_emb.sum(1) / rule_len_tensor
                    rule_emb = F.normalize(rule_emb)

                # 用样本对应的01张量和演化规则张量得到大小为tail_ent_reachable * h_dim的聚合规则表征张量
                # tail_ent_ids, tail2rule = tail2rule_list[i]
                # tail2rule = tail2rule.cuda()
                tail_ent_ids, tail2rule_sorted = tail2rule_list[i]
                tail_ent_ids = tail_ent_ids.cuda()
                tail2rule_sorted = tail2rule_sorted.cuda()
                tail_ent_num = tail_ent_ids.shape[0]

                tail2rule = torch.zeros_like(tail2rule_sorted).cuda()
                if args.rule_used > -1:
                    tail2rule_sorted = tail2rule_sorted[:, :args.rule_used]
                tail2rule[torch.arange(tail_ent_num).unsqueeze(1).long(), tail2rule_sorted] = 1

                if self.args.semantic_sim != '':  # 计算关系和规则的相似度，代替01张量中1的位置，对规则表征做加权
                    if self.args.semantic_sim == 'head':
                        sim_rel2rule = h_emb @ rule_emb.transpose(0, 1)
                    elif self.args.semantic_sim == 'rel':
                        sim_rel2rule = r_emb @ rule_emb.transpose(0, 1)
                    elif self.args.semantic_sim == 'plus':
                        sim_rel2rule = (h_emb + r_emb) @ rule_emb.transpose(0, 1)
                    elif self.args.semantic_sim == 'cosine':
                        sim_rel2rule = F.cosine_similarity(r_emb.unsqueeze(0), rule_emb)
                    elif self.args.semantic_sim == 'euclidean':
                        sim_rel2rule = torch.norm(rule_emb - r_emb.unsqueeze(0), dim=1)
                    sim_rel2rule = sim_rel2rule.softmax(0)
                    sim_tensor = sim_rel2rule.unsqueeze(0).repeat_interleave(tail_ent_num, 0)
                    tail2rule_normalized = tail2rule * sim_tensor
                else:
                    tail2rule_normalized = tail2rule / tail2rule.sum(1).reshape(-1, 1)
                tail2rule_emb = torch.mm(tail2rule_normalized, rule_emb)
                tail2rule_emb = F.normalize(tail2rule_emb)
            elif args.cands_type == 'evolve':
                tail_ent_ids, rels_info_all, tss_info_all, rel_len_info_all_ori, mask_info_all = tail2rule_list[i]
                tail_ent_ids = tail_ent_ids.cuda()
                rels_info_all = rels_info_all.cuda()
                tss_info_all = tss_info_all.cuda()
                rel_len_info_all_ori = rel_len_info_all_ori.cuda()
                mask_info_all = mask_info_all.cuda()
                tail_ent_num, rule_used, rule_len = rels_info_all.shape
                # 根据使用规则数量对张量做剪切
                if args.rule_used > -1 and args.rule_used < rule_used:
                    rel_len_info_all_ori = rel_len_info_all_ori.reshape(tail_ent_num, rule_used)
                    mask_info_all = mask_info_all.reshape(tail_ent_num, rule_used)
                    rel_len_info_all_ori = rel_len_info_all_ori[:, :args.rule_used].reshape(-1)
                    mask_info_all = mask_info_all[:, :args.rule_used].reshape(-1)
                    rels_info_all = rels_info_all[:, :args.rule_used]
                    tss_info_all = tss_info_all[:, :args.rule_used]
                    rule_used = args.rule_used
                length = len(rel_len_info_all_ori)
                indices_times = (torch.arange(length) * rule_len).cuda()  # 加上索引的偏移量
                rel_len_info_all = rel_len_info_all_ori + indices_times

                # 得到大小为(tail_reachable * rule_num) * rule_len * h_dim的三维关系张量
                rels_info_all = rels_info_all.reshape(-1, rule_len)  # (tail_reachable * rule_num) * rule_len二维张量
                tss_info_all = tss_info_all.reshape(-1, rule_len)
                rule2rel_emb = extend_rel_emb[tss_info_all, rels_info_all]

                # 关系张量过gru，利用规则id张量，得到大小为(tail_ent_num * rule_used) * h_dim的规则张量
                if args.gru_init == 'zero':
                    init_hidden_emb = torch.zeros(1, tail_ent_num * rule_used, self.h_dim).cuda()
                elif args.gru_init == 'rel':
                    init_hidden_emb = (r_emb.unsqueeze(0).repeat_interleave(tail_ent_num * rule_used, 0)).unsqueeze(0)
                elif args.gru_init == 'rand':
                    init_hidden_emb = torch.rand(1, tail_ent_num * rule_used, self.h_dim).cuda()
                    init_hidden_emb = F.normalize(init_hidden_emb)

                all_len_emb, _ = self.gru(rule2rel_emb, init_hidden_emb)
                all_len_emb = all_len_emb.reshape(-1, args.n_hidden)  # 扩展为(tail_ent_num * rule_used * rule_len) * h_dim
                if args.gru_atten != '':
                    mask_template = torch.tril(torch.ones(rule_len, rule_len), diagonal=0).cuda()  # 将主对角线及下三角的元素设为1
                    mask_gru = mask_template[rel_len_info_all_ori]  # 大小为(tail_ent_num * rule_used) * rule_len的二维gru掩码张量
                    if args.gru_atten == 'mean':
                        mask_gru = mask_gru / mask_gru.sum(1).reshape(-1, 1)  # 除以长度求平均
                    elif args.gru_atten == 'time':
                        if args.weight_time_learnable:
                            lam_gru = F.sigmoid(self.lam_gru)
                        else:
                            lam_gru = args.lam_gru
                        tss_info_all_exp = lam_gru * torch.exp(tss_info_all - args.train_history_len)
                        mask_gru = tss_info_all_exp.reshape(tail_ent_num * rule_used, rule_len) * mask_gru  # 先乘再padding成负无穷
                        mask_gru = torch.where(mask_gru == 0, -1e9, mask_gru)
                        mask_gru = F.softmax(mask_gru)
                    mask_gru = mask_gru.flatten().reshape(-1, 1)
                    rule_emb = (all_len_emb * mask_gru).reshape(tail_ent_num * rule_used, rule_len, -1)
                    rule_emb = rule_emb.sum(1)  # (tail_ent_num * rule_used) * h_dim的二维规则张量
                else:
                    rule_emb = all_len_emb[rel_len_info_all]  # 得到每个规则对应长度的表征
                rule_emb = F.normalize(rule_emb)

                # 计算关系和规则的相似度，将padding规则的评分掩盖，再对非0评分softmax，对规则张量加权，最后池化为聚合规则表征张量
                if self.args.rule_atten != '':
                    if self.args.semantic_sim != '':  # 计算关系和规则的相似度，代替01张量中1的位置，对规则表征做加权
                        if self.args.semantic_sim == 'head':
                            sim_rel2rule_semantic = h_emb @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'rel':
                            sim_rel2rule_semantic = r_emb @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'plus':
                            sim_rel2rule_semantic = (h_emb + r_emb) @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'cosine':
                            sim_rel2rule_semantic = F.cosine_similarity(r_emb.unsqueeze(0), rule_emb)
                        elif self.args.semantic_sim == 'euclidean':
                            sim_rel2rule_semantic = torch.norm(rule_emb - r_emb.unsqueeze(0), dim=1)
                        sim_rel2rule_semantic = sim_rel2rule_semantic * mask_info_all
                    if self.args.time_sim != '':
                        # if self.args.time_sim == 'earliest':
                        #     tss_info_earliest = tss_info_all[:, 0]  # 所有适用的规则使用的最早时间戳
                        #     if args.weight_time_learnable:
                        #         lam_rule = F.sigmoid(self.lam_rule)
                        #     else:
                        #         lam_rule = args.lam_rule
                        #     sim_rel2rule_time = lam_rule * torch.exp(tss_info_earliest - args.train_history_len)
                        if args.weight_time_learnable:
                            lam_rule = F.sigmoid(self.lam_rule)
                        else:
                            lam_rule = args.lam_rule
                        if 'earliest' in self.args.time_sim:
                            tss_info_used = tss_info_all[:, 0]  # 所有适用的规则使用的最早时间戳
                        elif 'average' in self.args.time_sim:
                            mask_template = torch.tril(torch.ones(rule_len, rule_len), diagonal=0).cuda()  # 将主对角线及下三角的元素设为1
                            # mask_template = torch.where(mask_template == 0, -1, mask_template)  # 用-1作掩码，防止原本时间戳均为0的情况
                            mask_rel = mask_template[rel_len_info_all_ori]  # 大小为(tail_ent_num * rule_used) * rule_len的二维gru掩码张量
                            tss_info_used = tss_info_all * mask_rel  # 掩码后时间戳
                            tss_info_used = tss_info_used.sum(1) / (mask_rel != 0).sum(1)  # 有效时间戳的平均长度，时间戳为负数表示要被要被掩盖
                        if 'exp' in self.args.time_sim:
                            sim_rel2rule_time = lam_rule * torch.exp(tss_info_used - args.train_history_len)
                        elif 'tan' in self.args.time_sim:
                            sim_rel2rule_time = lam_rule * torch.tan(tss_info_used - args.train_history_len)
                        sim_rel2rule_time = sim_rel2rule_time * mask_info_all
                    if self.args.rule_atten == 'semantic':
                        sim_rel2rule_semantic = torch.where(sim_rel2rule_semantic == 0, -1e9, sim_rel2rule_semantic)
                        sim_rel2rule = F.softmax(sim_rel2rule_semantic).reshape(-1, 1)
                    elif self.args.rule_atten == 'time':
                        sim_rel2rule_time = torch.where(sim_rel2rule_time == 0, -1e9, sim_rel2rule_time)
                        sim_rel2rule = F.softmax(sim_rel2rule_time).reshape(-1, 1)
                    elif self.args.rule_atten == 'fusion':
                        sim_rel2rule = torch.exp(sim_rel2rule_semantic) * sim_rel2rule_time
                        sim_rel2rule = torch.where(sim_rel2rule == 0, -1e9, sim_rel2rule)
                        sim_rel2rule = F.softmax(sim_rel2rule).reshape(-1, 1)
                    elif self.args.rule_atten == 'weighted':
                        sim_rel2rule_semantic = torch.where(sim_rel2rule_semantic == 0, -1e9, sim_rel2rule_semantic)
                        sim_rel2rule_time = torch.where(sim_rel2rule_time == 0, -1e9, sim_rel2rule_time)
                        sim_rel2rule_semantic = F.softmax(sim_rel2rule_semantic)
                        sim_rel2rule_time = F.softmax(sim_rel2rule_time)
                        if args.weight_time_learnable:
                            sim_atten = F.sigmoid(self.sim_atten)
                        else:
                            sim_atten = args.sim_atten
                        sim_rel2rule = (sim_atten * sim_rel2rule_semantic + (1 - sim_atten) * sim_rel2rule_time).reshape(-1, 1)
                    rule_emb = (rule_emb * sim_rel2rule).reshape(tail_ent_num, rule_used, -1)
                else:  # 直接掩盖padding的规则，再对剩余规则表征求和作为聚合规则表征
                    mask_info_all = mask_info_all.reshape(-1, 1)
                    rule_emb = (rule_emb * mask_info_all).reshape(tail_ent_num, rule_used, -1)
                tail2rule_emb = rule_emb.sum(1)  # tail_ent_num * h_dim二维张量
                tail2rule_emb = F.normalize(tail2rule_emb)

            # 对聚合规则表征和关系表征作加权
            if args.atten_matr_num > 0:
                rel_emb_repeat_reachable = r_emb.unsqueeze(0).repeat_interleave(tail_ent_num, dim=0)
                if args.atten_matr_num == 1:
                    weight_rule_normalized = F.sigmoid(self.weight_rule)
                    rel_emb_att = self.activation(torch.mm(tail2rule_emb, weight_rule_normalized) + torch.mm(rel_emb_repeat_reachable, 1 - weight_rule_normalized))
                elif args.atten_matr_num == 2:
                    rel_emb_att = self.activation(torch.mm(tail2rule_emb, self.weight_rule) + torch.mm(rel_emb_repeat_reachable, self.weight_rel))
                rel_emb_att = self.dropout(rel_emb_att)
                rel_emb_att = F.normalize(rel_emb_att)
            else:
                rel_emb_att = tail2rule_emb

            h_emb_repeat = h_emb.unsqueeze(0).repeat_interleave(tail_ent_num, dim=0)
            assert h_emb_repeat.shape == tail2rule_emb.shape  # 确保大小一致

            # 计算能到达的尾实体个数个评分
            time_emb_1_sample = emb_time_1[i].repeat_interleave(tail_ent_num, dim=0).unsqueeze(1)
            time_emb_2_sample = emb_time_2[i].repeat_interleave(tail_ent_num, dim=0).unsqueeze(1)
            query_repr = self.get_query_repr(h_emb_repeat.unsqueeze(1), rel_emb_att.unsqueeze(1), time_emb_1_sample, time_emb_2_sample, tail_ent_num)
            tail_ent_emb = e1_embedded_all[tail_ent_ids]
            sample_score = (query_repr * tail_ent_emb).sum(1)
            score_rule[i, tail_ent_ids] = sample_score
            # print(sample_score)
            # score_all[i, tail_ent_ids] = sample_score

            time_2 = datetime.now()
            # print(tail_ent_num, (time_1 - start_time) / tail_ent_num, time_2 - time_1)
            # print(count, tail_ent_num)

            # tail_entity_vectors = []  # 存储聚合关系向量的列表
            # rules = self.rules_dict[str(r)]  # 查询关系对应的规则字典列表

            # tail_ent_ids = list(tail_ent_list[i].keys())
            # tail_ent_ids = [int(_) for _ in tail_ent_ids]
            # tail_ent_ids = torch.LongTensor(tail_ent_ids).cuda()
            # tail_ent_id_ls.append(tail_ent_ids)
            # count = 0
            # for tail_entity_id, rule_ids in tail_ent_list[i].items():
            #     # 初始化一个用于存储每条规则编码结果的列表
            #     rule_encodings = []

            #     if len(rule_ids) <= 3:
            #         count += 1
            #     for rule_id in rule_ids:
            #         rule = np.array(rules[rule_id]['body_rels'])  # 规则对应的关系链数组
            #         rel_chain_embedding = emb_rel[rule]
            #         _, rule_encoding = self.gru(rel_chain_embedding, r_emb.unsqueeze(0))  # GRU
            #         rule_encoding = F.normalize(rule_encoding.squeeze(0), dim=0)  # 获取最后时刻隐藏状态作为规则编码
            #         rule_encodings.append(rule_encoding)

            #     # 对支持该尾实体的所有规则编码求平均得到聚合关系向量
            #     aggregated_vector = torch.mean(torch.stack(rule_encodings), dim=0)
            #     aggregated_vector = F.normalize(aggregated_vector, dim=0)

            #     tail_entity_vectors.append(aggregated_vector)

            # time_1 = datetime.now()

            # # 对空列表的处理
            # if len(tail_entity_vectors) == 0:
            #     score_ls.append(torch.Tensor([]).cuda())
            #     continue

            # aggregated_tensor = torch.stack(tail_entity_vectors)  # tail_ent_reachable * h_dim
            # tail_ent_num = aggregated_tensor.shape[0]

        if args.score_method == 'all_plus':
            score_all += score_rule
        elif args.score_method == 'limited_plus':
            score_mask = score_rule > 0
            score_all = score_all * score_mask + score_rule
        elif args.score_method == 'limited_new':
            score_all = score_rule
        elif args.score_method == 'atten_plus':
            if args.weight_score_learnable:
                score_atten = F.sigmoid(self.score_atten)
            else:
                score_atten = args.score_atten
            score_mask = score_rule > 0
            score_all = score_all * ~score_mask + score_atten * score_all * score_mask + (1 - score_atten) * score_rule

        if partial_embeding is not None:
            if fre_norm:
                partial_embeding = F.normalize(partial_embeding)
            score_all = torch.mul(score_all, partial_embeding)

        return score_all

    def forward_no_batch(self, args, embedding, emb_rel, emb_time, triplets, nodes_id=None, mode="train", negative_rate=0, partial_embeding=None, fre_norm=False):
        e1_embedded_all = F.tanh(embedding)
        batch_size = len(triplets)
        e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)  # batch_size,1,h_dim
        rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)  # batch_size,1,h_dim
        emb_time_1, emb_time_2 = emb_time
        emb_time_1 = emb_time_1.unsqueeze(1)
        emb_time_2 = emb_time_2.unsqueeze(1)

        time = triplets[0, 3]
        if mode == 'train':
            tail_ent_list = self.cands_with_rules_dict_train[str(time.item())]
        elif mode == 'valid':
            tail_ent_list = self.cands_with_rules_dict_valid[str(time.item())]
        else:
            tail_ent_list = self.cands_with_rules_dict_test[str(time.item())]

        score_ls = []  # 每个样本对应一个一维评分张量
        tail_ent_id_ls = []  # 每个样本对应一个一维id张量
        # 遍历每个样本
        for i, triplet in enumerate(triplets):
            start_time = datetime.now()

            h, r, _, _ = triplet.tolist()
            h_emb = e1_embedded[i, 0]
            r_emb = rel_embedded[i, 0]
            tail_entity_vectors = []  # 存储聚合关系向量的列表
            rules = self.rules_dict[str(r)]  # 查询关系对应的规则字典列表

            tail_ent_ids = list(tail_ent_list[i].keys())
            tail_ent_ids = [int(_) for _ in tail_ent_ids]
            tail_ent_ids = torch.LongTensor(tail_ent_ids).cuda()
            tail_ent_id_ls.append(tail_ent_ids)
            count = 0
            for tail_entity_id, rule_ids in tail_ent_list[i].items():
                # 初始化一个用于存储每条规则编码结果的列表
                rule_encodings = []

                if len(rule_ids) <= 3:
                    count += 1
                for rule_id in rule_ids:
                    rule = np.array(rules[rule_id]['body_rels'])  # 规则对应的关系链数组
                    rel_chain_embedding = emb_rel[rule]
                    _, rule_encoding = self.gru(rel_chain_embedding, r_emb.unsqueeze(0))  # GRU
                    rule_encoding = F.normalize(rule_encoding.squeeze(0), dim=0)  # 获取最后时刻隐藏状态作为规则编码
                    rule_encodings.append(rule_encoding)

                # 对支持该尾实体的所有规则编码求平均得到聚合关系向量
                aggregated_vector = torch.mean(torch.stack(rule_encodings), dim=0)
                aggregated_vector = F.normalize(aggregated_vector, dim=0)

                tail_entity_vectors.append(aggregated_vector)

            time_1 = datetime.now()

            # 对空列表的处理
            if len(tail_entity_vectors) == 0:
                score_ls.append(torch.Tensor([]).cuda())
                continue

            aggregated_tensor = torch.stack(tail_entity_vectors)  # tail_ent_reachable * h_dim
            tail_ent_num = aggregated_tensor.shape[0]
            h_emb_repeat = h_emb.unsqueeze(0).repeat_interleave(tail_ent_num, dim=0)
            assert h_emb_repeat.shape == aggregated_tensor.shape  # 确保大小一致

            # 计算能到达的尾实体个数个评分
            time_emb_1_sample = emb_time_1[i].repeat_interleave(tail_ent_num, dim=0).unsqueeze(1)
            time_emb_2_sample = emb_time_2[i].repeat_interleave(tail_ent_num, dim=0).unsqueeze(1)
            query_repr = self.get_query_repr(h_emb_repeat.unsqueeze(1), aggregated_tensor.unsqueeze(1), time_emb_1_sample, time_emb_2_sample, aggregated_tensor.shape[0])
            tail_ent_emb = e1_embedded_all[tail_ent_ids]
            sample_score = (query_repr * tail_ent_emb).sum(1).cuda()
            score_ls.append(sample_score)

            time_2 = datetime.now()
            # print(tail_ent_num, (time_1 - start_time) / tail_ent_num, time_2 - time_1)
            print(count, tail_ent_num)

        time_3 = datetime.now()

        query_repr_all = self.get_query_repr(e1_embedded, rel_embedded, emb_time_1, emb_time_2, batch_size)
        score_all = torch.mm(query_repr_all, e1_embedded_all.transpose(1, 0))

        assert len(score_ls) == len(tail_ent_id_ls) == batch_size
        # 将规则能到达的尾实体的分数替换
        time_11 = datetime.now()
        for i in range(batch_size):
            score_all[i, tail_ent_id_ls[i]] = score_ls[i]
        time_12 = datetime.now()
        # print(time_12 - time_11)
        if partial_embeding is not None:
            if fre_norm:
                partial_embeding = F.normalize(partial_embeding)
            score_all = torch.mul(score_all, partial_embeding)

        time_4 = datetime.now()
        # print(time_4 - time_3)

        return score_all

    def forward_slow(self, embedding, emb_rel, triplets):

        e1_embedded_all = F.tanh(embedding)
        batch_size = len(triplets)
        e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)
        rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)
        stacked_inputs = torch.cat([e1_embedded, rel_embedded], 1)
        stacked_inputs = self.bn0(stacked_inputs)
        x = self.inp_drop(stacked_inputs)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        x = self.hidden_drop(x)
        if batch_size > 1:
            x = self.bn2(x)
        x = F.relu(x)
        e2_embedded = e1_embedded_all[triplets[:, 2]]
        score = torch.sum(torch.mul(x, e2_embedded), dim=1)
        pred = score
        return pred


class RuleConvTransE(torch.nn.Module):
    def __init__(self, args, num_entities, embedding_dim, input_dropout=0, hidden_dropout=0, feature_map_dropout=0, mode='train', channels=50, kernel_size=3, use_bias=True):

        super(RuleConvTransE, self).__init__()

        self.inp_drop = torch.nn.Dropout(input_dropout)
        self.hidden_drop = torch.nn.Dropout(hidden_dropout)
        self.feature_map_drop = torch.nn.Dropout(feature_map_dropout)
        self.loss = torch.nn.BCELoss()

        if args.time_emb:
            input_cha = 4
        else:
            input_cha = 2
        self.conv1 = torch.nn.Conv1d(input_cha, channels, kernel_size, stride=1,
                               padding=int(math.floor(kernel_size / 2)))
        self.bn0 = torch.nn.BatchNorm1d(input_cha)
        self.bn1 = torch.nn.BatchNorm1d(channels)
        self.bn2 = torch.nn.BatchNorm1d(embedding_dim)
        self.register_parameter('b', Parameter(torch.zeros(num_entities)))
        self.fc = torch.nn.Linear(embedding_dim * channels, embedding_dim)
        self.h_dim = args.n_hidden
        self.args = args

        self.gru = GRU(input_size=embedding_dim, hidden_size=embedding_dim, batch_first=True)
        self.cands_with_rules_dict_train = {}
        self.cands_with_rules_dict_valid = {}
        self.cands_with_rules_dict_test = {}
        if args.cands_type == 'with_rules':
            # file_dir = f'../output_rule/{args.dataset}/seed{args.seed}/cands_with_rules'
            file_dir = f'..output_rule/{args.dataset}/seed{args.seed}/cands_with_rules'
        elif args.cands_type == 'evolve':
            file_dir = f'../output_rule/{args.dataset}/seed{args.seed}/cands_evolve'

        # 读取文件
        train_name = f'{args.cands_with_rules_file}_train.pt'
        train_path = os.path.join(file_dir, train_name)
        neg_used = f'neg{args.negative}'
        cands_with_rules_file_neg_all = args.cands_with_rules_file.replace(neg_used, 'neg-1')
        valid_name = f'{cands_with_rules_file_neg_all}_valid.pt'
        valid_path = os.path.join(file_dir, valid_name)
        test_name = f'{cands_with_rules_file_neg_all}_test.pt'
        test_path = os.path.join(file_dir, test_name)

        # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        # 训练需要加载三个文件，测试只需要加载一个文件
        if mode == 'train':
        #     self.cands_with_rules_dict_train = torch.load(train_path, map_location=device)
        #     self.cands_with_rules_dict_valid = torch.load(valid_path, map_location=device)
        # self.cands_with_rules_dict_test = torch.load(test_path, map_location=device)
            self.cands_with_rules_dict_train = torch.load(train_path)
            self.cands_with_rules_dict_valid = torch.load(valid_path)
        self.cands_with_rules_dict_test = torch.load(test_path)
        if args.cands_type == 'with_rules':
            rule_ten_ls_file = f'{args.rule_ten_ls_file}'
            rule_ten_ls_path = os.path.join(file_dir, rule_ten_ls_file)
            # self.rule_ten_ls = torch.load(rule_ten_ls_path, map_location=device)
            self.rule_ten_ls = torch.load(rule_ten_ls_path)

        # 对聚合规则表征和关系表征作加权
        if args.atten_matr_num > 0:
            self.weight_rule = nn.Parameter(torch.Tensor(self.h_dim, self.h_dim))
            nn.init.xavier_uniform_(self.weight_rule, gain=nn.init.calculate_gain('relu'))
            self.activation = F.rrelu
            self.dropout = nn.Dropout(args.dropout)
            if args.atten_matr_num > 1:
                self.weight_rel = nn.Parameter(torch.Tensor(self.h_dim, self.h_dim))
                nn.init.xavier_uniform_(self.weight_rel, gain=nn.init.calculate_gain('relu'))

        if args.weight_score_learnable:
            self.score_atten = nn.Parameter(torch.tensor(args.score_atten).float(), requires_grad=True)

        if args.weight_time_learnable:
            self.lam_gru = nn.Parameter(torch.tensor(args.lam_gru).float(), requires_grad=True)
            self.lam_rule = nn.Parameter(torch.tensor(args.lam_rule).float(), requires_grad=True)
            self.sim_atten = nn.Parameter(torch.tensor(args.sim_atten).float(), requires_grad=True)

        if args.rule_atten == 'multi_head':
            self.multi_head_linear = nn.Linear(self.h_dim * 2, self.h_dim)
        # elif args.rule_atten == 'multi_head_weight':
        #     self.multi_head_weight = nn.Parameter(torch.Tensor(self.h_dim * 2, self.h_dim))

        if 'general' in args.semantic_sim:
            self.general_linear = nn.Linear(self.h_dim, self.h_dim)

    def get_query_repr(self, ent_emb, rel_emb, time_emb_1, time_emb_2, batch_size):
        if self.args.time_emb:
            stacked_inputs = torch.cat([ent_emb, rel_emb, time_emb_1, time_emb_2], 1)  # batch_size,2,h_dim
        else:
            stacked_inputs = torch.cat([ent_emb, rel_emb], 1)  # 不使用时间函数
        # stacked_inputs = torch.cat([e1_embedded, rel_embedded, emb_time_1], 1)  # 不带周期
        # stacked_inputs = torch.cat([e1_embedded, rel_embedded, emb_time_2], 1)  # 带周期
        stacked_inputs = self.bn0(stacked_inputs)  # batch_size,2,h_dim
        x = self.inp_drop(stacked_inputs)  # batch_size,2,h_dim
        x = self.conv1(x)  # batch_size,2,h_dim
        x = self.bn1(x)  # batch_size,channels,h_dim
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(batch_size, -1)  # batch_size,channels*h_dim
        x = self.fc(x)  # batch_size,channels*h_dim
        x = self.hidden_drop(x)  # batch_size,h_dim
        if batch_size > 1:
            x = self.bn2(x)
        x = F.relu(x)

        return x

    def forward(self, args, embedding, emb_rel, emb_time, triplets, batch_gpu, nodes_id=None, mode="train", negative_rate=0, partial_embeding=None, fre_norm=False):
        device = emb_rel.device
        if self.args.cands_type == 'with_rules':
            pseudo_rel_emb = torch.zeros((1, args.n_hidden), requires_grad=False).cuda(device)  # 伪关系嵌入，padding长度不足的关系
            extend_rel_emb = torch.cat((emb_rel, pseudo_rel_emb), dim=0)
            rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)  # batch_size,1,h_dim
        elif self.args.cands_type == 'evolve':
            seq_len = emb_rel.shape[0]
            pseudo_rel_emb = torch.zeros((seq_len, 1, args.n_hidden), requires_grad=False).cuda(device)  # 伪关系嵌入，padding长度不足的关系
            extend_rel_emb = torch.cat((emb_rel, pseudo_rel_emb), dim=1)
            rel_embedded = emb_rel[-1][triplets[:, 1]].unsqueeze(1)  # batch_size,1,h_dim

        e1_embedded_all = F.tanh(embedding)
        batch_size = len(triplets)
        e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)  # batch_size,1,h_dim
        e2_embedded = e1_embedded_all[triplets[:, 2]].unsqueeze(1)
        # rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)  # batch_size,1,h_dim
        emb_time_1, emb_time_2 = emb_time
        emb_time_1 = emb_time_1.unsqueeze(1)
        emb_time_2 = emb_time_2.unsqueeze(1)
        ent_num = e1_embedded_all.shape[0]

        # 利用ConvTransE每个样本对所有尾实体的评分
        # query_repr_all = self.get_query_repr(e1_embedded, rel_embedded, emb_time_1, emb_time_2, batch_size)
        # score_all = torch.mm(query_repr_all, e1_embedded_all.transpose(1, 0))
        score_rule = torch.zeros(batch_size, ent_num).cuda(device)

        time = triplets[0, 3].item()
        # 得到对应时间戳元组列表，包括能达到的尾实体一维id张量和01二维张量
        if mode == 'train':
            tail2rule_list = self.cands_with_rules_dict_train[time]
        elif mode == 'valid':
            tail2rule_list = self.cands_with_rules_dict_valid[time]
        else:
            tail2rule_list = self.cands_with_rules_dict_test[time]

        # 遍历每个样本
        for i, triplet in enumerate(triplets):
            start_time = datetime.now()

            h, r, _, _ = triplet.tolist()
            h_emb = e1_embedded[i, 0]
            r_emb = rel_embedded[i, 0]
            t_emb = e2_embedded[i, 0]

            device_num = torch.cuda.current_device()

            # if device == torch.device('cuda:0'):  # 给特定的gpu加上一定偏移量
            #     idx = i
            # else:
            #     idx = i + len(triplets)
            idx = i + device_num * batch_gpu

            if len(tail2rule_list[idx]) == 0:  # 可到达的尾实体数为0
                continue

            if args.cands_type == 'with_rules':
                # 得到每个规则对应的关系嵌入，初始化隐藏表征，再过gru，得到演化规则张量
                if len(self.rule_ten_ls[r]) == 0:  # 关系对应规则数为0
                    continue
                rule_tensor, rule_id = self.rule_ten_ls[r]  # rule_num * rule_len
                rule_tensor = rule_tensor.cuda(device)
                rule_id = rule_id.cuda(device)
                rule_num, rule_len = rule_tensor.shape
                indices_times = (torch.arange(len(rule_id)) * rule_len).cuda(device)
                extend_rule_id = rule_id + indices_times
                rule_tensor = rule_tensor.reshape(-1)
                rule2rel_emb = extend_rel_emb[rule_tensor]  # (rule_num * rule_len) * h_dim
                rule2rel_emb = rule2rel_emb.reshape(rule_num, rule_len, -1)
                if args.gru:
                    if args.gru_init == 'zero':
                        init_hidden_emb = torch.zeros(1, rule_num, self.h_dim).cuda(device)
                    elif args.gru_init == 'rel':
                        init_hidden_emb = (r_emb.unsqueeze(0).repeat_interleave(rule_num, 0)).unsqueeze(0)  # 1 * rule_num * h_dim，使用关系作初始化隐藏嵌入
                    elif args.gru_init == 'rand':
                        init_hidden_emb = torch.rand(1, rule_num, self.h_dim).cuda(device)
                        init_hidden_emb = F.normalize(init_hidden_emb)

                    all_len_emb, _ = self.gru(rule2rel_emb, init_hidden_emb)
                    all_len_emb = all_len_emb.reshape(-1, args.n_hidden)  # 扩展为(rule_num * rule_len) * h_dim

                    if args.gru_atten != '':
                        mask_template = torch.tril(torch.ones(rule_len, rule_len), diagonal=0).cuda()  # 将主对角线及下三角的元素设为1
                        mask_gru = mask_template[rule_id]  # 大小为rule_num * rule_len的二维gru掩码张量
                        if args.gru_atten == 'mean':
                            mask_gru = mask_gru / mask_gru.sum(1).reshape(-1, 1)  # 除以长度求平均
                        # elif args.gru_atten == 'time':
                        #     if args.weight_time_learnable:
                        #         lam_gru = F.sigmoid(self.lam_gru)
                        #     else:
                        #         lam_gru = args.lam_gru
                        #     tss_info_all_exp = lam_gru * torch.exp(tss_info_all - args.train_history_len)
                        #     mask_gru = tss_info_all_exp.reshape(tail_ent_num * rule_used, rule_len) * mask_gru  # 先乘再padding成负无穷
                        #     mask_gru = torch.where(mask_gru == 0, -1e9, mask_gru)
                        #     mask_gru = F.softmax(mask_gru)
                        mask_gru = mask_gru.flatten().reshape(-1, 1)
                        rule_emb = (all_len_emb * mask_gru).reshape(rule_num, rule_len, -1)
                        rule_emb = rule_emb.sum(1)  # rule_num * h_dim的二维规则张量
                    else:
                        rule_emb = all_len_emb[extend_rule_id]  # 得到每个规则对应长度的表征
                    rule_emb = F.normalize(rule_emb)

                else:  # 简单的用规则关系链的关系表征平均值作为规则表征
                    rule_len_ls = extend_rule_id.tolist()
                    rule_len_ls = [idx + 1 - rule_len * i for i, idx in enumerate(rule_len_ls)]
                    rule_len_tensor = torch.Tensor(rule_len_ls).reshape(-1, 1).cuda()
                    rule_emb = rule2rel_emb.sum(1) / rule_len_tensor
                    rule_emb = F.normalize(rule_emb)

                # 用样本对应的01张量和演化规则张量得到大小为tail_ent_reachable * h_dim的聚合规则表征张量
                # tail_ent_ids, tail2rule = tail2rule_list[i]
                # tail2rule = tail2rule.cuda()
                if not args.time_file:
                    tail_ent_ids, tail2rule_sorted = tail2rule_list[idx]
                else:
                    tail_ent_ids, tail2rule_sorted, tss_info_all, rel_len_info_all_ori, mask_info_all = tail2rule_list[idx]
                    tss_info_all = tss_info_all.cuda()
                    rel_len_info_all_ori = rel_len_info_all_ori.cuda()
                    mask_info_all = mask_info_all.cuda()
                    tail_ent_num, rule_used, rule_len = tss_info_all.shape  # 这里的rule_used是指生成文件时设定的规则使用数量

                    if args.rule_used > -1 and args.rule_used < rule_used:
                        rel_len_info_all_ori = rel_len_info_all_ori.reshape(tail_ent_num, rule_used)
                        mask_info_all = mask_info_all.reshape(tail_ent_num, rule_used)
                        rel_len_info_all_ori = rel_len_info_all_ori[:, :args.rule_used].reshape(-1)
                        mask_info_all = mask_info_all[:, :args.rule_used].reshape(-1)
                        # rels_info_all = rels_info_all[:, :args.rule_used]
                        tss_info_all = tss_info_all[:, :args.rule_used]
                        # rule_used = args.rule_used
                    length = len(rel_len_info_all_ori)
                    indices_times = (torch.arange(length) * rule_len).cuda()  # 加上索引的偏移量
                    rel_len_info_all = rel_len_info_all_ori + indices_times
                    tss_info_all = tss_info_all.reshape(-1, rule_len)  # (tail_reachable * rule_used) * rule_len二维张量
                tail_ent_ids = tail_ent_ids.cuda()
                tail2rule_sorted = tail2rule_sorted.cuda()
                tail_ent_num = tail_ent_ids.shape[0]

                tail2rule = torch.zeros_like(tail2rule_sorted).cuda()
                if args.rule_used > -1 and args.rule_used < rule_used:
                    tail2rule_sorted = tail2rule_sorted[:, :args.rule_used]
                    rule_used = args.rule_used
                tail2rule[torch.arange(tail_ent_num).unsqueeze(1).long(), tail2rule_sorted] = 1

                # if self.args.semantic_sim != '':  # 计算关系和规则的相似度，代替01张量中1的位置，对规则表征做加权
                #     if self.args.semantic_sim == 'head':
                #         sim_rel2rule = h_emb @ rule_emb.transpose(0, 1)
                #     elif self.args.semantic_sim == 'rel':
                #         sim_rel2rule = r_emb @ rule_emb.transpose(0, 1)
                #     elif self.args.semantic_sim == 'plus':
                #         sim_rel2rule = (h_emb + r_emb) @ rule_emb.transpose(0, 1)
                #     elif self.args.semantic_sim == 'cosine':
                #         sim_rel2rule = F.cosine_similarity(r_emb.unsqueeze(0), rule_emb)
                #     elif self.args.semantic_sim == 'euclidean':
                #         sim_rel2rule = torch.norm(rule_emb - r_emb.unsqueeze(0), dim=1)
                #     sim_rel2rule = sim_rel2rule.softmax(0)
                #     sim_tensor = sim_rel2rule.unsqueeze(0).repeat_interleave(tail_ent_num, 0)
                #     tail2rule_normalized = tail2rule * sim_tensor
                # else:
                #     tail2rule_normalized = tail2rule / tail2rule.sum(1).reshape(-1, 1)
                # tail2rule_emb = torch.mm(tail2rule_normalized, rule_emb)
                # tail2rule_emb = F.normalize(tail2rule_emb)

                # 取对应规则
                _, rule_num = tail2rule_sorted.shape
                if rule_num < args.rule_used:
                    tail2rule_sorted = F.pad(tail2rule_sorted, (0, args.rule_used - rule_num, 0, 0), 'constant', 0)
                tail2rule_sorted = tail2rule_sorted.reshape(-1)  # tail_reachable * rule_used一维张量
                if self.args.rule_atten != '':
                    rule_emb = rule_emb[tail2rule_sorted]  # (tail_reachable * rule_used) * h_dim 二维张量
                    if self.args.semantic_sim != '':  # 计算关系和规则的相似度，代替01张量中1的位置，对规则表征做加权
                        if self.args.semantic_sim == 'head':
                            sim_rel2rule_semantic = h_emb @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'rel':
                            sim_rel2rule_semantic = r_emb @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'tail':
                            t_embs = e1_embedded_all[tail_ent_ids]  # tail_reachable * h_dim 二维张量
                            tail_reachable, h_dim = t_embs.shape
                            # t_embs_extended = self.general_linear(t_embs).unsqueeze(1)  # tail_reachable * 1 * h_dim 三维张量
                            t_embs_extended = t_embs.unsqueeze(1)
                            rule_emb_extended = rule_emb.reshape(tail_reachable, -1, h_dim)  # tail_reachable * rule_used * h_dim 三维张量
                            sim_rel2rule_semantic = (rule_emb_extended * t_embs_extended).sum(2).reshape(-1)  # tail_reachable * rule_used 一维张量
                            # sim_rel2rule_semantic = t_emb @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'plus':
                            sim_rel2rule_semantic = (h_emb + r_emb) @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'cosine':
                            sim_rel2rule_semantic = F.cosine_similarity(r_emb.unsqueeze(0), rule_emb)
                        elif self.args.semantic_sim == 'euclidean':
                            sim_rel2rule_semantic = torch.norm(rule_emb - r_emb.unsqueeze(0), dim=1)
                        elif self.args.semantic_sim == 'general_rel':
                            sim_rel2rule_semantic = self.general_linear(r_emb) @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'general_head':
                            sim_rel2rule_semantic = self.general_linear(h_emb) @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'general_tail':
                            t_embs = e1_embedded_all[tail_ent_ids]  # tail_reachable * h_dim 二维张量
                            tail_reachable, h_dim = t_embs.shape
                            t_embs_extended = self.general_linear(t_embs).unsqueeze(1)  # tail_reachable * 1 * h_dim 三维张量
                            rule_emb_extended = rule_emb.reshape(tail_reachable, -1, h_dim)  # tail_reachable * rule_used * h_dim 三维张量
                            sim_rel2rule_semantic = (rule_emb_extended * t_embs_extended).sum(2).reshape(-1)  # tail_reachable * rule_used 一维张量
                            # sim_rel2rule_semantic = self.general_linear(t_emb) @ rule_emb.transpose(0, 1)
                        sim_rel2rule_semantic = sim_rel2rule_semantic * mask_info_all
                    if self.args.time_sim != '':
                        latest_time = args.window
                        # if args.evolve_type == 'dynamic':
                        #     latest_time = args.train_history_len
                        # elif args.evolve_type == 'static':
                        #     latest_time = args.window
                        if args.weight_time_learnable:
                            lam_rule = F.sigmoid(self.lam_rule)
                        else:
                            lam_rule = args.lam_rule
                        if 'earliest' in self.args.time_sim:
                            tss_info_used = tss_info_all[:, 0]  # 所有适用的规则使用的最早时间戳
                        elif 'average' in self.args.time_sim:
                            mask_template = torch.tril(torch.ones(rule_len, rule_len), diagonal=0).cuda()  # 将主对角线及下三角的元素设为1
                            # mask_template = torch.where(mask_template == 0, -1, mask_template)  # 用-1作掩码，防止原本时间戳均为0的情况
                            mask_rel = mask_template[rel_len_info_all_ori]  # 大小为(tail_ent_num * rule_used) * rule_len的二维gru掩码张量
                            tss_info_used = tss_info_all * mask_rel  # 掩码后时间戳
                            tss_info_used = tss_info_used.sum(1) / (mask_rel != 0).sum(1)  # 有效时间戳的平均长度，时间戳为负数表示要被要被掩盖
                        if 'exp' in self.args.time_sim:
                            sim_rel2rule_time = torch.exp(lam_rule * (tss_info_used - latest_time))
                            # if self.args.time_fun_origin:
                            #     sim_rel2rule_time = torch.exp(lam_rule * (tss_info_used - latest_time))
                            # else:
                            #     sim_rel2rule_time = lam_rule * torch.exp(tss_info_used - latest_time)
                        elif 'tan' in self.args.time_sim:
                            sim_rel2rule_time = lam_rule * torch.tan(tss_info_used - latest_time)
                            # sim_rel2rule_time = torch.tan(lam_rule * (tss_info_used - latest_time))
                        sim_rel2rule_time = sim_rel2rule_time * mask_info_all
                    if self.args.rule_atten == 'semantic':
                        sim_rel2rule_semantic = torch.where(sim_rel2rule_semantic == 0, -1e9, sim_rel2rule_semantic)
                        sim_rel2rule = F.softmax(sim_rel2rule_semantic.reshape(tail_ent_num, rule_used)).reshape(-1, 1)
                        # sim_rel2rule = F.softmax(sim_rel2rule_semantic).reshape(-1, 1)
                    elif self.args.rule_atten == 'time':
                        sim_rel2rule_time = torch.where(sim_rel2rule_time == 0, -1e9, sim_rel2rule_time)
                        sim_rel2rule = F.softmax(sim_rel2rule_time.reshape(tail_ent_num, rule_used)).reshape(-1, 1)
                    elif self.args.rule_atten == 'fusion':
                        sim_rel2rule = torch.exp(sim_rel2rule_semantic) * sim_rel2rule_time
                        sim_rel2rule = torch.where(sim_rel2rule == 0, -1e9, sim_rel2rule)
                        sim_rel2rule = F.softmax(sim_rel2rule.reshape(tail_ent_num, rule_used)).reshape(-1, 1)
                    elif self.args.rule_atten == 'weighted':
                        sim_rel2rule_semantic = torch.where(sim_rel2rule_semantic == 0, -1e9, sim_rel2rule_semantic)
                        sim_rel2rule_time = torch.where(sim_rel2rule_time == 0, -1e9, sim_rel2rule_time)
                        sim_rel2rule_semantic = F.softmax(sim_rel2rule_semantic.reshape(tail_ent_num, rule_used))
                        sim_rel2rule_time = F.softmax(sim_rel2rule_time.reshape(tail_ent_num, rule_used))
                        if args.weight_time_learnable:
                            sim_atten = F.sigmoid(self.sim_atten)
                        else:
                            sim_atten = args.sim_atten
                        sim_rel2rule = (sim_atten * sim_rel2rule_semantic + (1 - sim_atten) * sim_rel2rule_time).reshape(-1, 1)
                    if self.args.rule_atten != 'multi_head':
                        rule_emb = (rule_emb * sim_rel2rule).reshape(tail_ent_num, rule_used, -1)
                        tail2rule_emb = rule_emb.sum(1)  # tail_ent_num * h_dim二维张量
                        tail2rule_emb = F.normalize(tail2rule_emb)
                    else:  # 多头
                        sim_rel2rule_semantic = torch.where(sim_rel2rule_semantic == 0, -1e9, sim_rel2rule_semantic)
                        sim_rel2rule_time = torch.where(sim_rel2rule_time == 0, -1e9, sim_rel2rule_time)
                        sim_rel2rule_semantic = F.softmax(sim_rel2rule_semantic.reshape(tail_ent_num, rule_used)).reshape(-1, 1)
                        sim_rel2rule_time = F.softmax(sim_rel2rule_time.reshape(tail_ent_num, rule_used)).reshape(-1, 1)
                        rule_emb_semantic = (rule_emb * sim_rel2rule_semantic).reshape(tail_ent_num, rule_used, -1)
                        tail2rule_emb_semantic = rule_emb_semantic.sum(1)  # tail_ent_num * h_dim二维张量
                        rule_emb_time = (rule_emb * sim_rel2rule_time).reshape(tail_ent_num, rule_used, -1)
                        tail2rule_emb_time = rule_emb_time.sum(1)  # tail_ent_num * h_dim二维张量
                        tail2rule_emb_cat = torch.cat((tail2rule_emb_semantic, tail2rule_emb_time), dim=1)
                        tail2rule_emb = self.multi_head_linear(tail2rule_emb_cat)
                        tail2rule_emb = F.normalize(tail2rule_emb)
                else:
                    tail2rule_normalized = tail2rule / tail2rule.sum(1).reshape(-1, 1)
                    tail2rule_emb = torch.mm(tail2rule_normalized, rule_emb)
                    # tail2rule = tail2rule.to(torch.float32)
                    # tail2rule_emb = torch.mm(tail2rule, rule_emb)
                    tail2rule_emb = F.normalize(tail2rule_emb)

            elif args.cands_type == 'evolve':
                tail_ent_ids, rels_info_all, tss_info_all, rel_len_info_all_ori, mask_info_all = tail2rule_list[i]
                tail_ent_ids = tail_ent_ids.cuda()
                rels_info_all = rels_info_all.cuda()
                tss_info_all = tss_info_all.cuda()
                rel_len_info_all_ori = rel_len_info_all_ori.cuda()
                mask_info_all = mask_info_all.cuda()
                tail_ent_num, rule_used, rule_len = rels_info_all.shape
                # 根据使用规则数量对张量做剪切
                if args.rule_used > -1 and args.rule_used < rule_used:
                    rel_len_info_all_ori = rel_len_info_all_ori.reshape(tail_ent_num, rule_used)
                    mask_info_all = mask_info_all.reshape(tail_ent_num, rule_used)
                    rel_len_info_all_ori = rel_len_info_all_ori[:, :args.rule_used].reshape(-1)
                    mask_info_all = mask_info_all[:, :args.rule_used].reshape(-1)
                    rels_info_all = rels_info_all[:, :args.rule_used]
                    tss_info_all = tss_info_all[:, :args.rule_used]
                    rule_used = args.rule_used
                length = len(rel_len_info_all_ori)
                indices_times = (torch.arange(length) * rule_len).cuda()  # 加上索引的偏移量
                rel_len_info_all = rel_len_info_all_ori + indices_times

                # 得到大小为(tail_reachable * rule_num) * rule_len * h_dim的三维关系张量
                rels_info_all = rels_info_all.reshape(-1, rule_len)  # (tail_reachable * rule_num) * rule_len二维张量
                tss_info_all = tss_info_all.reshape(-1, rule_len)
                if args.evolve_type == 'dynamic':
                    rule2rel_emb = extend_rel_emb[tss_info_all, rels_info_all]
                # elif args.evolve_type == 'static':  # 静态设置只使用最近时间戳的关系嵌入
                #     time_padding = min(time, args.train_history_len)
                #     time_padding -= 1
                #     tss_info_all_latest = tss_info_all.clone().fill_(time_padding)
                #     # print(tss_info_all_latest)
                #     # sys.exit()
                #     rule2rel_emb = extend_rel_emb[tss_info_all_latest, rels_info_all]

                # 关系张量过gru，利用规则id张量，得到大小为(tail_ent_num * rule_used) * h_dim的规则张量
                if args.gru_init == 'zero':
                    init_hidden_emb = torch.zeros(1, tail_ent_num * rule_used, self.h_dim).cuda()
                elif args.gru_init == 'rel':
                    init_hidden_emb = (r_emb.unsqueeze(0).repeat_interleave(tail_ent_num * rule_used, 0)).unsqueeze(0)
                elif args.gru_init == 'rand':
                    init_hidden_emb = torch.rand(1, tail_ent_num * rule_used, self.h_dim).cuda()
                    init_hidden_emb = F.normalize(init_hidden_emb)

                all_len_emb, _ = self.gru(rule2rel_emb, init_hidden_emb)
                all_len_emb = all_len_emb.reshape(-1, args.n_hidden)  # 扩展为(tail_ent_num * rule_used * rule_len) * h_dim
                if args.gru_atten != '':
                    mask_template = torch.tril(torch.ones(rule_len, rule_len), diagonal=0).cuda()  # 将主对角线及下三角的元素设为1
                    mask_gru = mask_template[rel_len_info_all_ori]  # 大小为(tail_ent_num * rule_used) * rule_len的二维gru掩码张量
                    if args.gru_atten == 'mean':
                        mask_gru = mask_gru / mask_gru.sum(1).reshape(-1, 1)  # 除以长度求平均
                    elif args.gru_atten == 'time':
                        if args.weight_time_learnable:
                            lam_gru = F.sigmoid(self.lam_gru)
                        else:
                            lam_gru = args.lam_gru
                        tss_info_all_exp = lam_gru * torch.exp(tss_info_all - args.train_history_len)
                        mask_gru = tss_info_all_exp.reshape(tail_ent_num * rule_used, rule_len) * mask_gru  # 先乘再padding成负无穷
                        mask_gru = torch.where(mask_gru == 0, -1e9, mask_gru)
                        mask_gru = F.softmax(mask_gru)
                    mask_gru = mask_gru.flatten().reshape(-1, 1)
                    rule_emb = (all_len_emb * mask_gru).reshape(tail_ent_num * rule_used, rule_len, -1)
                    rule_emb = rule_emb.sum(1)  # (tail_ent_num * rule_used) * h_dim的二维规则张量
                else:
                    rule_emb = all_len_emb[rel_len_info_all]  # 得到每个规则对应长度的表征
                rule_emb = F.normalize(rule_emb)

                # 计算关系和规则的相似度，将padding规则的评分掩盖，再对非0评分softmax，对规则张量加权，最后池化为聚合规则表征张量
                if self.args.rule_atten != '':
                    if self.args.semantic_sim != '':  # 计算关系和规则的相似度，代替01张量中1的位置，对规则表征做加权
                        if self.args.semantic_sim == 'head':
                            sim_rel2rule_semantic = h_emb @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'rel':
                            sim_rel2rule_semantic = r_emb @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'plus':
                            sim_rel2rule_semantic = (h_emb + r_emb) @ rule_emb.transpose(0, 1)
                        elif self.args.semantic_sim == 'cosine':
                            sim_rel2rule_semantic = F.cosine_similarity(r_emb.unsqueeze(0), rule_emb)
                        elif self.args.semantic_sim == 'euclidean':
                            sim_rel2rule_semantic = torch.norm(rule_emb - r_emb.unsqueeze(0), dim=1)
                        sim_rel2rule_semantic = sim_rel2rule_semantic * mask_info_all
                    if self.args.time_sim != '':
                        latest_time = args.train_history_len
                        # if args.evolve_type == 'dynamic':
                        #     latest_time = args.train_history_len
                        # elif args.evolve_type == 'static':
                        #     latest_time = args.window
                        if args.weight_time_learnable:
                            lam_rule = F.sigmoid(self.lam_rule)
                        else:
                            lam_rule = args.lam_rule
                        if 'earliest' in self.args.time_sim:
                            tss_info_used = tss_info_all[:, 0]  # 所有适用的规则使用的最早时间戳
                        elif 'average' in self.args.time_sim:
                            mask_template = torch.tril(torch.ones(rule_len, rule_len), diagonal=0).cuda()  # 将主对角线及下三角的元素设为1
                            # mask_template = torch.where(mask_template == 0, -1, mask_template)  # 用-1作掩码，防止原本时间戳均为0的情况
                            mask_rel = mask_template[rel_len_info_all_ori]  # 大小为(tail_ent_num * rule_used) * rule_len的二维gru掩码张量
                            tss_info_used = tss_info_all * mask_rel  # 掩码后时间戳
                            tss_info_used = tss_info_used.sum(1) / (mask_rel != 0).sum(1)  # 有效时间戳的平均长度，时间戳为负数表示要被要被掩盖
                        if 'exp' in self.args.time_sim:
                            sim_rel2rule_time = lam_rule * torch.exp(tss_info_used - latest_time)
                        elif 'tan' in self.args.time_sim:
                            sim_rel2rule_time = lam_rule * torch.tan(tss_info_used - latest_time)
                        sim_rel2rule_time = sim_rel2rule_time * mask_info_all
                    if self.args.rule_atten == 'semantic':
                        sim_rel2rule_semantic = torch.where(sim_rel2rule_semantic == 0, -1e9, sim_rel2rule_semantic)
                        sim_rel2rule = F.softmax(sim_rel2rule_semantic).reshape(-1, 1)
                    elif self.args.rule_atten == 'time':
                        sim_rel2rule_time = torch.where(sim_rel2rule_time == 0, -1e9, sim_rel2rule_time)
                        sim_rel2rule = F.softmax(sim_rel2rule_time).reshape(-1, 1)
                    elif self.args.rule_atten == 'fusion':
                        sim_rel2rule = torch.exp(sim_rel2rule_semantic) * sim_rel2rule_time
                        sim_rel2rule = torch.where(sim_rel2rule == 0, -1e9, sim_rel2rule)
                        sim_rel2rule = F.softmax(sim_rel2rule).reshape(-1, 1)
                    elif self.args.rule_atten == 'weighted':
                        sim_rel2rule_semantic = torch.where(sim_rel2rule_semantic == 0, -1e9, sim_rel2rule_semantic)
                        sim_rel2rule_time = torch.where(sim_rel2rule_time == 0, -1e9, sim_rel2rule_time)
                        sim_rel2rule_semantic = F.softmax(sim_rel2rule_semantic)
                        sim_rel2rule_time = F.softmax(sim_rel2rule_time)
                        if args.weight_time_learnable:
                            sim_atten = F.sigmoid(self.sim_atten)
                        else:
                            sim_atten = args.sim_atten
                        sim_rel2rule = (sim_atten * sim_rel2rule_semantic + (1 - sim_atten) * sim_rel2rule_time).reshape(-1, 1)
                    rule_emb = (rule_emb * sim_rel2rule).reshape(tail_ent_num, rule_used, -1)
                else:  # 直接掩盖padding的规则，再对剩余规则表征求和作为聚合规则表征
                    mask_info_all = mask_info_all.reshape(-1, 1)
                    rule_emb = (rule_emb * mask_info_all).reshape(tail_ent_num, rule_used, -1)
                tail2rule_emb = rule_emb.sum(1)  # tail_ent_num * h_dim二维张量
                tail2rule_emb = F.normalize(tail2rule_emb)

            # 对聚合规则表征和关系表征作加权
            if args.atten_matr_num > 0:
                rel_emb_repeat_reachable = r_emb.unsqueeze(0).repeat_interleave(tail_ent_num, dim=0)
                if args.atten_matr_num == 1:
                    weight_rule_normalized = F.sigmoid(self.weight_rule)
                    rel_emb_att = self.activation(torch.mm(tail2rule_emb, weight_rule_normalized) + torch.mm(rel_emb_repeat_reachable, 1 - weight_rule_normalized))
                elif args.atten_matr_num == 2:
                    rel_emb_att = self.activation(torch.mm(tail2rule_emb, self.weight_rule) + torch.mm(rel_emb_repeat_reachable, self.weight_rel))
                rel_emb_att = self.dropout(rel_emb_att)
                rel_emb_att = F.normalize(rel_emb_att)
            else:
                rel_emb_att = tail2rule_emb

            h_emb_repeat = h_emb.unsqueeze(0).repeat_interleave(tail_ent_num, dim=0)
            assert h_emb_repeat.shape == tail2rule_emb.shape  # 确保大小一致

            # 计算能到达的尾实体个数个评分
            time_emb_1_sample = emb_time_1[i].repeat_interleave(tail_ent_num, dim=0).unsqueeze(1)
            time_emb_2_sample = emb_time_2[i].repeat_interleave(tail_ent_num, dim=0).unsqueeze(1)
            query_repr = self.get_query_repr(h_emb_repeat.unsqueeze(1), rel_emb_att.unsqueeze(1), time_emb_1_sample, time_emb_2_sample, tail_ent_num)
            tail_ent_emb = e1_embedded_all[tail_ent_ids]
            sample_score = (query_repr * tail_ent_emb).sum(1)
            score_rule[i, tail_ent_ids] = sample_score
            # print(sample_score)
            # score_all[i, tail_ent_ids] = sample_score

            if args.dataset == 'WIKI':
                del tail_ent_ids, tail2rule_sorted, tss_info_all, rel_len_info_all_ori, mask_info_all, rel_len_info_all, tail2rule, rule_tensor, rule_id, rule2rel_emb, all_len_emb, rule_emb, rule_emb_extended, t_embs_extended, sim_rel2rule_semantic, sim_rel2rule_time, tail2rule_emb

            time_2 = datetime.now()

        return score_rule
