import logging
from collections import Counter, namedtuple
from typing import List, Union

import regex as re
import torch
from torchtext.data import RawField
from torchtext.vocab import Vocab
from transformers import PreTrainedTokenizer

from page.const import *
from page.util import find_numbers_in_text

NumericTextInstance = namedtuple('NumericTextInstance', ('token', 'pad', 'number', 'number_value'))
EquationInstance = namedtuple('EquationInstance', ('token', 'variable', 'number', 'template'))
SPIECE_UNDERLINE = '▁'


class TransformerTextField(RawField):
    def __init__(self, tokenizer: PreTrainedTokenizer, is_target=False, maximum_sequence_length: int = 510):
        super().__init__(is_target=is_target)
        self.tokenizer = tokenizer
        self.maximum_sequence_length = maximum_sequence_length
        self.dtype = torch.long
        self.num_token_id = len(tokenizer) + 1

        self.pretrained_tokens = tokenizer.vocab_size
        self.tokenizer.add_special_tokens({'additional_special_tokens': [NUM_TOKEN]})

    def preprocess(self, x):
        assert type(x) is str, "We expect a string for problem text."

        tokenized, numbers = find_numbers_in_text(x, append_number_token=True)
        tokenized = self.tokenizer.tokenize(tokenized.strip())

        return NumericTextInstance(tokenized, None, None, numbers)

    def process(self, batch: List[NumericTextInstance], device=None, **kwargs):
        return self.numericalize(self.pad(batch), device=device)

    def pad(self, minibatch: List[NumericTextInstance]):
        minibatch = list(minibatch)
        max_len = max(len(x.token) - x.count(NUM_TOKEN) for x in minibatch)
        if self.maximum_sequence_length:
            max_len = min(max_len, self.maximum_sequence_length)

        max_len_with_specials = max_len + 2
        padded = []
        numbers = []
        num_pos = []

        bos_token = self.tokenizer.bos_token
        eos_token = self.tokenizer.eos_token
        pad_token = self.tokenizer.pad_token

        for item in minibatch:
            tokens = []
            number_indicators = []
            number_index = 0

            for tok in item.token:
                if tok != NUM_TOKEN:
                    tokens.append(tok)
                    number_indicators.append(PAD_ID)
                else:
                    # Set number index until meet SPIECE_UNDERLINE.
                    for i in range(1, len(tokens) + 1):     # From -1 to -len(tok)
                        if tokens[-i] != SPIECE_UNDERLINE:
                            # We ignore 'space' token when marking the position of numbers.
                            number_indicators[-i] = number_index

                        if tokens[-i].startswith(SPIECE_UNDERLINE):
                            break

                    # Increase index of numbers
                    number_index += 1

            # Check whether any number token is discarded.
            assert max(number_indicators[max_len:], default=PAD_ID) == PAD_ID, \
                "A number token should not be discarded. You should increase the number of input tokens."
            assert number_index == len(item.number_value) and len(set(number_indicators)) - 1 == number_index, \
                "The extracted numbers are not the same! %s vs %s" % (number_index, len(item.number_value))

            # Build tokens
            tokens = [bos_token] + tokens[:max_len] + [eos_token]
            number_indicators = [PAD_ID] + number_indicators[:max_len] + [PAD_ID]

            # Build padding
            remain_len = max(0, max_len_with_specials - len(tokens))
            padded.append(tokens + [pad_token] * remain_len)
            num_pos.append(number_indicators + [PAD_ID] * remain_len)
            numbers.append(item.number_value)

        return NumericTextInstance(padded, None, num_pos, numbers)

    def numericalize(self, minibatch: NumericTextInstance, device=None):
        # Convert tokens to token ids.
        tokens = [self.tokenizer.convert_tokens_to_ids(tok) for tok in minibatch.token]
        token_ids = torch.as_tensor(tokens, dtype=self.dtype, device=device)
        # Attention mask: [False] if the position represents [PAD] token
        pad_masks = token_ids == self.tokenizer.pad_token_id
        # Number positions
        number_positions = torch.as_tensor(minibatch.number, dtype=self.dtype, device=device)

        # LongTensor [B, S] of indices, and BoolTensor [B, S] indicates whether padding or not.
        return NumericTextInstance(token_ids, pad_masks, number_positions, minibatch.number_value)

    def convert_ids_to_string(self, minibatch: torch.Tensor):
        return [self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(text.tolist()))
                for text in minibatch]


class EquationField(RawField):
    def __init__(self, variable_prefixes, number_prefixes, constant_prefix, number_postprocess=None,
                 is_target=False, maximum_sequence_length: int = 0, init_token='∵', eos_token='■', unk_token='?',
                 generate_num=False, generate_var=False):
        super().__init__(is_target=is_target)
        self.dtype = torch.long
        self.maximum_sequence_length = maximum_sequence_length

        self.variable_prefixes = set(variable_prefixes)
        self.number_perfixes = set(number_prefixes)
        self.constant_prefix = constant_prefix

        self.number_postprocess = number_postprocess if number_postprocess is not None \
            else lambda x: int(x.split('_')[-1])

        self.init_token = init_token
        self.eos_token = eos_token
        self.unk_token = unk_token
        self.special_tokens = {init_token, eos_token, unk_token}

        self.unk_token_id = 0
        self.init_token_id = 1
        self.eos_token_id = 2

        self.var_token_id = 5
        self.num_token_id = 6

        self.max_variable_index = 0
        self.max_number_index = 0

        self.can_generate_num = generate_num
        self.can_generate_var = generate_var

        self.equation_vocab = None
        self.templates = []

    @property
    def vocab_without_specials(self):
        itos = self.equation_vocab.itos
        return [tok for tok in itos if tok not in self.special_tokens]

    @property
    def template_tensors(self):
        return self.process(self.templates, template_only=True)

    @property
    def template_size(self):
        return len(self.templates)

    def tokenize(self, x):
        assert type(x) is list, "We expect [(TYPE, EQUATION), ...] " \
                                "where TYPE = 0, 1, 2 and EQUATION is a list of tokens."

        equation = []
        answers = set()
        memory_counter = 0

        for typ, expr in x:
            if type(expr) is str:
                expr = re.split('\\s+', expr.strip())

            if typ == PREP_KEY_EQN:
                equation += expr
            elif typ == PREP_KEY_MEM:
                equation += ['M_%s' % memory_counter] + expr + ['=']
                memory_counter += 1
            else:
                answers.update(expr)

        tokens = []
        variables = {}
        answer_remapped = set()     # TODO use this.

        for token in equation:
            # Normalize tokens
            if any(token.startswith(prefix) for prefix in self.variable_prefixes):
                # Remapping variables by order of appearance.
                if token not in variables:
                    assert len(variables) < EQ_VAR_MAX
                    variables[token] = len(variables)

                answer_var = token in answers
                # To preserve order, we padded indices with zeros at the front.
                token = EQ_VAR_PATTERN % variables[token]  # By the index of the first appearance.

                if answer_var:
                    answer_remapped.add(token)
            elif any(token.startswith(prefix) for prefix in self.number_perfixes):
                # To preserve order, we padded indices with zeros at the front.
                num_id = self.number_postprocess(token)
                assert num_id < EQ_NUM_MAX

                token = EQ_NUM_PATTERN % num_id
            elif token.startswith(self.constant_prefix):
                token = token.replace(self.constant_prefix, EQ_CON_PREFIX)

            tokens.append(token)

        return tokens

    def preprocess(self, tokens):
        # We expect this is a result of self.tokenize()
        res = EquationInstance([], [], [], None)

        for tokenid, token in enumerate(tokens):
            res.token.append(token)
            res.variable.append(PAD_ID)
            res.number.append(PAD_ID)

            if token.startswith(EQ_VAR_PREFIX):
                if not self.can_generate_var:
                    res.token[-1] = EQ_VAR_TOKEN
                res.variable[-1] = int(token[len(EQ_VAR_PREFIX):])
            elif token.startswith(EQ_NUM_PREFIX):
                if not self.can_generate_num:
                    res.token[-1] = EQ_NUM_TOKEN
                res.number[-1] = int(token[len(EQ_NUM_PREFIX):])

        return res

    def build_vocab(self, equations: list):
        equation_counter = Counter()

        for item in equations:
            # Build vocab for template or equation
            item = self.preprocess(self.tokenize(item))
            equation_counter.update(item.token)

            # Build templates
            if not any(item.token == tpl.token for tpl in self.templates):
                self.templates.append(item)

        # Make sure that <PAD> and <UNK> always at the front of the vocabulary.
        special_tokens = [self.unk_token, self.init_token, self.eos_token]
        # Enforce number and variable tokens are sorted with their indices.
        if self.can_generate_num:
            special_tokens += [EQ_NUM_PATTERN % i for i in range(EQ_NUM_MAX)]
        if self.can_generate_var:
            special_tokens += [EQ_VAR_PATTERN % i for i in range(EQ_VAR_MAX)]

        self.equation_vocab = Vocab(equation_counter, specials=special_tokens)
        logging.info("Equation vocab: %s", self.equation_vocab.stoi)
        logging.info("Number of equation templates: %s", len(self.templates))

        self.unk_token_id = self.equation_vocab.stoi[self.unk_token]
        self.init_token_id = self.equation_vocab.stoi[self.init_token]
        self.eos_token_id = self.equation_vocab.stoi[self.eos_token]

        # Assign the index of variable token.
        if self.can_generate_var:
            self.var_token_id = self.equation_vocab.stoi[EQ_VAR_PATTERN % 0]
        else:
            self.var_token_id = self.equation_vocab.stoi[EQ_VAR_TOKEN]
        if self.can_generate_num:
            self.num_token_id = self.equation_vocab.stoi[EQ_NUM_PATTERN % 0]
        else:
            self.num_token_id = self.equation_vocab.stoi[EQ_NUM_TOKEN]

    def process(self, batch, device=None, template_only=False, **kwargs):
        return self.numericalize(self.pad(batch, template_only=template_only),
                                 device=device, template_only=template_only)

    def pad(self, minibatch: List[EquationInstance], template_only=False):
        max_len = max(len(x.token) for x in minibatch)
        if self.maximum_sequence_length:
            max_len = min(max_len, self.maximum_sequence_length)

        padded = EquationInstance([], [], [], [])

        # [TOK, ANS, VAR, NUM]
        # We will ignore INIT/EOS/PAD in predicting non-token targets
        for item in minibatch:
            remain_len = max(0, max_len - len(item.token))
            padded.token.append([self.init_token] + item.token + [self.eos_token] + [None] * remain_len)

            if not template_only:
                padded.number.append([PAD_ID] + item.number + [PAD_ID] + [PAD_ID] * remain_len)

                # Build variable positions: as the first occurrence
                variable_sequence = [PAD_ID] + item.variable + [PAD_ID] + [PAD_ID] * remain_len
                variable_sequence = [variable_sequence.index(var) if var != PAD_ID else PAD_ID
                                     for var in variable_sequence]
                padded.variable.append(variable_sequence)

                # Build template index
                template_id = PAD_ID
                for tpl_id, tpl in enumerate(self.templates):
                    if tpl.token == item.token:
                        template_id = tpl_id
                        break

                padded.template.append(template_id)

        return padded

    def convert_token_to_id(self, token_or_list: Union[list, str]):
        if type(token_or_list) is list:
            return [self.convert_token_to_id(item) for item in token_or_list]
        elif token_or_list is None:
            return PAD_ID
        elif token_or_list in self.equation_vocab.stoi:
            return self.equation_vocab.stoi[token_or_list]
        else:
            return self.unk_token_id

    def numericalize(self, minibatch: EquationInstance, device=None, template_only=False):
        # Shape [B, T]
        tokens = torch.as_tensor(self.convert_token_to_id(minibatch.token), dtype=self.dtype, device=device)

        if not template_only:
            # Shape [B, T]
            variables = torch.as_tensor(minibatch.variable, dtype=self.dtype, device=device)
            # Shape [B, T]
            numbers = torch.as_tensor(minibatch.number, dtype=self.dtype, device=device)
            # Template index [B]
            templates = torch.as_tensor(minibatch.template, dtype=self.dtype, device=device)

            return EquationInstance(tokens, variables, numbers, templates)
        else:
            return EquationInstance(tokens, None, None, None)

    def convert_to_template(self, ids: torch.Tensor, var_index: torch.Tensor = None, num_index: torch.Tensor = None):
        tokens = [self.equation_vocab.itos[token] if token >= 0 else '' for token in ids.tolist()]
        begin = 0
        length = len(tokens)

        if tokens[0] == self.init_token:
            # Ignore the first INIT token.
            tokens = tokens[1:]
            begin = 1
            length -= 1

        if self.eos_token in tokens:
            # Don't generate tokens after the EOS token.
            length = tokens.index(self.eos_token)
            tokens = tokens[:length]

        if var_index is not None:
            # Trim tensor positions
            var_index = var_index[begin:begin+length]

            # Adjust variable pointing positions from 'begin'
            var_index -= torch.full_like(var_index, fill_value=begin).masked_fill(var_index == PAD_ID, 0)

        if num_index is not None:
            # Trim tensor positions
            num_index = num_index[begin:begin + length]

        return tokens, var_index, num_index

    def convert_to_aligned_equation(self, template: Union[List[str], torch.Tensor], var_index: torch.Tensor = None,
                                    num_index: torch.Tensor = None) -> List[str]:
        if isinstance(template, torch.Tensor):
            template, var_index, num_index = \
                self.convert_to_template(template, num_index=num_index, var_index=var_index)

        assert var_index is None or var_index.numel() == len(template)
        assert num_index is None or num_index.numel() == len(template)

        variables = {}
        var_counter = 0

        string_tokens = []
        for i, token in enumerate(template):
            if token == EQ_VAR_TOKEN:
                var = var_index[i].item()
                if var not in variables:
                    variables[var] = var_counter
                    var_counter += 1

                # Enable chaining across variables, i.e. [29]->[3]->[1]
                variables[i] = variables[var]
                token = EQ_VAR_PATTERN % variables[i]
            elif token == EQ_NUM_TOKEN:
                token = EQ_NUM_PATTERN % num_index[i].item()
            #elif token.startswith(EQ_CON_PREFIX) and token != EQ_CON_UNK:
            #    # Make constant tokens readable.
            #    token = token[2:].replace('_', '.')

            string_tokens.append(token)

        return string_tokens


def get_fields(tokenizer: PreTrainedTokenizer,
               variable_prefixes=None, number_prefixes=None, constant_prefix=None,
               generate_num=False, generate_var=False):
    if not variable_prefixes:
        variable_prefixes = ['X_', 'M_']
    if not number_prefixes:
        number_prefixes = ['N_']
    if not constant_prefix:
        constant_prefix = 'C_'

    return TransformerTextField(tokenizer), \
           EquationField(variable_prefixes, number_prefixes, constant_prefix, is_target=True,
                         generate_num=generate_num, generate_var=generate_var)


__all__ = ['TransformerTextField', 'EquationField', 'get_fields', 'EquationInstance', 'NumericTextInstance']
