import numpy as np
import torch

from torch.utils.data import BatchSampler, RandomSampler
import pickle

class CLSBatchSampler():
    def __init__(self, labels, batch_size):
        self.labels = labels
        self.batch_size = batch_size
        self.sampler = BatchSampler(RandomSampler(labels), batch_size=self.batch_size, drop_last=False)

    def __len__(self):
        return len(self.sampler)

    def __iter__(self):
        for idxs in self.sampler:
            yield idxs


class GZSLBatchSampler():
    def __init__(self, labels, iterations, novel_class_per_it, num_novel_query, memory_class_per_it, num_memory_query):
        '''Initialize the PrototypicalBatchSampler object

        Args:
            labels: the dataset labels list
            iterations: number of iterations (episodes) per epoch
            novel_class_per_it: the number of novel class in an episode
            num_novel_query: the number of novel class query set in an episode
            memory_class_per_it: the number of memory class in an episode
            num_memory_query: the number of memory class query set in an episode
        '''
        self.labels = labels
        self.iterations = iterations
        self.novel_class_per_it = novel_class_per_it
        self.num_novel_query = num_novel_query
        self.memory_class_per_it = memory_class_per_it
        self.num_memory_query = num_memory_query

        self.classes, self.counts = np.unique(self.labels, return_counts=True)
        self.classes = torch.Tensor(self.classes)
        self.num_classes = len(self.classes)

        self.idxs = range(len(self.labels))
        self.indexes = np.zeros((self.num_classes, max(self.counts)), dtype=int) * np.nan
        self.indexes = torch.Tensor(self.indexes)
        self.numel_per_class = torch.zeros_like(self.classes, dtype=torch.long)
        for idx, label in enumerate(self.labels):
            label_idx = torch.nonzero(self.classes == label, as_tuple=False).item()
            self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx
            self.numel_per_class[label_idx] += 1

        self.all_labels = torch.LongTensor(self.labels)
        self.indexes = self.indexes.int()
        self.numel_per_class = self.numel_per_class.int()
        self.num_classes = len(self.classes)
        print('num classes available: {}'.format(self.num_classes))

    def __len__(self):
        '''returns the number of iterations (episodes) per epoch.'''
        return self.iterations

    def __iter__(self):
        '''yield a batch of indexes.'''
        ncpi = self.novel_class_per_it
        nnq = self.num_novel_query
        mcpi = self.memory_class_per_it
        nmq = self.num_memory_query

        for it in range(self.iterations):
            novel_classes = []
            novel_query = []
            memory_classes = []
            memory_query = []

            rand_classes = torch.randperm(self.num_classes)

            # classes for novel
            class_idxs = rand_classes[:ncpi]
            for i, c in enumerate(self.classes[class_idxs]):
                label_idx = torch.arange(len(self.classes)).long()[self.classes == c].item()
                sample_idxs = torch.randperm(self.numel_per_class[label_idx])[:nnq]
                novel_classes.append(int(c))
                novel_query += self.indexes[label_idx][sample_idxs].tolist()

            # classes for remain
            remain_class_idxs = rand_classes[ncpi:ncpi + mcpi]
            for i, c in enumerate(self.classes[remain_class_idxs]):
                label_idx = torch.arange(len(self.classes)).long()[self.classes == c].item()
                sample_idxs = torch.randperm(self.numel_per_class[label_idx])[:nmq]
                memory_classes.append(int(c))
                memory_query += self.indexes[label_idx][sample_idxs].tolist()

            yield novel_classes, novel_query, memory_query