# -*- coding: UTF-8 -*-
import torch
from torch.utils.data import DataLoader, Dataset
import pickle
from path import ROOT_DIR
import json


def read_data(filename="./data/train_data.pkl"):
    """
    读取文件, 得到sequences
    """
    train_data = pickle.load(open(filename, "rb"))
    topic_seqs = []
    for line in train_data:
        goal_path = line['goal_path']
        seq = []
        for idx in goal_path:
            behavior = goal_path[idx]
            if behavior[0] == "Seeker":
                i = 2
                while i < len(behavior):
                    # topic可能是一个list
                    if type(behavior[i]) == list:
                        for b in behavior[i]:
                            if b is not None:
                                seq += [b]
                    else:
                        if behavior[i] is not None:
                            seq += [behavior[i]]
                    i += 2
        topic_seqs += [seq]
    return topic_seqs


def get_topic_dict(train_seqs):
    """
    根据train sequences获得所有topics, 并映射成id
    """
    topic_index = {}
    index_topic = {}
    index = 1
    for seq in train_seqs:
        for topic in seq:
            if topic not in topic_index:
                topic_index[topic] = index
                index_topic[index] = topic
                index += 1
    return topic_index, index_topic


def seq_to_index(seqs, topic_index):
    """
    将sequences转换为index
    仅保留在training data里出现过的topic, 没有出现过的会删除
    """
    seq_list = []
    for seq in seqs:
        new_seq = []
        for topic in seq:
            if topic in topic_index:
                new_seq += [topic_index[topic]]
        seq_list += [new_seq]
    return seq_list


def split_seq_label(seqs):
    """
    把sequence分为序列和label
    """
    seq_list = []
    labels = []
    for seq in seqs:
        # 保留序列长度>1的序列
        if len(seq[:-1]) != 0:
            seq_list += [seq[:-1]]
            labels += [seq[-1]]
    return seq_list, labels

def split_seq_label_aug(seqs):
    '''
    sequence 分为序列和label并带序列增强
    '''
    seq_list = []
    labels = []
    for seq in seqs:
        for i in range(2, len(seq)):
            tar = seq[-i]
            labels += [tar]
            seq_list += [seq[:-i]]
    return seq_list, labels



def statistics(seqs, labels):
    """
    统计数据
    """
    topic_click = 0
    max_length = 0
    topics = set()
    for seq, l in zip(seqs, labels):
        s = seq + [l]
        max_length = max(max_length, len(s))
        topic_click += len(s)
        topics = topics | set(s)
    print(f"topics_clicks: {topic_click}")
    print(f"max_length: {max_length}")
    print(f"avg_length: {topic_click / len(seqs)}")
    print(f"topic_num: {len(topics)}")
    print(f"max_topic_index: {max(topics)}")


def get_data(seq_length=10, front_padding=True):
    """
    获取dataset
    """
    train_seqs = read_data("./data/train_data.pkl")
    valid_seqs = read_data("./data/valid_data.pkl")
    test_seqs = read_data("./data/test_data.pkl")

    topic_index, index_topic = get_topic_dict(train_seqs)

    train_seqs, train_labels = split_seq_label_aug(seq_to_index(train_seqs, topic_index))
    valid_seqs, valid_labels = split_seq_label_aug(seq_to_index(valid_seqs, topic_index))
    test_seqs, test_labels = split_seq_label_aug(seq_to_index(test_seqs, topic_index))
    print("\ntraining data:")
    statistics(train_seqs, train_labels)
    print("\nvalidate data:")
    statistics(valid_seqs, valid_labels)
    print("\ntesting data:")
    statistics(test_seqs, test_labels)
    print()

    train_dataset = SequentialDataset(seq=train_seqs,
                                      label=train_labels,
                                      seq_length=seq_length,
                                      front_padding=front_padding)
    valid_dataset = SequentialDataset(seq=valid_seqs,
                                      label=valid_labels,
                                      seq_length=seq_length,
                                      front_padding=front_padding)
    test_dataset = SequentialDataset(seq=test_seqs,
                                     label=test_labels,
                                     seq_length=seq_length,
                                     front_padding=front_padding)
    return train_dataset, valid_dataset, test_dataset


# def read_data(filename, topic_to_id):
#     """
#     分成序列和label
#     """
#     data = pickle.load(open(f"./data/{filename}processed_data.pkl", "rb"))
#     seqs = []
#     labels = []
#     for d in data:
#         seq = d[2]
#         label = d[5][0]
#         seq_after_index = []
#         for x in seq:
#             seq_after_index += [topic_to_id[x[0]]]
#         seqs += [seq_after_index]
#         labels += [topic_to_id[label]]
#
#     return seqs, labels
#
#
# def get_data(seq_length=10, front_padding=True):
#     """
#     获取dataset
#     """
#     topic_to_id = json.load(open(f"{ROOT_DIR}/data/topic_to_id_2k5.json", 'r'))
#
#     train_seqs, train_labels = read_data("train", topic_to_id)
#     valid_seqs, valid_labels = read_data("valid", topic_to_id)
#     test_seqs, test_labels = read_data("test", topic_to_id)
#
#     print("\ntraining data:")
#     statistics(train_seqs, train_labels)
#     print("\nvalidate data:")
#     statistics(valid_seqs, valid_labels)
#     print("\ntesting data:")
#     statistics(test_seqs, test_labels)
#     print()
#
#     train_dataset = SequentialDataset(seq=train_seqs,
#                                       label=train_labels,
#                                       seq_length=seq_length,
#                                       front_padding=front_padding)
#     valid_dataset = SequentialDataset(seq=valid_seqs,
#                                       label=valid_labels,
#                                       seq_length=seq_length,
#                                       front_padding=front_padding)
#     test_dataset = SequentialDataset(seq=test_seqs,
#                                      label=test_labels,
#                                      seq_length=seq_length,
#                                      front_padding=front_padding)
#     return train_dataset, valid_dataset, test_dataset


class SequentialDataset(Dataset):
    def __init__(self, seq, label, seq_length=15, front_padding=True):
        # padding
        self.seq = torch.zeros((len(label), seq_length), dtype=torch.long)
        for i, s in enumerate(seq):
            l = min(len(s), seq_length)
            if front_padding:
                self.seq[i][-l:] = torch.tensor(s[-l:], dtype=torch.long)
            else:
                self.seq[i][:l] = torch.tensor(s[-l:], dtype=torch.long)
        # change to tensor
        self.labels = torch.tensor(label, dtype=torch.long)

    def __getitem__(self, idx):
        return self.seq[idx], self.labels[idx]

    def __len__(self):
        return len(self.labels)

    def get_data_loader(self, device, batch_size=128, shuffle=True):
        pin_memory = device not in ["cpu", "CPU"]
        data_loader = DataLoader(self, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
        return data_loader


if __name__ == '__main__':
    train_dataset, valid_dataset, test_dataset = get_data(seq_length=15, front_padding=True)
    data_loader = train_dataset.get_data_loader(device="cpu", batch_size=128, shuffle=False)
    for i, d in enumerate(data_loader):
        print(d)
        if i == 1:
            break
