from modeling_bart import BartForConditionalGeneration
from transformers import BartTokenizer
from transformers import BertTokenizer, BertConfig
# from transformers import AlbertTokenizer, AlbertForMaskedLM, AlbertConfig
# from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaConfig
from modeling_bert import BertForMaskedLM

import torch
import torch.nn.functional as F
import numpy as np
import os
import json
import logging
import random
import data_loader

from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)

logger = logging.getLogger(__name__)


class Prober:
    def __init__(self, args):
        super(Prober, self).__init__()

        # Load pre-trained model tokenizer (vocabulary)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        if torch.cuda.device_count() > 1:
            torch.cuda.manual_seed_all(args.seed)

        language_model_name = args.language_model_dir
        language_vocab_name = language_model_name
        self.mlm_config = AutoConfig.from_pretrained(language_model_name)
        self.model_type = 'bert'
        self.mlm_tokenizer = BertTokenizer.from_pretrained(language_vocab_name)
        self.mlm_model = BertForMaskedLM.from_pretrained(language_model_name,
                                                         config=self.mlm_config)

        # self.generator_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", force_bos_token_to_be_generated=True)
        # self.generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

        generator_name = args.generative_model_dir
        self.generator_config = AutoConfig.from_pretrained(
            generator_name,
        )
        self.generator_tokenizer = AutoTokenizer.from_pretrained(
            generator_name
        )
        self.generator_model = BartForConditionalGeneration.from_pretrained(
            generator_name,
            config=self.generator_config
        )
        # self.generator_model = BartForConditionalGeneration(config=self.generator_config)


    def init_indices_for_filter_logprobs(self, vocab_subset, tokenizer, logger=None):
        index_list = []
        new_vocab_subset = []
        for word in vocab_subset:
            tokens = tokenizer.tokenize(' ' + word)
            if (len(tokens) == 1) and (tokens[0] != tokenizer.unk_token):
                index_list.append(tokenizer.convert_tokens_to_ids(tokens)[0])
                new_vocab_subset.append(word)
            else:
                msg = "word {} from vocab_subset not in model vocabulary!".format(word)
                if logger is not None:
                    logger.warning(msg)
                else:
                    logger.info("WARNING: {}".format(msg))
        indices = torch.as_tensor(index_list)
        return indices, index_list

    def get_vocab_map(self, generator_index_list, mlm_index_list):
        generator_map2_mlm = {self.generator_tokenizer.pad_token_id: self.mlm_tokenizer.pad_token_id,
                              self.generator_tokenizer.cls_token_id: self.mlm_tokenizer.cls_token_id,
                              self.generator_tokenizer.sep_token_id: self.mlm_tokenizer.sep_token_id,
                              self.generator_tokenizer.mask_token_id: self.mlm_tokenizer.mask_token_id}
        for i, g_idx in enumerate(generator_index_list):
            m_idx = mlm_index_list[i]
            generator_map2_mlm[g_idx] = m_idx
        g_None_id = self.generator_tokenizer.convert_tokens_to_ids(self.generator_tokenizer.tokenize(" None"))[0]
        mlm_None_id = self.mlm_tokenizer.convert_tokens_to_ids(self.mlm_tokenizer.tokenize(" None"))[0]
        generator_map2_mlm[g_None_id] = mlm_None_id
        return generator_map2_mlm

    def get_generator_banned_token_ids(self, generator_index_list):
        generator_banned_token_ids = []
        for idx in range(self.generator_tokenizer.vocab_size):
            if not (idx in generator_index_list) and not (idx in self.generator_tokenizer.all_special_ids):
                generator_banned_token_ids.append(idx)
        return generator_banned_token_ids