# -*- coding: utf-8 -*-

import os
from typing import Optional, Union

import torch
import torch.nn as nn

from ...common.dataclass_options import ExistFile, OptionsBase, argfield
from ...common.logger import LOGGER
from ...common.utils import lazy_property
from ...data.vocab import Tokens
from ..input_plugin_base import InputPluginBase
from ..layers.dropout import FeatureDropout
from ..utils import broadcast_gather, pad_and_stack_1d

# TODO ??? dropout during training

BERT_TOKEN_MAPPING = {
    '-LRB-': '(',
    '-RRB-': ')',
    '-LCB-': '{',
    '-RCB-': '}',
    '-LSB-': '[',
    '-RSB-': ']',
    '``': '"',
    "''": '"',
    '`': "'",
    '«': '"',
    '»': '"',
    '‘': "'",
    '’': "'",
    '“': '"',
    '”': '"',
    '„': '"',
    '‹': "'",
    '›': "'",
}


def load_bert(options):
    from transformers import BertModel

    bert = BertModel.from_pretrained(options.bert_path)

    num_bert_layers = len(bert.encoder.layer)
    trainable_layers = options.trainable_layers

    if options.training:
        if isinstance(trainable_layers, int):
            if trainable_layers == -1:
                trainable_layers = num_bert_layers
            assert trainable_layers >= 0 and trainable_layers <= num_bert_layers

    if trainable_layers != 'all':  # some parameters are not trainable
        # bert has no named buffers
        for params in bert.parameters():
            params.requires_grad_(False)

        for params in bert.encoder.layer[num_bert_layers - trainable_layers:].parameters():
            params.requires_grad_(True)

    return bert


def bert_cleanup(sentence):
    for old, new in BERT_TOKEN_MAPPING.items():
        sentence = sentence.replace(old, new)
    return sentence


def bert_encode(tokenizer, text_a, max_sequence_length, text_b=None):
    tokens_a = tokenizer.tokenize(text_a)
    tokens_b = None
    if text_b is not None:
        tokens_b = tokenizer.tokenize(text_b)
        # Modifies `tokens_a` and `tokens_b` in place so that the total
        # length is less than the specified length.
        # Account for [CLS], [SEP], [SEP] with '- 3'
        _truncate_sequence_pair(tokens_a, tokens_b, max_sequence_length - 3)
    else:
        # Account for [CLS] and [SEP] with '- 2'
        if len(tokens_a) > max_sequence_length - 2:
            tokens_a = tokens_a[:(max_sequence_length - 2)]

    # @see https://github.com/huggingface/pytorch-pretrained-BERT
    # The convention in BERT is:
    # (a) For sequence pairs:
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
    #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
    # (b) For single sequences:
    #  tokens:   [CLS] the dog is hairy . [SEP]
    #  type_ids: 0   0   0   0  0     0 0
    #
    # Where 'type_ids' are used to indicate whether this is the first
    # sequence or the second sequence. The embedding vectors for `type=0` and
    # `type=1` were learned during pre-training and are added to the wordpiece
    # embedding vector (and position vector). This is not *strictly* necessary
    # since the [SEP] token unambiguously separates the sequences, but it makes
    # it easier for the model to learn the concept of sequences.

    tokens = ['[CLS]'] + tokens_a + ['[SEP]']
    segment_ids = [0] * len(tokens)

    if tokens_b:
        tokens += tokens_b + ['[SEP]']
        segment_ids += [1] * (len(tokens_b) + 1)

    token_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    token_mask = [1] * len(token_ids)

    return token_ids, token_mask, segment_ids


def bert_decode(tokenizer, token_ids):
    return tokenizer.convert_ids_to_tokens(token_ids)


def _truncate_sequence_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""
    # copy from https://github.com/huggingface/pytorch-pretrained-BERT/
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()


class BERTPlugin(InputPluginBase):

    class Options(OptionsBase):
        bert_path: ExistFile = argfield(active_time='both')
        vocab_path: Optional[ExistFile] = argfield(None, active_time='both')

        subword_separator: Optional[str] = None

        pool_method: str = argfield('word_end', choices=['none', 'word_end', 'word_start'])
        project_to: Optional[int] = None
        feature_dropout: float = 0.0
        trainable_layers: Union[int, str] = 0

    def __init__(self, options: Options):
        super().__init__()

        self.options = options

        bert_path = options.bert_path
        vocab_path = options.vocab_path
        if vocab_path is None:
            options.vocab_path = vocab_path = bert_path

        if os.path.isdir(bert_path):
            self.bert_name = os.path.basename(bert_path)
        else:
            self.bert_name = bert_path

        if os.path.isdir(vocab_path):
            options.vocab_path = os.path.join(vocab_path, 'vocab.txt')
            self.vocab_name = os.path.basename(vocab_path.strip(os.path.sep))
        else:
            self.vocab_name = vocab_path

        self.bert = load_bert(options)

        if options.feature_dropout > 0:
            self.feature_dropout = FeatureDropout(options.feature_dropout)
        else:
            self.feature_dropout = None

        self.output_size = self.bert.pooler.dense.in_features
        if options.project_to:
            self.projection = nn.Linear(self.output_size, options.project_to, bias=False)
            self.output_size = options.project_to
        else:
            self.projection = None

    @lazy_property
    def tokenizer(self):
        from transformers import BertTokenizer

        bert_path = self.options.bert_path
        vocab_paths = [self.options.vocab_path]
        if os.path.isdir(bert_path):
            vocab_paths.append(os.path.join(bert_path, 'vocab.txt'))

        for vocab_path in vocab_paths:
            try:
                return BertTokenizer.from_pretrained(vocab_path,
                                                     do_lower_case=('uncased' in vocab_path))
            except OSError:
                LOGGER.warning('can not load bert vocab: %s', vocab_path)

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        destination = super().state_dict(destination, prefix, keep_vars)

        for name, params in self.bert.named_parameters():
            if not params.requires_grad:
                del destination[f'{prefix}bert.{name}']

        destination[f'{prefix}bert_name'] = self.bert_name
        destination[f'{prefix}vocab_name'] = self.vocab_name

        return destination

    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        bert_name = state_dict.pop(f'{prefix}bert_name')
        vocab_name = state_dict.pop(f'{prefix}vocab_name')
        if vocab_name != self.vocab_name:
            LOGGER.warning('vocab of BERTPlugin is changed: %s => %s', vocab_name, self.vocab_name)
        if bert_name != self.bert_name:
            LOGGER.warning('bert of BERTPlugin is changed: %s => %s', bert_name, self.bert_name)

        # fill state_dict to avoid missing keys error
        for key, value in self.bert.state_dict(prefix=f'{prefix}bert.').items():
            if key not in state_dict:
                state_dict[key] = value

        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

    def postprocess_sample(self, sample, sos_and_eos=False, **_kwargs):
        subword_separator = self.options.subword_separator
        words = sample.original_object.words

        word_starts = [0] if sos_and_eos else []
        word_ends = [0] if sos_and_eos else []

        word_pieces = ['[CLS]']
        for word in words:
            word_starts.append(len(word_pieces))
            if word == Tokens.SOS or word == Tokens.EOS:
                continue
            subwords = word.split(subword_separator) if subword_separator is not None else [word]
            for subword in subwords:
                subword = BERT_TOKEN_MAPPING.get(subword, subword)
                pieces = self.tokenizer.tokenize(subword)
                word_pieces.extend(pieces)
            word_ends.append(len(word_pieces) - 1)
        word_pieces.append('[SEP]')

        if sos_and_eos:
            word_starts.append(len(word_pieces) - 1)
            word_ends.append(len(word_pieces) - 1)

        sample.attrs['bert_tokens'] = \
            torch.tensor(self.tokenizer.convert_tokens_to_ids(word_pieces))
        sample.attrs['bert_word_starts'] = torch.tensor(word_starts)
        sample.attrs['bert_word_ends'] = torch.tensor(word_ends)

    def postprocess_batch(self, batch_samples, inputs):
        inputs['bert_tokens'] = \
            pad_and_stack_1d([sample.attrs['bert_tokens'] for sample in batch_samples])

        method = self.options.pool_method
        if method != 'none':
            method = f'bert_{method}s'
            inputs[method] = \
                pad_and_stack_1d([sample.attrs[method] for sample in batch_samples])

    def forward(self, inputs):
        features, *_ = self.bert(inputs.bert_tokens, attention_mask=(inputs.bert_tokens >= 0))

        method = self.options.pool_method
        if method != 'none':
            features = broadcast_gather(features, 1, inputs[f'bert_{method}s'])

        if self.projection is not None:
            features = self.projection(features)

        if self.feature_dropout is not None:
            features = self.feature_dropout(features)

        return features
