# -*- coding: utf-8 -*-

"""Class for grouping input data."""

import itertools
from collections import Counter, OrderedDict
from typing import Optional

import numpy as np

from ..common.dataclass_options import OptionsBase, argfield
from ..common.logger import LOGGER
from ..common.utils import MethodFactory
from .dataset import iter_batches, iter_sub_batches

BUCKET_FACTORY = MethodFactory()


def group_samples(length_counter, num_buckets, batch_item_count=None, batch_sample_count=None):
    assert batch_item_count is None or batch_sample_count is None
    assert batch_item_count is not None or batch_sample_count is not None

    if len(length_counter) < num_buckets:
        num_buckets = len(length_counter)
    lengths = np.array([0] + sorted(length_counter))
    length_counts = np.array([(length_counter[length] if length else 0) for length in lengths])
    # $\sum_{length}$
    sigma_counts = np.cumsum(length_counts)
    # $\sum_{length \times count}$
    sigma_length_x_count = np.cumsum(lengths * length_counts)

    dp = np.zeros((len(lengths), num_buckets + 1), dtype=np.int64)
    bp = np.zeros((len(lengths), num_buckets + 1), dtype=np.int64)

    for bucket_count in range(1, num_buckets + 1):
        for max_length_index in range(bucket_count, len(lengths)):
            max_length = lengths[max_length_index]

            min_padding_size = 2 ** 63 - 1
            min_previous_length_index = -1
            for previous_length_index in range(bucket_count - 1, max_length_index):
                if previous_length_index != 0 and bucket_count == 1:
                    continue

                num_samples_in_bucket = (sigma_counts[max_length_index] -
                                         sigma_counts[previous_length_index])
                padding_size = dp[previous_length_index, bucket_count - 1]
                padding_size += num_samples_in_bucket * max_length
                padding_size -= (sigma_length_x_count[max_length_index] -
                                 sigma_length_x_count[previous_length_index])

                if batch_sample_count is not None:
                    size = batch_sample_count
                else:
                    size = max(1, batch_item_count // max_length)
                # last batch is not fully filled
                padding_size += (size - num_samples_in_bucket % size) * max_length

                if padding_size < min_padding_size:
                    min_padding_size = padding_size
                    min_previous_length_index = previous_length_index

            assert min_previous_length_index != -1, \
                f'dp[{lengths[max_length_index]}, {bucket_count}] <- ???'
            dp[max_length_index, bucket_count] = min_padding_size
            bp[max_length_index, bucket_count] = min_previous_length_index

    buckets = [lengths[-1]]
    max_length_index = len(lengths) - 1
    bucket_count = min(range(1, num_buckets + 1),
                       key=lambda bucket_count: dp[max_length_index, bucket_count])
    while bucket_count > 1:
        max_length_index = bp[max_length_index, bucket_count]
        buckets.append(lengths[max_length_index])
        bucket_count -= 1

    buckets.append(0)
    buckets.reverse()
    return buckets


def samples_to_batch(batch_samples, sort_fn, return_original, **kwargs):
    if sort_fn is not None:
        batch_samples.sort(key=sort_fn, reverse=True)

    batch = batch_samples[0].pack_to_batch(batch_samples, **kwargs)
    if return_original:
        return batch_samples, batch
    return batch


def samples_to_sub_batches(batch_samples, batch_item_count, batch_sample_count,
                           sort_fn, return_original, **kwargs):
    return [samples_to_batch(sub_batch_samples, sort_fn, return_original, **kwargs)
            for sub_batch_samples in iter_sub_batches(batch_samples,
                                                      batch_sample_count, batch_item_count)]


@BUCKET_FACTORY.register('simple')
class SimpleBuckets:
    """Split samples to batches by a single sort function"""

    def __init__(self, original_objects, preprocess_fn,
                 batch_item_count, batch_sample_count=16384,
                 original_length_fn=len, **_kwargs):
        self.batch_item_count = batch_item_count
        self.batch_sample_count = batch_sample_count

        self.original_objects = original_objects
        self.data_samples = [None] * len(original_objects)

        self.preprocess_fn = preprocess_fn
        self.length_fn = original_length_fn

    def get_sample(self, index, **kwargs):
        sample = self.data_samples[index]
        if sample is None:
            sample = self.data_samples[index] = \
                self.preprocess_fn(index, self.original_objects[index], **kwargs)
        return sample

    def __len__(self):
        return len(self.original_objects)

    def iter_batch_samples(self, shuffle, **kwargs):
        indices = np.arange(len(self.original_objects))
        if shuffle:
            np.random.shuffle(indices)

        for _, _, batch_indices in iter_batches(indices, self.batch_sample_count):
            yield [self.get_sample(index, **kwargs) for index in batch_indices]

    def generate_batches(self, shuffle=False, return_original=False,
                         sort_fn=None, use_sub_batch=False, **kwargs):
        # TODO check correctness
        for batch_samples in self.iter_batch_samples(shuffle, **kwargs):
            if use_sub_batch:
                yield from samples_to_sub_batches(batch_samples,
                                                  self.batch_item_count, self.batch_sample_count,
                                                  sort_fn, return_original, **kwargs)
            else:
                yield samples_to_batch(batch_samples, sort_fn, return_original, **kwargs)


@BUCKET_FACTORY.register('stream')
class StreamBuckets(SimpleBuckets):

    def iter_batch_samples(self, shuffle, **kwargs):
        indices = np.arange(len(self.original_objects))
        if shuffle:
            np.random.shuffle(indices)

        batch_samples = []
        max_length_in_batch = -1
        for index in indices:
            original_object = self.original_objects[index]
            sample = self.get_sample(index, **kwargs)

            max_length = max(self.length_fn(original_object), max_length_in_batch)
            count = len(batch_samples)

            if count >= self.batch_sample_count or max_length * (count + 1) > self.batch_item_count:
                yield batch_samples

                batch_samples = [sample]
                max_length_in_batch = self.length_fn(original_object)
            else:
                batch_samples.append(sample)
                max_length_in_batch = max_length

        if batch_samples:
            yield batch_samples


@BUCKET_FACTORY.register('min_padding')
class MinPaddingBuckets(SimpleBuckets):
    """
    Group samples into similar lengths and generate batches.
    """

    def __init__(self, original_objects, preprocess_fn, num_buckets, logger=LOGGER, **kwargs):
        super().__init__(original_objects, preprocess_fn, **kwargs)

        length_counter = Counter(map(self.length_fn, original_objects))
        lengths = group_samples(length_counter, num_buckets,
                                self.batch_item_count, self.batch_sample_count)

        length_to_padded_length = [0]
        for last_length, length in zip(lengths, lengths[1:]):
            length_to_padded_length.extend(itertools.repeat(length, length - last_length))

        self.buckets = OrderedDict({length: [] for length in lengths[1:]})
        for index, sample in enumerate(self.original_objects):
            self.buckets[length_to_padded_length[self.length_fn(sample)]].append(index)

        logger.info('use %s buckets: %s',
                    len(lengths) - 1, {k: len(v) for k, v in self.buckets.items()})

    def iter_batch_samples(self, shuffle, **kwargs):
        # all_indices = set()

        buckets = []
        for padded_length, indices in self.buckets.items():
            batch_sample_count = self.batch_sample_count
            if batch_sample_count is None:
                batch_sample_count = max(self.batch_item_count // padded_length, 1)

            if shuffle:
                np.random.shuffle(indices)

            for _, _, batch_indices in iter_batches(indices, batch_sample_count):
                # all_indices.update(batch_indices)
                buckets.append((padded_length, batch_indices))

        # assert len(all_indices) == len(self.original_objects)

        if shuffle:
            np.random.shuffle(buckets)

        for _, batch_indices in buckets:
            yield [self.get_sample(index, **kwargs) for index in batch_indices]


@BUCKET_FACTORY.register('sorted')
class SortedBuckets(SimpleBuckets):
    """group sample by sample length"""

    def iter_batch_samples(self, shuffle, **kwargs):
        original_objects = self.original_objects
        indices = sorted(range(len(original_objects)),
                         key=lambda index: self.length_fn(original_objects[index]))

        batches = [batch_indices
                   for _, _, batch_indices in iter_batches(indices, self.batch_sample_count)]
        if shuffle:
            np.random.shuffle(batches)

        for batch_indices in batches:
            if shuffle:
                np.random.shuffle(batch_indices)
            yield [self.get_sample(index, **kwargs) for index in batch_indices]


class BucketOptions(OptionsBase, active_time='both'):
    """options for grouping input data."""
    batch_sample_count: Optional[int] = None
    batch_item_count: Optional[int] = 5000
    num_buckets: int = 100

    bucket_type: str = argfield('min_padding', choices=BUCKET_FACTORY.keys())

    def create(self, original_objects, preprocess_fn, **kwargs):
        return BUCKET_FACTORY.invoke(self.bucket_type, original_objects, preprocess_fn,
                                     num_buckets=self.num_buckets,
                                     batch_item_count=self.batch_item_count,
                                     batch_sample_count=self.batch_sample_count,
                                     **kwargs)
