import os
import sys
import torch
import numpy as np
from scipy import sparse
import scipy.sparse as sp
from tqdm import tqdm
import argparse

args = argparse.ArgumentParser()
args.add_argument('--dataset', type=str, default='ICEWS14s')
args.add_argument('--use_fre', action='store_true', default=False)
args.add_argument('--his_len', type=int, default=0)
args = args.parse_args()
print(args)

def load_quadruples(inPath, fileName, fileName2=None):
    with open(os.path.join(inPath, fileName), 'r') as fr:
        quadrupleList = []
        times = set()
        for line in fr:
            line_split = line.split()
            head = int(line_split[0])
            tail = int(line_split[2])
            rel = int(line_split[1])
            time = int(line_split[3])
            quadrupleList.append([head, rel, tail, time])
            times.add(time)
    if fileName2 is not None:
        with open(os.path.join(inPath, fileName2), 'r') as fr:
            for line in fr:
                line_split = line.split()
                head = int(line_split[0])
                tail = int(line_split[2])
                rel = int(line_split[1])
                time = int(line_split[3])
                quadrupleList.append([head, rel, tail, time])
                times.add(time)
    times = list(times)
    times.sort()

    return np.asarray(quadrupleList), np.asarray(times)

def load_all_quadruples(inPath, fileName, fileName2=None, fileName3=None):
    quadrupleList = []
    times = set()
    with open(os.path.join(inPath, fileName), 'r') as fr:
        for line in fr:
            line_split = line.split()
            head = int(line_split[0])
            tail = int(line_split[2])
            rel = int(line_split[1])
            time = int(line_split[3])
            quadrupleList.append([head, rel, tail, time])
            times.add(time)
    with open(os.path.join(inPath, fileName2), 'r') as fr:
        for line in fr:
            line_split = line.split()
            head = int(line_split[0])
            tail = int(line_split[2])
            rel = int(line_split[1])
            time = int(line_split[3])
            quadrupleList.append([head, rel, tail, time])
            times.add(time)
    with open(os.path.join(inPath, fileName3), 'r') as fr:
        for line in fr:
            line_split = line.split()
            head = int(line_split[0])
            tail = int(line_split[2])
            rel = int(line_split[1])
            time = int(line_split[3])
            quadrupleList.append([head, rel, tail, time])
            times.add(time)
    times = list(times)
    times.sort()

    return np.asarray(quadrupleList), np.asarray(times)

def get_total_number(inPath, fileName):
    with open(os.path.join(inPath, fileName), 'r') as fr:
        for line in fr:
            line_split = line.split()
            return int(line_split[0]), int(line_split[1])

def get_data_with_t(data, tim):
    triples = [[quad[0], quad[1], quad[2]] for quad in data if quad[3] == tim]
    return np.array(triples)

all_data, all_times = load_all_quadruples('../data/{}'.format(args.dataset), 'train.txt', 'valid.txt', "test.txt")
num_e, num_r = get_total_number('../data/{}'.format(args.dataset), 'stat.txt')

save_dir_obj = '../data/{}/history{}/'.format(args.dataset, args.his_len)

def mkdirs(path):
	if not os.path.exists(path):
		os.makedirs(path)

mkdirs(save_dir_obj)
raw_num_r = num_r
num_r = num_r * 2
fre_dict_tail = {}
fre_dict_rel = {}
last_time = -1
time_int = all_times[1] - all_times[0]
# 为每个时间戳生成反向三元组，再生成尾实体和关系的稀疏矩阵
for tim in tqdm(all_times):
    if not args.his_len:
        train_new_data = np.array([[quad[0], quad[1], quad[2], quad[3]] for quad in all_data if quad[3] < tim])
    else:
        upper = tim
        lower = tim - args.his_len * time_int
        if args.dataset == 'GDELT' and lower >= 0:
            while 1:
                if lower not in all_times:
                    lower -= 1
                if lower in all_times:
                    break
        train_new_data = np.array([[quad[0], quad[1], quad[2], quad[3]] for quad in all_data
                                   if quad[3] < upper and quad[3] >= lower])
        # train_new_data = np.array([[quad[0], quad[1], quad[2], quad[3]] for quad in all_data
        #                            if quad[3] < tim and quad[3] >= tim - args.his_len * time_int])
    if tim != all_times[0]:
        train_new_data = torch.from_numpy(train_new_data)
        inverse_train_data = train_new_data[:, [2, 1, 0, 3]]
        inverse_train_data[:, 1] = inverse_train_data[:, 1] + raw_num_r
        train_new_data = torch.cat([train_new_data, inverse_train_data])

        # entity history
        train_unique_data = torch.unique(train_new_data[:, :3], sorted=False, dim=0)
        train_unique_data = train_unique_data.numpy()
        row = train_unique_data[:, 0] * num_r + train_unique_data[:, 1]
        col = train_unique_data[:, 2]
        # relation history
        rel_row = train_unique_data[:, 0] * num_e + train_unique_data[:, 2]
        rel_col = train_unique_data[:, 1]
        # 创建频率字典，索引元组为键，对于新的时间戳上的数据，将新的索引加入字典中，不断累加，直接利用值作为数据创建稀疏矩阵
        # 将行列转换为索引数组，再双重循环遍历得到frequency，计算量太大了
        # 不能用torch先创建矩阵再转化，矩阵太大了
        if args.use_fre:
            if not args.his_len:
                train_last_time_data = np.array([[quad[0], quad[1], quad[2], quad[3]] for quad in train_new_data if quad[3] == last_time])
            else:
                fre_dict_tail = {}
                fre_dict_rel = {}
                train_last_time_data = train_new_data.numpy()  # 直接使用张量赋值会使键值对增对，需转化为数组
            for sample in train_last_time_data:
                tail_idx = (sample[0] * num_r + sample[1], sample[2])
                rel_idx = (sample[0] * num_e + sample[2], sample[1])
                if tail_idx in fre_dict_tail:
                    fre_dict_tail[tail_idx] += 1
                else:
                    fre_dict_tail[tail_idx] = 1
                if rel_idx in fre_dict_rel:
                    fre_dict_rel[rel_idx] += 1
                else:
                    fre_dict_rel[rel_idx] = 1
            d = np.array(list(fre_dict_tail.values()))
            rel_d = np.array(list(fre_dict_rel.values()))
            tail_seq = sp.csr_matrix((d, (row, col)), shape=(num_e * num_r, num_e))  # 二维二元稀疏矩阵
            rel_seq = sp.csr_matrix((rel_d, (rel_row, rel_col)), shape=(num_e * num_e, num_r))
        else:
            d = np.ones(len(row))
            tail_seq = sp.csr_matrix((d, (row, col)), shape=(num_e * num_r, num_e))  # 二维二元稀疏矩阵
            rel_d = np.ones(len(rel_row))
            rel_seq = sp.csr_matrix((rel_d, (rel_row, rel_col)), shape=(num_e * num_e, num_r))

    else:
        tail_seq = sp.csr_matrix(([], ([], [])), shape=(num_e * num_r, num_e))  # 压缩稀疏行的稀疏矩阵
        rel_seq = sp.csr_matrix(([], ([], [])), shape=(num_e * num_e, num_r))
    sp.save_npz('../data/{}/history{}/tail_history_{}.npz'.format(args.dataset, args.his_len, tim), tail_seq)
    sp.save_npz('../data/{}/history{}/rel_history_{}.npz'.format(args.dataset, args.his_len, tim), rel_seq)
    last_time = tim
