import logging
from collections import Counter
from typing import List, Union, Tuple, Any

import regex as re
import torch
from torchtext.data import RawField
from torchtext.vocab import Vocab

from page.const import *


def postfix_parser(equation: List[Union[str, Tuple[str, Any]]], memory: list) -> int:
    stack = []

    for tok in equation:
        if tok in OPERATORS:
            op = OPERATORS[tok]
            arity = op['arity']

            # Retrieve arguments
            args = stack[-arity:]
            stack = stack[:-arity]

            stack.append((ARG_MEM, len(memory)))
            memory.append((tok, args))
        else:
            stack.append(tok)

    return len(stack)


class TokenEquationField(RawField):
    def __init__(self, variable_prefixes, number_prefixes, constant_prefix, number_postprocess=None,
                 is_target=False, generate_all=False):
        super().__init__(is_target=is_target)
        self.dtype = torch.long

        self.variable_prefixes = set(variable_prefixes)
        self.number_perfixes = set(number_prefixes)
        self.constant_prefix = constant_prefix

        self.number_postprocess = number_postprocess if number_postprocess is not None \
            else lambda string: int(string.split('_')[-1])

        self.unk_token_id = 0
        self.init_token_id = 1
        self.eos_token_id = 2

        self.var_token_id = 5
        self.num_token_id = 6

        self.max_variable_index = 0
        self.max_number_index = 0

        self.generate_all = generate_all
        self.token_vocab = None

    @property
    def has_empty_vocab(self):
        return self.token_vocab is None

    @property
    def eq_sign_id(self):
        return self.token_vocab.stoi.get('=', SEQ_UNK_TOK_ID)

    def preprocess(self, item):
        assert type(item) is list, "We expect [(TYPE, EQUATION), ...] " \
                                   "where TYPE = 0, 1, 2 and EQUATION is a list of tokens."

        tokens = []
        memory_counter = 0
        variables = {}

        for typ, expr in item:
            if type(expr) is str:
                expr = re.split('\\s+', expr.strip())

            if typ == PREP_KEY_ANS:
                continue
            elif typ == PREP_KEY_MEM:
                expr = ['M_%s' % memory_counter] + expr + ['=']
                memory_counter += 1

            for token in expr:
                # Normalize tokens
                if any(token.startswith(prefix) for prefix in self.variable_prefixes):
                    # Remapping variables by order of appearance.
                    if token not in variables:
                        variables[token] = len(variables)

                    position = variables[token]
                    token = FORMAT_VAR % position  # By the index of the first appearance.
                    tokens.append((TOK_TOK_ID, token))
                elif any(token.startswith(prefix) for prefix in self.number_perfixes):
                    # To preserve order, we padded indices with zeros at the front.
                    position = self.number_postprocess(token)
                    if self.generate_all:
                        tokens.append((TOK_TOK_ID, FORMAT_NUM % position))
                    else:
                        tokens.append((TOK_NUM_ID, position))
                else:
                    if token.startswith(self.constant_prefix):
                        token = token.replace(self.constant_prefix, EQ_CON_PREFIX)
                    tokens.append((TOK_TOK_ID, token))

        return tokens

    def build_vocab(self, equations: list):
        equation_counter = Counter()

        for item in equations:
            # Build vocab for template or equation
            equation_counter.update([tok for typ, tok in self.preprocess(item) if typ == TOK_TOK_ID])

        # Make sure that BOS and EOS always at the front of the vocabulary.
        special_tokens = SEQ_TOKENS.copy()
        if self.generate_all:
            # Enforce number and variable tokens are sorted with their indices.
            special_tokens += [FORMAT_NUM % i for i in range(NUM_MAX)]
            special_tokens += [FORMAT_VAR % i for i in range(VAR_MAX)]

        self.token_vocab = Vocab(equation_counter, specials=special_tokens)

    def process(self, batch, device=None, **kwargs):
        return self.numericalize(self.pad(batch), device=device)

    def pad(self, minibatch: List[List[Tuple[int, Any]]]):
        max_len = max(len(item) for item in minibatch) + 2  # 2 = EOE/BOE
        padded_batch = []

        # We will ignore INIT/EOS/PAD in predicting non-token targets
        for item in minibatch:
            # Split item
            item = [(typ,) + ((self.token_vocab.stoi.get(tok, SEQ_UNK_TOK_ID), int(tok == '='))
                              if typ == TOK_TOK_ID else (tok, 0))
                    for typ, tok in item]

            # Build padded item
            padded_item = [(TOK_TOK_ID, SEQ_NEW_EQN_ID, 0)] + item + [(TOK_TOK_ID, SEQ_END_EQN_ID, 0)]
            padded_item += [(PAD_ID, PAD_ID, 0)] * max(0, max_len - len(padded_item))

            padded_batch.append(padded_item)

        return padded_batch

    def numericalize(self, minibatch: List[List[Tuple[int, int]]], device=None):
        # Shape [B, T, 3]
        return torch.as_tensor(minibatch, dtype=self.dtype, device=device)

    def convert_ids_to_equations(self, minibatch: torch.Tensor) -> List[List[str]]:
        equation_batch = []
        for item in minibatch:
            equation = []

            for i, (typ, token, _) in enumerate(item.tolist()):
                if typ == TOK_TOK_ID:
                    token = self.token_vocab.itos[token]
                    if token == SEQ_NEW_EQN:
                        equation.clear()
                        continue
                    elif token == SEQ_END_EQN:
                        break
                elif typ == TOK_NUM_ID:
                    token = FORMAT_NUM % token
                else:
                    break

                equation.append(token)

            equation_batch.append(equation)
        return equation_batch


class OperationEquationField(RawField):
    def __init__(self, variable_prefixes, number_prefixes, constant_prefix,
                 number_postprocess=None, is_target=False, max_arity: int = 2, force_generation: bool = False):
        super().__init__(is_target=is_target)
        self.dtype = torch.long

        self.variable_prefixes = set(variable_prefixes)
        self.number_perfixes = set(number_prefixes)
        self.constant_prefix = constant_prefix
        self.force_generation = force_generation

        self.number_postprocess = number_postprocess if number_postprocess is not None \
            else lambda string: int(string.split('_')[-1])

        self.function_word_vocab = None
        self.constant_word_vocab = None
        self.max_arity = max_arity

    @property
    def has_empty_vocab(self):
        return self.function_word_vocab is None

    @property
    def function_arities(self):
        return {i: OPERATORS[f]['arity'] for i, f in enumerate(self.function_word_vocab.itos) if i >= len(FUN_TOKENS)}

    def preprocess(self, formulae) -> List[Tuple[str, list]]:
        assert type(formulae) is list, "We expect [(TYPE, EQUATION), ...] " \
                                       "where TYPE = 0, 1, 2 and EQUATION is a list of tokens."

        variables = []
        memories = []
        for typ, expr in formulae:
            if type(expr) is str:
                expr = re.split('\\s+', expr.strip())

            # Replace number, const, variable tokens
            normalized = []
            for token in expr:
                if any(token.startswith(prefix) for prefix in self.variable_prefixes):
                    # Case 1: Variable
                    if token not in variables:
                        variables.append(token)

                    # Set as negative numbers, since we don't know how many variables are in the list.
                    normalized.append((ARG_MEM, - variables.index(token) - 1))
                elif any(token.startswith(prefix) for prefix in self.number_perfixes):
                    # Case 2: Number
                    if self.force_generation:
                        # Treat number indicator as constant.
                        normalized.append((ARG_NUM, FORMAT_NUM % self.number_postprocess(token)))
                    else:
                        normalized.append((ARG_NUM, self.number_postprocess(token)))
                elif token.startswith(self.constant_prefix):
                    normalized.append((ARG_CON, token.replace(self.constant_prefix, EQ_CON_PREFIX)))
                else:
                    normalized.append(token)

            # Build memory representation
            if typ == PREP_KEY_EQN:
                stack_len = postfix_parser(normalized, memories)
                assert stack_len == 1, "Equation is not correct! '%s'" % expr
            elif typ == PREP_KEY_MEM:
                stack_len = postfix_parser(normalized, memories)
                assert stack_len == 1, "Intermediate representation of memory is not correct! '%s'" % expr

        # Reconstruct memory indices
        var_length = len(variables)
        preprocessed = [(FUN_NEW_VAR, []) for _ in range(var_length)]
        for func, arguments in memories:
            new_arguments = []
            for typ, tok in arguments:
                if typ == ARG_MEM:
                    tok = tok + var_length if tok >= 0 else -(tok + 1)

                    if self.force_generation:
                        # Build as a string
                        tok = FORMAT_MEM % tok

                new_arguments.append((typ, tok))

            preprocessed.append((func, new_arguments))

        return preprocessed

    def build_vocab(self, equations: list):
        function_counter = Counter()
        constant_counter = Counter()

        constant_specials = [ARG_UNK]
        if self.force_generation:
            # Enforce index of numbers become 1 ~ NUM_MAX
            constant_specials += [FORMAT_NUM % i for i in range(NUM_MAX)]
            # Enforce index of memory indices become NUM_MAX+1 ~ NUM_MAX+MEM_MAX
            constant_specials += [FORMAT_MEM % i for i in range(MEM_MAX)]

        for item in equations:
            # Equation is not tokenized
            item = self.preprocess(item)
            functions, arguments = zip(*item)
            function_counter.update(functions)
            for args in arguments:
                constant_counter.update([const for t, const in args if t == ARG_CON or self.force_generation])

        self.function_word_vocab = Vocab(function_counter, specials=FUN_TOKENS_WITH_EQ.copy())
        self.constant_word_vocab = Vocab(constant_counter, specials=constant_specials)

    def process(self, batch: List[List[Tuple[str, list]]], device=None, **kwargs):
        return self.numericalize(self.pad(batch), device=device)

    def pad(self, minibatch: List[List[Tuple[str, list]]]) -> List[List[Tuple[str, list]]]:
        max_len = max(len(item) for item in minibatch) + 2  # 2 = BOE/EOE
        padded_batch = []

        # We will ignore INIT/EOS/PAD in predicting non-token targets
        max_arity_pad = [(None, None)] * self.max_arity

        for item in minibatch:
            padded_item = [(FUN_NEW_EQN, max_arity_pad)]

            for func, args in item:
                # We had to pad arguments.
                remain_arity = max(0, self.max_arity - len(args))
                args = args + max_arity_pad[:remain_arity]

                padded_item.append((func, args))

            padded_item.append((FUN_END_EQN, max_arity_pad))
            padded_item += [(None, max_arity_pad)] * max(0, max_len - len(padded_item))

            # Add batched item
            padded_batch.append(padded_item)

        return padded_batch

    def convert_token_to_id(self, memory_item: Tuple[str, list]):
        func, args = memory_item

        func = PAD_ID if func is None else self.function_word_vocab.stoi[func]
        new_args = []
        for t, i in args:
            if t is None:
                new_args += [PAD_ID, PAD_ID]
            else:
                new_args.append(ARG_TOKENS.index(t))
                if t == ARG_CON or self.force_generation:
                    new_args.append(self.constant_word_vocab.stoi.get(i, ARG_UNK_ID))
                else:
                    new_args.append(i)

        return [func] + new_args

    def numericalize(self, minibatch: List[List[Tuple[str, list]]], device=None) -> torch.Tensor:
        # Shape [B, T, 1 + 2*ArgN]
        minibatch = [[self.convert_token_to_id(token) for token in item] for item in minibatch]
        return torch.as_tensor(minibatch, dtype=self.dtype, device=device)

    def convert_ids_to_memories(self, minibatch: torch.Tensor) -> List[List[Tuple[str, list]]]:
        memory_batch = []

        for item in minibatch.tolist():
            memory_item = []

            for token in item:
                func = self.function_word_vocab.itos[token[0]]
                if func == FUN_NEW_EQN:
                    memory_item.clear()
                    continue

                if func == FUN_END_EQN:
                    break

                args = []
                for i in range(1, len(token), 2):
                    t = token[i]
                    if t != PAD_ID:
                        t = ARG_TOKENS[t]
                        arg = token[i + 1]
                        if t == ARG_CON or self.force_generation:
                            arg = self.constant_word_vocab.itos[arg]

                        if type(arg) is str and arg.startswith(MEM_PREFIX):
                            args.append((ARG_MEM, int(arg[2:])))
                        else:
                            args.append((t, arg))

                memory_item.append((func, args))

            memory_batch.append(memory_item)

        return memory_batch

    def convert_ids_to_equations(self, minibatch: torch.Tensor) -> List[List[str]]:
        memory_batch = self.convert_ids_to_memories(minibatch)
        equation_batch = []

        for item in memory_batch:
            computation_history = []
            memory_used = []

            for func, args in item:
                computation = []

                if func == FUN_NEW_VAR:
                    computation.append(FORMAT_VAR % len(computation_history))
                else:
                    for t, arg in args:
                        if t == ARG_NUM and not self.force_generation:
                            computation.append(FORMAT_NUM % arg)
                        elif t == ARG_MEM:
                            if arg < len(computation_history):
                                computation += computation_history[arg]
                                memory_used[arg] = True
                            else:
                                computation.append(ARG_UNK)
                        else:
                            computation.append(arg)

                    # Postfix representation
                    computation.append(func)

                computation_history.append(computation)
                memory_used.append(False)

            computation_history = [equation for used, equation in zip(memory_used, computation_history) if not used]
            equation_batch.append(sum(computation_history, []))

        return equation_batch


__all__ = ['TokenEquationField', 'OperationEquationField']
