from transformers import BertTokenizer
import json
import torch


class RobustTokenizer:
    def __init__(self, tokenizer_path, max_length=128):
        self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
        self.word_shape = json.load(open(f"{tokenizer_path}/word_shape.json", "r", encoding="utf8"))
        self.word_pronunciation = json.load(open(f"{tokenizer_path}/word_pronunciation.json", "r", encoding="utf8"))
        self.shape_unk_id = self.word_shape["[UNK]"]
        self.proun_unk_id = self.word_pronunciation["[UNK]"]
        self.word_sep_id = self.tokenizer.convert_tokens_to_ids("[SEP]")
        self.shape_sep_id = self.shape_unk_id # self.word_shape["[SEP]"]
        self.proun_sep_id = self.proun_unk_id  # self.word_pronunciation["[SEP]"]
        self.word_cls_id = self.tokenizer.convert_tokens_to_ids("[CLS]")
        self.shape_cls_id = self.shape_unk_id # self.word_shape["[CLS]"]
        self.proun_cls_id = self.proun_unk_id #self.word_pronunciation["[CLS]"]
        self.max_length = max_length
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def encode_plus(self, sentence_a, sentence_b=None, return_tensor=False, **kwargs):
        if sentence_b is None:
            if len(sentence_a) >= self.max_length - 2:
                sentence_a = sentence_a[:self.max_length-2]
            input_token = ["[CLS]"] + [token for token in sentence_a] + ["[SEP]"]
            cur_input_ids = [self.tokenizer.convert_tokens_to_ids(token) for token in input_token]
            cur_input_shape = [self.word_shape.get(word, self.shape_unk_id) for word in input_token]
            cur_input_pronunciation = [self.word_pronunciation.get(word, self.proun_unk_id) for word in input_token]
            token_type_ids = [0] * len(cur_input_ids)
            assert len(cur_input_ids) <= self.max_length
        else:
            input_token_a = [token for token in sentence_a]
            cur_input_ids_a = [self.tokenizer.convert_tokens_to_ids(token) for token in input_token_a]
            cur_input_shape_a = [self.word_shape.get(word, self.shape_unk_id) for word in input_token_a]
            cur_input_pronunciation_a = [self.word_pronunciation.get(word, self.proun_unk_id) for word in input_token_a]

            input_token_b = [token for token in sentence_b]
            cur_input_ids_b = [self.tokenizer.convert_tokens_to_ids(token) for token in input_token_b]
            cur_input_shape_b = [self.word_shape.get(word, self.shape_unk_id) for word in input_token_b]
            cur_input_pronunciation_b = [self.word_pronunciation.get(word, self.proun_unk_id) for word in input_token_b]

            cur_input_ids = cur_input_ids_a + [self.word_sep_id] + cur_input_ids_b
            cur_input_shape = cur_input_shape_a + [self.shape_sep_id] + cur_input_shape_b
            cur_input_pronunciation = cur_input_pronunciation_a + [self.proun_sep_id] + cur_input_pronunciation_b
            token_type_ids = [0] * (len(input_token_a)+1) + [1] * len(input_token_b)
            if len(cur_input_ids) > self.max_length - 2:
                cur_input_ids = cur_input_ids[:self.max_length-2]
                cur_input_shape = cur_input_shape[:self.max_length-2]
                cur_input_pronunciation = cur_input_pronunciation[:self.max_length-2]
                token_type_ids = token_type_ids[:self.max_length-2]
            cur_input_ids = [self.word_cls_id] + cur_input_ids + [self.word_sep_id]
            cur_input_shape = [self.shape_cls_id] + cur_input_shape + [self.shape_sep_id]
            cur_input_pronunciation = [self.proun_cls_id] + cur_input_pronunciation + [self.proun_sep_id]
            token_type_ids = [0] + token_type_ids + [1]
            assert len(cur_input_ids) <= self.max_length
        if return_tensor:
            return_dict = {
                "input_ids": torch.tensor(cur_input_ids, dtype=torch.long).unsqueeze(0).to(self.device),
                "shape_ids": torch.tensor(cur_input_shape, dtype=torch.long).unsqueeze(0).to(self.device),
                "pronunciation_ids": torch.tensor(cur_input_pronunciation, dtype=torch.long).unsqueeze(0).to(self.device),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long).unsqueeze(0).to(self.device),
                "attention_mask": torch.tensor([1] * len(cur_input_ids), dtype=torch.long).unsqueeze(0).to(self.device),
            }
        else:
            return_dict = {
                "input_ids": cur_input_ids,
                "shape_ids": cur_input_shape,
                "pinyin_ids": cur_input_pronunciation,
                "token_type_ids": token_type_ids
            }
        return return_dict

    def save_vocabulary(self, **kwargs):
        return

    def save_pretrained(self, **kwargs):
        return