from collections import namedtuple
from typing import List, Dict, Tuple
from pathlib import Path
from numpy import median
import json
import logging

from torch import Tensor, load as load_data, save as save_data
from torchtext.data import batch
from torchtext.data.utils import RandomShuffler

from .eq_field import TokenEquationField, OperationEquationField
from .text_field import TransformerTextField
from page.const import FUN_NEW_VAR

ProblemInstance = namedtuple('ProblemInstance', ('text', 'token_gen', 'token_ptr', 'tuple_gen', 'tuple_ptr',
                                                 'index', 'expected'))


def _get_token_length(item: ProblemInstance) -> Tuple[int, ...]:
    # Before building a batch, each equation instance is a tuple of (tokens, answer_vars)
    # This will be run before preprocess & numericalization
    return len(item.text.token) + 2, len(item.token_gen) + 2, len(item.tuple_gen) + 2


def _get_item_size(item: ProblemInstance, size: int, prev_size: int) -> int:
    prev_max = prev_size // (size - 1) if size > 1 else 0
    new_max = max(prev_max, *_get_token_length(item))
    return new_max * size


class TokenBatchIterator(object):
    def __init__(self, dataset: str, problem_field: TransformerTextField,
                 token_gen_field: TokenEquationField, token_ptr_field: TokenEquationField,
                 tuple_gen_field: OperationEquationField, tuple_ptr_field: OperationEquationField,
                 token_batch_size: int = 4096, testing_purpose: bool = False):
        self._batch_size = token_batch_size

        self.problem_field = problem_field
        self.token_gen_field = token_gen_field
        self.token_ptr_field = token_ptr_field
        self.tuple_gen_field = tuple_gen_field
        self.tuple_ptr_field = tuple_ptr_field

        self._testing_purpose = testing_purpose

        self._bootstrap = None
        self._batches = None
        self._iterator = None
        self._random = RandomShuffler() if not testing_purpose else None

        # Read dataset
        cached_path = Path(dataset + '.cached')
        if cached_path.exists():
            cache = load_data(cached_path)
            self._dataset = cache['dataset']
            vocab_cache = cache['vocab']

            if self.token_gen_field.has_empty_vocab:
                self.token_gen_field.token_vocab = vocab_cache['token']
            if self.token_ptr_field.has_empty_vocab:
                self.token_ptr_field.token_vocab = vocab_cache['token_nonum']
            if self.tuple_gen_field.has_empty_vocab:
                self.tuple_gen_field.function_word_vocab = vocab_cache['func']
                self.tuple_gen_field.constant_word_vocab = vocab_cache['arg']
            if self.tuple_ptr_field.has_empty_vocab:
                self.tuple_ptr_field.function_word_vocab = vocab_cache['func']
                self.tuple_ptr_field.constant_word_vocab = vocab_cache['const']
        else:
            _dataset = []
            _items_for_vocab = []
            with Path(dataset).open('r+t', encoding='UTF-8') as fp:
                for line in fp.readlines():
                    line = line.strip()
                    if not line:
                        continue

                    item = json.loads(line)
                    _dataset.append((item['text'], item['expr'], item['id'], item['answer']))
                    _items_for_vocab.append(item['expr'])

            # Build vocab if required
            if self.token_gen_field.has_empty_vocab:
                self.token_gen_field.build_vocab(_items_for_vocab)
            if self.token_ptr_field.has_empty_vocab:
                self.token_ptr_field.build_vocab(_items_for_vocab)
            if self.tuple_gen_field.has_empty_vocab:
                self.tuple_gen_field.build_vocab(_items_for_vocab)
            if self.tuple_ptr_field.has_empty_vocab:
                self.tuple_ptr_field.build_vocab(_items_for_vocab)

            self._dataset = [self._tokenize_equation(item) for item in _dataset]

            # Cache dataset and vocab.
            save_data({'dataset': self._dataset,
                       'vocab': {
                           'token': self.token_gen_field.token_vocab,
                           'token_nonum': self.token_ptr_field.token_vocab,
                           'func': self.tuple_gen_field.function_word_vocab,
                           'arg': self.tuple_gen_field.constant_word_vocab,
                           'const': self.tuple_ptr_field.constant_word_vocab,
                       }}, cached_path)

        self._examples = len(self._dataset)
        self.reset()

    def get_rng_state(self):
        return self._random.random_state

    def set_rng_state(self, state):
        self._random = RandomShuffler(state)
        self.reset()

    def print_item_statistics(self, logger):
        item_stats = self.get_item_statistics()

        lengths = item_stats['text_token']
        logger.info('Information about lengths of text sequences: Range %s - %s (mean: %s)',
                    min(lengths), max(lengths), sum(lengths) / self._examples)

        lengths = item_stats['eqn_op_token']
        logger.info('Information about lengths of token unit sequences: Range %s - %s (mean: %s)',
                    min(lengths), max(lengths), sum(lengths) / self._examples)
        logger.info('Token unit vocabulary (no-pointer): %s', self.token_gen_field.token_vocab.itos)
        logger.info('Token unit vocabulary (pointer): %s', self.token_ptr_field.token_vocab.itos)

        lengths = item_stats['eqn_expr_token']
        logger.info('Information about lengths of operator unit sequences: Range %s - %s (mean: %s)',
                    min(lengths), max(lengths), sum(lengths) / self._examples)

        logger.info('Operator unit vocabulary (operator): %s', self.tuple_gen_field.function_word_vocab.itos)
        logger.info('Operator unit vocabulary (operand): %s', self.tuple_gen_field.constant_word_vocab.itos)
        logger.info('Operator unit vocabulary (constant): %s', self.tuple_ptr_field.constant_word_vocab.itos)

    def get_item_statistics(self):
        return dict(
            text_token=[len(item.text.token) for item in self._dataset],
            text_number=[len(item.text.number_value) for item in self._dataset],
            eqn_op_token=[len(item.token_gen) for item in self._dataset],
            eqn_expr_token=[len(item.tuple_gen) for item in self._dataset],
            eqn_unk=[sum(func == FUN_NEW_VAR for func, _ in item.tuple_gen) for item in self._dataset]
        )

    def _tokenize_equation(self, item) -> ProblemInstance:
        return ProblemInstance(
            text=self.problem_field.preprocess(item[0]),
            token_gen=self.token_gen_field.preprocess(item[1]),
            token_ptr=self.token_ptr_field.preprocess(item[1]),
            tuple_gen=self.tuple_gen_field.preprocess(item[1]),
            tuple_ptr=self.tuple_ptr_field.preprocess(item[1]),
            index=item[2],
            expected=item[3]
        )

    def reset(self):
        self._batches = list(self._generate_batches())

        if not self._testing_purpose:
            self._iterator = iter(self._random(self._batches))
        else:
            # Preserve the order when testing.
            self._iterator = iter(self._batches)

    def _generate_batches(self):
        max_token_size = 0
        items = []
        dataset = self._dataset

        chunks = list(batch(dataset, self._batch_size * 1024, _get_item_size))
        for batch_group in chunks:
            # Sort within each batch-group
            for item in sorted(batch_group, key=_get_token_length):
                items.append(item)

                # Compute the max-length key.
                token_size = max(_get_token_length(item))
                max_token_size = max(max_token_size, token_size)

                # If the size exceeded, flush it.
                batch_size = max_token_size * len(items)

                if batch_size == self._batch_size:
                    yield self._concatenate_batch(items)
                    items = []
                    max_token_size = 0
                elif batch_size > self._batch_size:
                    yield self._concatenate_batch(items[:-1])
                    items = items[-1:]
                    max_token_size = token_size

            # If batch is empty, clear the max length information
            if items:
                yield self._concatenate_batch(items)

    def _concatenate_batch(self, items: List[ProblemInstance]) -> ProblemInstance:
        kwargs = {}
        for item in items:
            for key in ProblemInstance._fields:
                if key not in kwargs:
                    kwargs[key] = []

                kwargs[key].append(getattr(item, key))

        kwargs['text'] = self.problem_field.process(kwargs['text'])
        kwargs['token_gen'] = self.token_gen_field.process(kwargs['token_gen'])
        kwargs['token_ptr'] = self.token_ptr_field.process(kwargs['token_ptr'])
        kwargs['tuple_gen'] = self.tuple_gen_field.process(kwargs['tuple_gen'])
        kwargs['tuple_ptr'] = self.tuple_ptr_field.process(kwargs['tuple_ptr'])

        return ProblemInstance(**kwargs)

    def __len__(self):
        return len(self._batches)

    def __iter__(self):
        return self

    def __next__(self) -> Dict[str, Tensor]:
        try:
            return next(self._iterator)
        except StopIteration as e:
            if not self._testing_purpose:
                # Re-initialize iterator when iterator is empty
                self.reset()
                return self.__next__()
            else:
                raise e


__all__ = ['TokenBatchIterator', 'ProblemInstance']
