# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import itertools
import random
from typing import (
    List,
    Optional,
)
import glob
from collections import defaultdict
import pickle
import time
import math

import torch
from torch.utils.data import (
    Dataset,
    default_collate,
)
from torchvision.transforms import (
    ToTensor,
    ToPILImage,
)
import torch.nn.functional as F
import numpy as np
import tensorflow_datasets as tfds
from tqdm import tqdm


def enumerate_attribute_value(n_attributes, n_values):
    iters = [range(n_values) for _ in range(n_attributes)]

    return list(itertools.product(*iters))


def select_subset_V1(data, n_subset, n_attributes, n_values, random_seed=7):
    import numpy as np

    assert n_subset <= n_values
    random_state = np.random.RandomState(seed=random_seed)

    chosen_val = []
    for attribute in range(n_attributes):
        chosen_val.append(
            [0]
            + list(random_state.choice(range(1, n_values), n_subset - 1, replace=False))
        )

    sampled_data = []
    for sample in data:
        boolean = True
        for attribute in range(n_attributes):
            boolean = boolean and (sample[attribute] in chosen_val[attribute])
        if boolean:
            sampled_data.append(sample)
    return sampled_data


def select_subset_V2(data, n_subset, n_attributes, n_values, random_seed=7):
    import numpy as np

    assert n_subset <= n_values
    random_state = np.random.RandomState(seed=random_seed)
    sampled_data = []
    # Sample the diagonal (minus (0,0)) to impose having each attribute is present at least once in the dataset
    start = 0
    while start < (n_values ** n_attributes):
        if start > 0:
            sampled_data.append(data[start])
        start += n_values + 1
    # Sample remaining
    to_sample = (n_subset ** n_attributes) - len(sampled_data)
    tobesampled = copy.deepcopy(data)
    for sample in sampled_data:
        tobesampled.remove(sample)
    tmp = list(random_state.choice(range(len(tobesampled)), to_sample, replace=False))

    for i in tmp:
        sampled_data += [tobesampled[i]]
    return sampled_data


def one_hotify(data, n_attributes, n_values):
    r = []
    for config in data:
        z = torch.zeros((n_attributes, n_values))
        for i in range(n_attributes):
            z[i, config[i]] = 1
        r.append(z.view(-1))
    return r


def split_holdout(dataset):
    train, hold_out = [], []

    for values in dataset:
        indicators = [x == 0 for x in values]
        if not any(indicators):
            train.append(values)
        elif sum(indicators) == 1:
            hold_out.append(values)
        else:
            pass

    return train, hold_out


def split_train_test(dataset, p_hold_out=0.1, random_seed=7):
    import numpy as np

    assert p_hold_out > 0
    random_state = np.random.RandomState(seed=random_seed)

    n = len(dataset)
    permutation = random_state.permutation(n)

    n_test = int(p_hold_out * n)

    test = [dataset[i] for i in permutation[:n_test]]
    train = [dataset[i] for i in permutation[n_test:]]
    assert train and test

    assert len(train) + len(test) == len(dataset)

    return train, test


def split_train_val_test(dataset, validation_ratio=0.1, test_ratio=0.1, random_seed=7):
    import numpy as np

    # assert validation_ratio > 0
    random_state = np.random.RandomState(seed=random_seed)

    n = len(dataset)
    permutation = random_state.permutation(n)

    n_test = int(test_ratio * n)
    n_val = int(validation_ratio * n)

    test = [dataset[i] for i in permutation[:n_test]]
    val = [dataset[i] for i in permutation[n_test:n_test + n_val]]
    train = [dataset[i] for i in permutation[n_test + n_val:]]
    # assert train and test and val

    assert len(train) + len(test) + len(val) == len(dataset)

    return train, val, test


class ScaledDataset:
    def __init__(self, examples, scaling_factor=1, return_tuple: bool = True):
        self.examples = examples
        self.scaling_factor = scaling_factor
        self.return_tuple = return_tuple

    def __len__(self):
        return len(self.examples) * self.scaling_factor

    def __getitem__(self, k):
        k = k % len(self.examples)
        if self.return_tuple:
            return self.examples[k], self.examples[k]
        else:
            return self.examples[k]


class ImageDiscrimiationDataset(Dataset):
    def __init__(
        self,
        dataset_dir: str,
        split: str,
        att_indices: List[List[int]],
        n_distractors: int,
        n_samples_per_epoch: Optional[int],
        seed: Optional[int] = None,
        deterministic: bool = False,
        use_cache: bool = True,
        batch_size: int = 1,
        scale: int = 1,
        load_at_init: bool = True,
        att_filter: Optional[List[int]] = None,
        skip_short_atts: bool = False,
    ) -> None:
        super().__init__()
        self.dataset_dir = f'{dataset_dir}/{split}'
        self.split = split
        self.n_distractors = n_distractors
        self.n_samples_per_epoch = n_samples_per_epoch
        self.deterministic = deterministic
        self.use_cache = use_cache
        self.att_indices = [set(att_idx) for att_idx in att_indices]
        self.seed = seed
        self.batch_size = batch_size
        self.scale = scale
        self.att_filter = att_filter
        self.skip_short_atts = skip_short_atts

        dataset = tfds.load(
            'byol_imagenet2012',
            split=self.split,
            shuffle_files=False,
            data_dir='data',
        )
        data = tfds.as_numpy(dataset)

        if load_at_init:
            self.file_name_to_logits = []
            self.file_names_list = [[] for _ in range(len(self.att_indices))]
            self.full_file_names = []
            self.attributes = set()
            i = 0
            for item in tqdm(data, desc='Parsing Data'):
                file_name = item['file_name'].decode("utf-8")
                label = item['label']
                if self.att_filter is None or label in self.att_filter:
                    self.file_name_to_logits.append(torch.tensor(item['logit']).to(0))
                    self.full_file_names.append(i)
                    self.attributes.add(label)
                    for j, atts in enumerate(self.att_indices):
                        if label in atts:
                            self.file_names_list[j].append(i)
                    i += 1
            
            for att_idx, names in zip(self.att_indices, self.file_names_list):
                print(f'# samples: {len(names)} for n_att: {len(att_idx)}')

            if self.n_samples_per_epoch is None:
                self.n_samples_per_epoch = len(self.full_file_names)

            if self.deterministic:
                rng = np.random.default_rng(self.seed)
                self.indices = rng.choice(
                    a=len(self.full_file_names), 
                    size=self.n_samples_per_epoch, 
                    replace=False,
                )
                self.indices = torch.split(torch.tensor(self.indices), self.batch_size)
                self.indices = [idx.tolist() for idx in self.indices]
        
            self.file_name_to_logits = torch.stack(self.file_name_to_logits)

        if self.deterministic:
            assert self.scale == 1
    
    def __getitem__(self, index):
        # TODO: Check for edge cases.
        index = index % math.ceil(len(self.full_file_names) / self.batch_size)

        if self.deterministic:
            indices = self.indices[index]
        else:
            indices = np.random.choice(
                a=len(self.full_file_names), size=self.batch_size, replace=False
            )

        sender_input_paths = set(indices)
        sender_inputs = [self._load_file(path) for path in sender_input_paths]

        receiver_input = []
        for names in self.file_names_list:
            if self.skip_short_atts and len(names) == 0:
                continue

            indices = np.random.choice(
                len(names), 
                size=self.n_distractors + len(sender_inputs),
                replace=False
            )
            inputs = []
            for idx in indices:
                path = names[idx]
                identical = path in sender_input_paths
                if identical:
                    continue
                logits = self._load_file(path)
                inputs.append(logits)
                if len(inputs) == self.n_distractors:
                    break
            inputs = torch.stack(inputs)
            receiver_input.append(inputs)

        sender_input = torch.stack(sender_inputs)
        labels = torch.zeros(len(sender_inputs), len(self.file_names_list), dtype=torch.long)
        receiver_input = torch.stack(receiver_input)

        return sender_input, labels, receiver_input, {'sender_input': sender_input}
    
    def _load_file(self, path: str):
        logits = self.file_name_to_logits[path]
        return logits
    
    @property
    def visual_dim(self) -> int:
        return self[0][0].size(1)

    @staticmethod
    def n_attributes(dataset_dir: str) -> int:
        dataset = tfds.load(
            'byol_imagenet2012',
            split='test',
            shuffle_files=False,
            data_dir='data',
        )
        data = tfds.as_numpy(dataset)
        
        attributes = set()
        for i, item in enumerate(tqdm(data, desc='Parsing Data')):
            label = item['label']
            attributes.add(label)
        return len(attributes)
    
    def __len__(self):
        return math.ceil(self.n_samples_per_epoch * self.scale / self.batch_size)

    @staticmethod
    def collate_fn(item) -> int:
        assert len(item) == 1
        sender_input, labels, receiver_input, aux = item[0]
        return sender_input, labels, receiver_input, aux
    
    def load_data(self, num_workers, worker_id):
        dataset = tfds.load(
            'byol_imagenet2012',
            split=self.split,
            shuffle_files=False,
            data_dir='data',
        )
        data = tfds.as_numpy(dataset)

        if self.n_samples_per_epoch is None:
            self.n_samples_per_epoch = len(data)

        indices = torch.arange(self.n_samples_per_epoch)
        indices = indices.chunk(num_workers)[worker_id]
        print('num_workers and worker_id:', num_workers, worker_id)
        print('Low and high indices:', indices[0], indices[-1])

        self.file_name_to_logits = {}
        self.file_names_list = [[] for _ in range(len(self.att_indices))]
        self.full_file_names = []
        self.attributes = set()
        for i, item in enumerate(tqdm(data, desc='Parsing Data')):
            if not ((i >= indices[0]) and (i <= indices[-1])):
                if i > indices[-1]:
                    break
                continue
            file_name = item['file_name'].decode("utf-8")
            label = item['label']
            self.file_name_to_logits[file_name] = item['logit']
            self.full_file_names.append(file_name)
            self.attributes.add(label)
            for i, atts in enumerate(self.att_indices):
                if label in atts:
                    self.file_names_list[i].append(file_name)
        
        for att_idx, names in zip(self.att_indices, self.file_names_list):
            print(f'# samples: {len(names)} for n_att: {len(att_idx)}')

        if self.deterministic:
            rng = np.random.default_rng(self.seed)
            self.indices = rng.choice(
                a=len(self.full_file_names), 
                size=self.n_samples_per_epoch, 
                replace=False,
            )
            self.indices = torch.split(torch.tensor(self.indices), self.batch_size)
            self.indices = [idx.tolist() for idx in self.indices]


class ImageReconstructionDataset(Dataset):
    def __init__(
        self,
        dataset_dir: str,
        split: str,
        n_samples_per_epoch: Optional[int],
        seed: Optional[int] = None,
        deterministic: bool = False,
        batch_size: int = 1,
        scale: int = 1,
        load_at_init: bool = True,
    ) -> None:
        super().__init__()
        self.dataset_dir = f'{dataset_dir}/{split}'
        self.split = split
        self.n_samples_per_epoch = n_samples_per_epoch
        self.deterministic = deterministic
        self.seed = seed
        self.batch_size = batch_size
        self.scale = scale

        dataset = tfds.load(
            'byol_imagenet2012',
            split=self.split,
            shuffle_files=False,
            data_dir='data',
        )
        data = tfds.as_numpy(dataset)

        if self.n_samples_per_epoch is None:
            self.n_samples_per_epoch = len(data)

        if load_at_init:
            self.file_name_to_logits = []
            self.full_file_names = []
            self.attributes = set()
            for i, item in enumerate(tqdm(data, desc='Parsing Data')):
                file_name = item['file_name'].decode("utf-8")
                label = item['label']
                self.file_name_to_logits.append(torch.tensor(item['logit']).to(0))
                self.full_file_names.append(i)
                self.attributes.add(label)

            if self.deterministic:
                rng = np.random.default_rng(self.seed)
                self.indices = rng.choice(
                    a=len(self.full_file_names), 
                    size=self.n_samples_per_epoch, 
                    replace=False,
                )
                self.indices = torch.split(torch.tensor(self.indices), self.batch_size)
                self.indices = [idx.tolist() for idx in self.indices]
        
            self.file_name_to_logits = torch.stack(self.file_name_to_logits)

        if self.deterministic:
            assert self.scale == 1
    
    def __getitem__(self, index):
        # TODO: Check for edge cases.
        index = index % math.ceil(len(self.full_file_names) / self.batch_size)
        start = time.time()

        if self.deterministic:
            indices = self.indices[index]
        else:
            indices = np.random.choice(
                a=len(self.full_file_names), size=self.batch_size, replace=False
            )

        sender_input_paths = set(indices)
        sender_inputs = [self._load_file(path) for path in sender_input_paths]

        sender_input = torch.stack(sender_inputs)
        return (sender_input,)
    
    def _load_file(self, path: str):
        logits = self.file_name_to_logits[path]
        return logits
    
    @property
    def visual_dim(self) -> int:
        return self[0][0].size(1)

    @staticmethod
    def n_attributes(dataset_dir: str) -> int:
        dataset = tfds.load(
            'byol_imagenet2012',
            split='test',
            shuffle_files=False,
            data_dir='data',
        )
        data = tfds.as_numpy(dataset)
        item = next(iter(data))
        logit = item['logit']

        return len(logit)
    
    def __len__(self):
        return math.ceil(self.n_samples_per_epoch * self.scale / self.batch_size)

    @staticmethod
    def collate_fn(item) -> int:
        assert len(item) == 1
        sender_input = item[0]
        return sender_input


class ImageDiscriminationDatasetLogitWrapper(Dataset):
    def __init__(self, dataset: ImageDiscrimiationDataset):
        self.dataset = dataset
    
    def __getitem__(self, index):
        logits, *_ =  self.dataset[index]
        return logits[0]
    
    def __len__(self):
        return len(self.dataset)


class Shape3DEvenSplitDataset(Dataset):
    N_CLASSES = 4

    def __init__(
        self, 
        split: str = None, 
        seed: int = None, 
        val_ratio: float = None, 
        test_ratio: float = None,
        dataset = None,
        attributes = ['shape, object_hue']
    ):
        self.attributes = attributes
        if dataset is None:
            dataset = Shape3DEvenSplitDataset._load_dataset()
            train, valid, test = split_train_val_test(
                dataset, val_ratio, test_ratio, seed
            )
            self.dataset = eval(split)
        else:
            self.dataset = dataset

        self.transform = ToTensor()
    
    @staticmethod 
    def _load_dataset(
        max_len: int = float('inf'), 
        dummy: bool = False,
        attributes = ['shape, object_hue'],
        conversion = None,
    ):
        if dummy:
            dataset = []
            to_pil = ToPILImage()
            for i in range(10000):
                for shape in range(0, 4):
                    for color in range(0, 4):
                        item = dict()
                        if 'shape' in attributes:
                            item['label_shape'] = shape
                        if 'object_color' in attributes:
                            item['label_object_hue'] = color
                        if 'floor_color' in attributes:
                            item['label_floor_hue'] = color
                        if 'orientation' in attributes:
                            item['label_orientation'] = color
                        if 'scale' in attributes:
                            item['label_scale'] = color
                        if 'wall_color' in attributes:
                            item['label_wall_hue'] = color
                        item['image'] = to_pil(torch.rand(3, 64, 64))
                        dataset.append(item)
        else:
            dataset = tfds.load(
                'shapes3d',
                split='train',
                shuffle_files=True,
            )
            dataset = tfds.as_numpy(dataset)

        data = []
        counts = defaultdict(lambda: defaultdict(int))
        for item in tqdm(dataset, desc='Iterating dataset'):
            not_in_att = False
            for att in list(conversion.keys()):
                if not item[f'label_{att}'] in conversion[att]:
                    not_in_att = True
                    break
            if not_in_att:
                continue
            
            for att in conversion.keys():
                counts[att][item[f'label_{att}']] += 1
                item[f'label_{att}'] = conversion[att][item[f'label_{att}']]
            data.append(item)

            if len(data) == max_len:
                return data

        for att, count in counts.items():
            print(f'{att} count', count)

        return data
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = self.transform(item['image'])

        atts = []
        for att in self.attributes:
            att_val = F.one_hot(
                torch.tensor(item[f'label_{att}']), 
                num_classes=Shape3DEvenSplitDataset.N_CLASSES,
            )
            atts.append(att_val)
        attributes = torch.stack(atts)

        return image, attributes

    @classmethod 
    def get_splits(
        cls,
        seed: int, 
        val_ratio: float, 
        test_ratio: float,
        max_len: int = float('inf'),
        dummy: bool = False,
        attributes = ['shape, object_hue'],
    ):
        label_conversion = {}
        rng = np.random.RandomState(seed)
        if 'shape' in attributes:
            idx = rng.choice(10, 4, replace=False)
            label_conversion['shape'] = {}
            for tgt, src in enumerate(idx):
                label_conversion['shape'][src] = tgt
        if 'object_hue' in attributes:
            idx = rng.choice(10, 4, replace=False)
            label_conversion['object_hue'] = {}
            for tgt, src in enumerate(idx):
                label_conversion['object_hue'][src] = tgt
        if 'floor_hue' in attributes:
            idx = rng.choice(10, 4, replace=False)
            label_conversion['floor_hue'] = {}
            for tgt, src in enumerate(idx):
                label_conversion['floor_hue'][src] = tgt
        if 'orientation' in attributes:
            label_conversion['orientation'] = {}
            label_conversion['orientation'][0] = 0
            label_conversion['orientation'][4] = 1
            label_conversion['orientation'][9] = 2
            label_conversion['orientation'][14] = 3
        if 'scale' in attributes:
            label_conversion['scale'] = {}
            label_conversion['scale'][0] = 0
            label_conversion['scale'][2] = 1
            label_conversion['scale'][4] = 2
            label_conversion['scale'][7] = 3
        if 'wall_hue' in attributes:
            idx = rng.choice(10, 4, replace=False)
            label_conversion['wall_hue'] = {}
            for tgt, src in enumerate(idx):
                label_conversion['wall_hue'][src] = tgt

        print(label_conversion)
        
        dataset = cls._load_dataset(max_len, dummy, attributes, label_conversion)
        splits = split_train_val_test(
            dataset, val_ratio, test_ratio, seed
        )
        return list(map(lambda x: cls(dataset=x, attributes=attributes), splits))
    
    @property 
    def n_attributes(self):
        return self[0][1].size(0)

    @property
    def n_values(self):
        return self[0][1].size(1)
    
    def __len__(self):
        return len(self.dataset)


class Shape3DUnseenSplitDataset(Dataset):
    N_CLASSES = 4

    def __init__(
        self, 
        split: str = None, 
        seed: int = None, 
        val_ratio: float = None, 
        test_ratio: float = None,
        dataset = None,
    ):
        if dataset is None:
            dataset = Shape3DUnseenSplitDataset._load_dataset()
            train, valid, test = split_train_val_test(
                list(dataset.keys()), val_ratio, test_ratio, seed
            )
            keys = eval(split)
            self.dataset = [dataset[key] for key in keys]
        else:
            self.dataset = dataset

        self.dataset = list(itertools.chain.from_iterable(self.dataset))
        self.transform = ToTensor()
    
    @staticmethod 
    def _load_dataset(max_len: int = float('inf'), dummy: bool = False):
        if dummy:
            dataset = []
            to_pil = ToPILImage()
            for i in range(10000):
                for shape in range(0, 4):
                    for color in range(0, 4):
                        item = dict()
                        item['label_shape'] = shape
                        item['label_object_hue'] = color
                        item['image'] = to_pil(torch.rand(3, 64, 64))
                        dataset.append(item)
        else:
            dataset = tfds.load(
                'shapes3d',
                split='train',
                shuffle_files=True,
            )
            dataset = tfds.as_numpy(dataset)

        shape_count = defaultdict(int)
        color_count = defaultdict(int)
        shape_color_count = defaultdict(int)
        comb = defaultdict(list)

        for item in tqdm(dataset, desc='Iterating dataset'):
            if item['label_shape'] < 4 and item['label_object_hue'] < 4:
                shape = item['label_shape']
                color = item['label_object_hue']
                shape_count[shape] += 1
                color_count[color] += 1
                shape_color_count[f'{shape}_{color}'] += 1
                comb[f'{shape}_{color}'].append(item)

        print()
        print('shape count', shape_count)
        print('color count', color_count)
        print('shpae color count', shape_color_count)

        return comb
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = self.transform(item['image'])

        shape = F.one_hot(
            torch.tensor(item['label_shape']), 
            num_classes=Shape3DUnseenSplitDataset.N_CLASSES,
        )

        color = item['label_object_hue']
        color = F.one_hot(
            torch.tensor(color), 
            num_classes=Shape3DUnseenSplitDataset.N_CLASSES,
        )
        attributes = torch.stack([shape, color])

        return image, attributes

    @classmethod 
    def get_splits(
        cls,
        seed: int, 
        val_ratio: float, 
        test_ratio: float,
        max_len: int = float('inf'),
        dummy: bool = False,
    ):
        dataset = cls._load_dataset(max_len, dummy=dummy)
        splits = split_train_val_test(
            list(dataset.keys()), val_ratio, test_ratio, seed
        )
        print('splits (train, valid, test):', splits)
        datasets = []
        for split in splits:
            datasets.append([dataset[key] for key in split])

        return list(map(lambda x: cls(dataset=x), datasets))

    @property 
    def n_attributes(self):
        return self[0][1].size(0)

    @property
    def n_values(self):
        return self[0][1].size(1)
    
    def __len__(self):
        return len(self.dataset)

        
if __name__ == "__main__":
    dataset = enumerate_attribute_value(n_attributes=2, n_values=10)
    train, holdout = split_holdout(dataset)
    print(len(train), len(holdout), len(dataset))

    print([x[0] for x in [train, holdout, dataset]])
