# -*- coding: UTF-8 -*-
import torch
from torch.utils.data import DataLoader, Dataset
import pickle
from path import ROOT_DIR
import json


def statistics(seqs, labels):
    """
    统计数据
    """
    topic_click = 0
    max_length = 0
    topics = set()
    for seq, l in zip(seqs, labels):
        s = seq + [l]
        # s = [seq] + [l]
        
        max_length = max(max_length, len(s))
        topic_click += len(s)
        topics = topics | set(s)
    print(f"topics_clicks: {topic_click}")
    print(f"max_length: {max_length}")
    print(f"avg_length: {topic_click / len(seqs)}")
    print(f"topic_num: {len(topics)}")
    print(f"max_topic_index: {max(topics)}")


# def read_data(filename, topic_to_id):
#     """
#     分成序列和label
#     """
#     data = pickle.load(open(f"./data/{filename}processed_data.pkl", "rb"))
#     seqs = []
#     labels = []
#     seq_list=[]
#     for d in data:
#         seq = d[2] + d[5][:2]
#         seq_after_index = []
#         for x in seq:
#             seq_after_index += [topic_to_id[x[0]]]
#         seqs += [seq_after_index]
#         for s in seqs:
#             if len(s) >= 2:
#                 for i in range(1, len(s)):
#                     tar = s[-i]
#                     labels += [topic_to_id[tar]]
#                     seq_list += [s[:-i]]
#     return seq_list, labels

def read_data(filename, topic_to_id):
    data = pickle.load(open(f"./data/{filename}processed_data.pkl", "rb"))
    seqs = []
    labels = []
    for d in data:
        seq = d[2] 
        label = d[5][0]
        seq_after_index = []
        for x in seq:
            seq_after_index += [topic_to_id[x[0]]]
        seqs += [seq_after_index]
        print(seqs)
        labels += [topic_to_id[label]]
    return seqs, labels

def read_data_target(filename, topic_to_id):
    data = pickle.load(open(f"./data/{filename}processed_data.pkl", "rb"))
    seqs = []
    labels = []
    for d in data:
        seq = d[2] 
        tar = d[4][-1]
        seq_only = []
        seq_after_index = []
        for x in seq:
            seq_only += [x[0]]
        seq_only += [tar]
        for y in seq_only:
            seq_after_index += [topic_to_id[y]]
        seqs += [seq_after_index]

        label = d[5][0]
        labels += [topic_to_id[label]]
    return seqs, labels

def read_data_context(filename, topic_to_id):
    data = pickle.load(open(f"./data/{filename}processed_data.pkl", "rb"))
    seqs = []
    labels = []
    for d in data:
        seq = d[1]
        label = d[5][0]
        seqs += [seq]
        labels += [topic_to_id[label]]
    return seqs, labels


def get_data(seq_length=10, front_padding=True):
    """
    获取dataset
    """
    topic_to_id = json.load(open(f"{ROOT_DIR}/data/topic_to_id_2k5.json", 'r'))

    train_seqs, train_labels = read_data_target("train", topic_to_id)
    valid_seqs, valid_labels = read_data_target("valid", topic_to_id)
    test_seqs, test_labels = read_data_target("test", topic_to_id)

    print("\ntraining data:")
    statistics(train_seqs, train_labels)
    print("\nvalidate data:")
    statistics(valid_seqs, valid_labels)
    print("\ntesting data:")
    statistics(test_seqs, test_labels)
    print()

    train_dataset = SequentialDataset(seq=train_seqs,
                                      label=train_labels,
                                      seq_length=seq_length,
                                      front_padding=front_padding)
    valid_dataset = SequentialDataset(seq=valid_seqs,
                                      label=valid_labels,
                                      seq_length=seq_length,
                                      front_padding=front_padding)
    test_dataset = SequentialDataset(seq=test_seqs,
                                     label=test_labels,
                                     seq_length=seq_length,
                                     front_padding=front_padding)
    return train_dataset, valid_dataset, test_dataset


class SequentialDataset(Dataset):
    def __init__(self, seq, label, seq_length=15, front_padding=True):
        # padding
        self.seq = torch.zeros((len(label), seq_length), dtype=torch.long)
        for i, s in enumerate(seq):
            l = min(len(s), seq_length)
            if front_padding:
                self.seq[i][-l:] = torch.tensor(s[-l:], dtype=torch.long)
            else:
                self.seq[i][:l] = torch.tensor(s[-l:], dtype=torch.long)
        # change to tensor
        self.labels = torch.tensor(label, dtype=torch.long)

    def __getitem__(self, idx):
        return self.seq[idx], self.labels[idx]

    def __len__(self):
        return len(self.labels)

    def get_data_loader(self, device, batch_size=128, shuffle=True):
        pin_memory = device not in ["cpu", "CPU"]
        data_loader = DataLoader(self, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
        return data_loader


if __name__ == '__main__':
    train_dataset, valid_dataset, test_dataset = get_data(seq_length=15, front_padding=True)
    data_loader = train_dataset.get_data_loader(device="cpu", batch_size=128, shuffle=False)
    for i, d in enumerate(data_loader):
        print(d)
        if i == 1:
            break
