# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
""" Tokenization classes for XLM-RoBERTa model."""


import logging
import os
from shutil import copyfile

from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer


logger = logging.getLogger(__name__)

VOCAB_FILES_NAMES = {
    "vocab_file": "sentencepiece.bpe.model",
    "vocab_restrict": "vocab_restrict",
}


class DenseTokenizer(XLMRobertaTokenizer):
    """
        Adapted from RobertaTokenizer and XLNetTokenizer
        SentencePiece based tokenizer. Peculiarities:

            - requires `SentencePiece <https://github.com/google/sentencepiece>`_
    """

    def __init__(
        self,
        vocab_file,
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        vocab_restrict=None,
        **kwargs
    ):
        super().__init__(
            vocab_file,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            **kwargs,
        )
        self.vocab_restrict = vocab_restrict
        if self.vocab_restrict is not None:
            # Mimic fairseq token-to-id alignment for the first 4 token
            self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}

            self.small_vocab = [line.strip() for line in open(vocab_restrict, "r", encoding="utf-8").readlines()]
            self.sp_model.set_vocabulary(self.small_vocab)
            self.token2id = {token: idx for idx, token in enumerate(self.small_vocab, start=len(self.fairseq_tokens_to_ids))}
            self.id2token = {idx: token for idx, token in enumerate(self.small_vocab, start=len(self.fairseq_tokens_to_ids))}

            self.fairseq_tokens_to_ids["<mask>"] = len(self.small_vocab) + len(self.fairseq_tokens_to_ids)
            self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}

    @property
    def vocab_size(self):
        if self.vocab_restrict is None:
            return super().vocab_size
        else:
            return len(self.vocab_restrict) + len(self.fairseq_tokens_to_ids)

    def _convert_token_to_id(self, token):
        """ Converts a token (str) in an id using the vocab. """
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
        if self.vocab_restrict is None:
            return self.sp_model.PieceToId(token) + self.fairseq_offset
        else:
            try:
                return self.token2id[token]
            except KeyError:
                return self.fairseq_tokens_to_ids["<unk>"]

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        if index in self.fairseq_ids_to_tokens:
            return self.fairseq_ids_to_tokens[index]
        if self.vocab_restrict is None:
            return self.sp_model.IdToPiece(index - self.fairseq_offset)
        else:
            return self.id2token[index]

    def save_vocabulary(self, save_directory):
        """ Save the sentencepiece vocabulary (copy original file) and special tokens file
            to a directory.
        """
        out, = super().save_vocabulary(save_directory)
        if self.vocab_restrict is not None:
            out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_restrict"])
            if os.path.abspath(self.vocab_restrict) != os.path.abspath(out_vocab_file):
                copyfile(self.vocab_restrict, out_vocab_file)
            return (out, out_vocab_file)
        else:
            return (out,)
