from data_preprocess import *
import numpy as np


class Dataset():
    """Generalized Zero-Shot Learning Dataset, including all label."""
    def __init__(self, data_path, test_num=10):
        self.val_class = 0 # number of total val class
        self.test_class = 0 # number of total test class
        self.n_seen_class = 0 # number of seen class
        self.n_val_unseen_class = 0 # number of evaluation unseen class
        self.n_unseen_class = 0 # number of unseen classe

        self.seen_class = None
        self.val_unseen_class = None
        self.unseen_class = None

        self.train_seen = None
        self.val_seen = None
        self.test_seen = None
        self.val_unseen = None
        self.test_unseen = None

        self._init_dataset(data_path)
        # self._init_dataset_70(data_path, test_num)

        # self._init_dataset_default()
        # self.reset_dataset_by_random()
        # self._init_dataset_70()
        # print(self.seen_class)

    def _init_dataset(self, data_path):
        data = load_data_from_pkl(data_path)

        for k, v in data.items():
            setattr(self, k, v)

        self.n_seen_class = len(self.seen_class)
        self.n_unseen_class = len(self.unseen_class)


    def _init_dataset_default(self):
        """Default Init dataset by file."""
        self.n_seen_class = len(self.seen_class)
        self.n_unseen_class = len(self.unseen_class)
        # For those some seen class have no test samples
        # c = 0
        # y = {int(sample['y']) : None for sample in self.test_seen}
        # for i in range(len(self.seen_class)):
        #     if int(i) not in y:
        #         self.test_seen.append(self.seen_class[i])
        #         c += 1     # print(i)
        # print('Class without testing: {}'.format(c))
        # self.test_unseen += self.unseen_class


    def reset_dataset_by_random(self):
        self.all_class = self.seen_class + self.unseen_class
        np.random.shuffle(self.all_class)
        # print([sample['text'] for sample in self.all_class])

        idx = {sample['y']: i for i, sample in enumerate(self.all_class)}
        for cls in self.all_class:
            cls['y'] = idx[cls['y']]

        self.seen_class = self.all_class[:self.n_seen_class]
        self.unseen_class = self.all_class[self.n_seen_class:]

        samples = {y: [] for y in range(len(self.all_class))}
        for sample in self.train_seen + self.test_seen + self.test_unseen:
            sample['y'] = idx[sample['y']]
            samples[sample['y']].append(sample)
        # for l in samples.values():
        #     np.random.shuffle(l)

        self.train_seen = []
        self.test_seen = []
        self.test_unseen = []
        for i in range(len(samples)):
            l = samples[i]
            np.random.shuffle(l)
            if i < self.n_seen_class:
                self.train_seen += l[:len(l) * 7 // 10]
                self.test_seen += l[len(l) * 7 // 10:]
            else:
                self.test_unseen += l



    def _init_dataset_70_train(self, data_path, test_num):
        "load from data path"
        data = load_data_from_pkl(data_path)

        for k, v in data.items():
            setattr(self, k, v)

        # print(self.seen_class)
        self.n_seen_class = 70  # number of seen class
        self.n_unseen_class = test_num  # number of unseen classes

        self.all_class = self.seen_class + self.unseen_class
        self.seen_class = self.all_class[:70]
        self.unseen_class = self.all_class[70:]


        samples = {y : [] for y in range(len(self.all_class))}
        for sample in self.train_seen + self.test_seen + self.test_unseen:
            samples[sample['y']].append(sample)
        for l in samples.values():
            np.random.shuffle(l)

        self.train_seen = []
        self.test_seen = []
        self.test_unseen = []
        for i in range(len(samples)):
            l = samples[i]
            if i < 70:
                self.train_seen += l[:len(l) * 7 // 10 ]
                self.test_seen += l[len(l) * 7 // 10:]
            else:
                if i < 70 + test_num:
                    self.test_unseen += l
        self.unseen_class = self.unseen_class[:test_num]

        # self.n_seen_class = len(self.seen_class)
        # self.n_val_unseen_class = len(self.val_unseen_class)
        # self.n_unseen_class = len(self.unseen_class)

    def _init_dataset_70(self, data_path, test_num):
        "load from data path"
        data = load_data_from_pkl(data_path)

        for k, v in data.items():
            setattr(self, k, v)

        # print(self.seen_class)
        self.n_seen_class = 70  # number of seen class
        self.n_unseen_class = test_num  # number of unseen classe

        self.all_class = self.seen_class + self.unseen_class
        self.seen_class = self.all_class[:70]
        self.unseen_class = self.all_class[70:]

        samples = {y: [] for y in range(len(self.all_class))}
        for sample in self.train_seen + self.test_seen + self.test_unseen:
            samples[sample['y']].append(sample)
        for l in samples.values():
            np.random.shuffle(l)

        self.train_seen = []
        self.test_seen = []
        self.test_unseen = []
        for i in range(len(samples)):
            l = samples[i]
            if i < 70:
                self.train_seen += l[:len(l) * 7 // 10]
                self.test_seen += l[len(l) * 7 // 10:]
            else:
                self.test_unseen += l

        # self.n_seen_class = len(self.seen_class)
        # self.n_val_unseen_class = len(self.val_unseen_class)
        # self.n_unseen_class = len(self.unseen_class)



if __name__ == '__main__':

    dataset =  Dataset('../data/Clinc_Goog.pkl')

