import json
import torch
import argparse
import sys


parser = argparse.ArgumentParser()
parser.add_argument("--rel2rule_file_path", "-f", default="", type=str)
parser.add_argument("--dataset", "-d", default="ICEWS14", type=str)
parser.add_argument("--rule_len", default=3, type=int)
parser.add_argument('--seed', '-s', type=int, default=6666)
args = parser.parse_args()

# rule_file = f'../output_rule/{args.dataset}/seed{args.seed}/{args.rel2rule_file_path}'
rule_file = f'../output_rule/{args.dataset}/seed{args.seed}/{args.rel2rule_file_path}'
stat_file = f'../data/{args.dataset}/stat.txt'
with open(rule_file, 'r') as file:
    rel2rule_dict = json.load(file)
with open(stat_file, 'r') as file:
    _, rel_num, _ = file.readline().strip().split()

rel2rule_ls = []
pad_rel_id = int(rel_num) * 2
for idx in range(pad_rel_id):
    k = str(idx)
    if k not in rel2rule_dict:  # 关系没有规则
        res = ()
        rel2rule_ls.append(res)
        continue

    rules = rel2rule_dict[k]  # 规则列表
    # 每条规则对应长度为rule_len的关系链，padding
    padded_rule_list = [elem['body_rels'] + [pad_rel_id] * (args.rule_len - len(elem['body_rels'])) for elem in rules]
    extend_rule_id = [len(elem['body_rels']) - 1 for i, elem in enumerate(rules)]  # 规则长度
    # extend_rule_id = [len(elem['body_rels']) - 1 + i * args.rule_len for i, elem in enumerate(rules)]  # 扩展后的关系id
    res = (torch.LongTensor(padded_rule_list).cuda(), torch.LongTensor(extend_rule_id).cuda())
    rel2rule_ls.append(res)

save_dir = f'../output_rule/{args.dataset}/seed{args.seed}/cands_with_rules/{args.rel2rule_file_path}'
save_file = save_dir[:-5] + '_rule_ten_ls.pt'
torch.save(rel2rule_ls, save_file)
