# Copyright (c) <anonymized for review>
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import pytorch_pretrained_bert.tokenization as btok
from transformers import BertTokenizer, BertForMaskedLM, BasicTokenizer, BertModel
import numpy as np
import torch.nn.functional as F

from lama.modules.base_connector import (
    Base_Connector,
    BERT_PAD,
    BERT_UNK,
    BERT_SEP,
    BERT_CLS,
    MASK,
    activate_dropout
)


class Bert(Base_Connector):

    def __init__(self, args, vocab_subset = None):
        super().__init__()

        bert_model_name = args.bert_model_name
        dict_file = bert_model_name

        if args.bert_model_dir is not None:
            # load bert model from file
            bert_model_name = str(args.bert_model_dir) + "/"
            dict_file = bert_model_name + args.bert_vocab_name
            self.dict_file = dict_file
            print("loading BERT model from {}".format(bert_model_name))
        else:
            # load bert model from huggingface cache
            pass

        # When using a cased model, make sure to pass do_lower_case=False directly to BaseTokenizer
        do_lower_case = False
        if 'uncased' in bert_model_name:
            do_lower_case = True

        # Load pre-trained model tokenizer (vocabulary)
        self.tokenizer = BertTokenizer.from_pretrained(dict_file)

        # original vocab
        self.map_indices = None
        self.vocab = list(self.tokenizer.ids_to_tokens.values())
        self._init_inverse_vocab()

        # Add custom tokenizer to avoid splitting the ['MASK'] token
        custom_basic_tokenizer = BasicTokenizer(do_lower_case = do_lower_case)
        self.tokenizer.basic_tokenizer = custom_basic_tokenizer

        # Load pre-trained model (weights)
        # ... to get prediction/generation
        self.masked_bert_model = BertForMaskedLM.from_pretrained(bert_model_name)

        self.masked_bert_model.eval()

        # ... to get hidden states
        self.bert_model = self.masked_bert_model.bert

        self.pad_id = self.inverse_vocab[BERT_PAD]

        self.unk_index = self.inverse_vocab[BERT_UNK]

        self.mask_id = self.inverse_vocab[MASK]

    def eval(self):
        self.bert_model.eval()

    def apply_dropout(self, seed):
        torch.manual_seed(seed)
        self.bert_model.apply(activate_dropout)

    def get_id(self, string):
        tokenized_text = self.tokenizer.tokenize(string)
        indexed_string = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        if self.map_indices is not None:
            # map indices to subset of the vocabulary
            indexed_string = self.convert_ids(indexed_string)

        return indexed_string

    def __get_input_tensors_batch(self, sentences_list):
        tokens_tensors_list = []
        segments_tensors_list = []
        masked_indices_list = []
        tokenized_text_list = []
        max_tokens = 0
        for sentences in sentences_list:
            tokens_tensor, segments_tensor, masked_indices, tokenized_text = self.__get_input_tensors(sentences)
            tokens_tensors_list.append(tokens_tensor)
            segments_tensors_list.append(segments_tensor)
            masked_indices_list.append(masked_indices)
            tokenized_text_list.append(tokenized_text)
            # assert(tokens_tensor.shape[1] == segments_tensor.shape[1])
            if (tokens_tensor.shape[1] > max_tokens):
                max_tokens = tokens_tensor.shape[1]
        # print("MAX_TOKENS: {}".format(max_tokens))
        # apply padding and concatenate tensors
        # use [PAD] for tokens and 0 for segments
        final_tokens_tensor = None
        final_segments_tensor = None
        final_attention_mask = None
        for tokens_tensor, segments_tensor in zip(tokens_tensors_list, segments_tensors_list):
            dim_tensor = tokens_tensor.shape[1]
            pad_lenght = max_tokens - dim_tensor
            attention_tensor = torch.full([1,dim_tensor], 1, dtype= torch.long)
            if pad_lenght>0:
                pad_1 = torch.full([1,pad_lenght], self.pad_id, dtype= torch.long)
                pad_2 = torch.full([1,pad_lenght], 0, dtype= torch.long)
                attention_pad = torch.full([1,pad_lenght], 0, dtype= torch.long)
                tokens_tensor = torch.cat((tokens_tensor,pad_1), dim=1)
                segments_tensor = torch.cat((segments_tensor,pad_2), dim=1)
                attention_tensor = torch.cat((attention_tensor,attention_pad), dim=1)
            if final_tokens_tensor is None:
                final_tokens_tensor = tokens_tensor
                final_segments_tensor = segments_tensor
                final_attention_mask = attention_tensor
            else:
                final_tokens_tensor = torch.cat((final_tokens_tensor,tokens_tensor), dim=0)
                final_segments_tensor = torch.cat((final_segments_tensor,segments_tensor), dim=0)
                final_attention_mask = torch.cat((final_attention_mask,attention_tensor), dim=0)
        # print(final_tokens_tensor)
        # print(final_segments_tensor)
        # print(final_attention_mask)
        # print(final_tokens_tensor.shape)
        # print(final_segments_tensor.shape)
        # print(final_attention_mask.shape)
        return final_tokens_tensor, final_segments_tensor, final_attention_mask, masked_indices_list, tokenized_text_list

    def __get_input_tensors(self, sentences):

        if len(sentences) > 2:
            print(sentences)
            raise ValueError("BERT accepts maximum two sentences in input for each data point")

        first_tokenized_sentence = self.tokenizer.tokenize(sentences[0])
        first_segment_id = np.zeros(len(first_tokenized_sentence), dtype=int).tolist()

        # add [SEP] token at the end
        first_tokenized_sentence.append(BERT_SEP)
        first_segment_id.append(0)

        if len(sentences)>1 :
            second_tokenized_sentece = self.tokenizer.tokenize(sentences[1])
            second_segment_id = np.full(len(second_tokenized_sentece),1, dtype=int).tolist()

            # add [SEP] token at the end
            second_tokenized_sentece.append(BERT_SEP)
            second_segment_id.append(1)

            tokenized_text = first_tokenized_sentence + second_tokenized_sentece
            segments_ids = first_segment_id + second_segment_id
        else:
            tokenized_text = first_tokenized_sentence
            segments_ids = first_segment_id

        # add [CLS] token at the beginning
        tokenized_text.insert(0,BERT_CLS)
        segments_ids.insert(0,0)

        # look for masked indices
        masked_indices = []
        for i in range(len(tokenized_text)):
            token = tokenized_text[i]
            if token == MASK:
                masked_indices.append(i)

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)

        # Convert inputs to PyTorch tensors
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_tensors = torch.tensor([segments_ids])

        return tokens_tensor, segments_tensors, masked_indices, tokenized_text

    def __get_token_ids_from_tensor(self, indexed_string):
        token_ids = []
        if self.map_indices is not None:
            # map indices to subset of the vocabulary
            indexed_string = self.convert_ids(indexed_string)
            token_ids = np.asarray(indexed_string)
        else:
            token_ids = indexed_string
        return token_ids

    def _cuda(self):
        self.masked_bert_model.cuda()

    def get_batch_generation(self, sentences_list, logger= None,
                             try_cuda=True):
        if not sentences_list:
            return None
        if try_cuda:
            self.try_cuda()

        tokens_tensor, segments_tensor, attention_mask_tensor, masked_indices_list, tokenized_text_list = self.__get_input_tensors_batch(sentences_list)

        if logger is not None:
            logger.debug("\n{}\n".format(tokenized_text_list))

        with torch.no_grad():
            logits = self.masked_bert_model(
                input_ids=tokens_tensor.to(self._model_device),
                token_type_ids=segments_tensor.to(self._model_device),
                attention_mask=attention_mask_tensor.to(self._model_device),
            ).logits

            log_probs = F.log_softmax(logits, dim=-1).cpu()

        token_ids_list = []
        for indexed_string in tokens_tensor.numpy():
            token_ids_list.append(self.__get_token_ids_from_tensor(indexed_string))

        return log_probs, token_ids_list, masked_indices_list

    def get_contextual_embeddings(self, sentences_list, try_cuda=True):

        # assume in input 1 or 2 sentences - in general, it considers only the first 2 sentences
        if not sentences_list:
            return None
        if try_cuda:
            self.try_cuda()

        tokens_tensor, segments_tensor, attention_mask_tensor, masked_indices_list, tokenized_text_list = self.__get_input_tensors_batch(sentences_list)

        with torch.no_grad():
            all_encoder_layers, _ = self.bert_model(
                tokens_tensor.to(self._model_device),
                segments_tensor.to(self._model_device))

        all_encoder_layers = [layer.cpu() for layer in all_encoder_layers]

        sentence_lengths = [len(x) for x in tokenized_text_list]

        # all_encoder_layers: a list of the full sequences of encoded-hidden-states at the end
        # of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
        # encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
        return all_encoder_layers, sentence_lengths, tokenized_text_list

    def get_sentence_score(self, token_ids, logger=None, try_cuda=True, max_batch_size=32):

        if try_cuda:
            self.try_cuda()

        # Remove paddings
        content_indices = torch.tensor(token_ids != self.pad_id).nonzero(as_tuple=True)[0]
        # if pad_indices.size()[0] != 0:
        #     first_pad_index = pad_indices[0].item()
        #     token_ids = token_ids[:first_pad_index]
        first_pad_index = content_indices[-1].item() + 1
        token_ids = token_ids[:first_pad_index]


        assert token_ids[0] == self.inverse_vocab[BERT_CLS], (
            f"token_ids: {token_ids}, id(BERT_CLS): {self.inverse_vocab[BERT_CLS]}")
        assert token_ids[-1] == self.inverse_vocab[BERT_SEP], (
            f"token_ids: {token_ids}, id(BERT_SEP): {self.inverse_vocab[BERT_SEP]}")

        original_token_ids = token_ids[1:-1]

        # 1. Create a list of masked inputs
        #     If token_ids corresponds to ["[CLS]", "I", "have", "a", "pen", ".", "[SEP]"],
        #     token_ids_masked_list is the list of token ids corresponding to: [
        #     ["[CLS]", "[MASK]", "have", "a", "pen", ".", "[SEP]"],
        #     ["[CLS]", "I", "[MASK]", "a", "pen", ".", "[SEP]"],
        #     ["[CLS]", "I", "have", "[MASK]", "pen", ".", "[SEP]"],
        #     ["[CLS]", "I", "have", "a", "[MASK]", ".", "[SEP]"],
        #     ["[CLS]", "I", "have", "a", "pen", "[MASK]", "[SEP]"],
        #     ]

        mask_indices = [[mask_pos] for mask_pos in range(len(token_ids))]
        # We don't mask the [CLS], [SEP] for now for PLL
        mask_indices = mask_indices[1:-1]

        token_ids_masked_list = []
        for mask_set in mask_indices:
            token_ids_masked = token_ids.copy()
            token_ids_masked[mask_set] = self.mask_id
            token_ids_masked_list.append((token_ids_masked, mask_set))

        # print(f"token_ids_masked_list: {token_ids_masked_list}")

        # 2. Create input batches

        tokens_tensors_list = []
        segments_tensors_list = []
        masked_indices_list = []
        attention_tensors_list = []
        for token_ids_masked, mask_set in token_ids_masked_list:
            segment_ids = np.zeros(len(token_ids_masked), dtype=int).tolist()

            # Convert inputs to PyTorch tensors
            tokens_tensor = torch.tensor([token_ids_masked])
            segment_tensor = torch.tensor([segment_ids])

            tokens_tensors_list.append(tokens_tensor)
            segments_tensors_list.append(segment_tensor)
            masked_indices_list.append(mask_set)
            assert tokens_tensor.shape[1] == segment_tensor.shape[1]

            # No padding is required as all sentences are the same length
            attention_tensor = torch.full([1,tokens_tensor.shape[1]], 1, dtype= torch.long)
            attention_tensors_list.append(attention_tensor)

        tokens_tensor_batch_list = []
        segments_tensor_batch_list = []
        attention_mask_tensor_batch_list = []

        num_batches = ((len(tokens_tensors_list) - 1) // max_batch_size) + 1
        # print(f"tokens list length: {len(tokens_tensors_list)}")
        # print(f"num_batches: {num_batches}")

        # 3. Calculate score
        logits = None
        with torch.no_grad():
            for i_b in range(num_batches):
                b_s = max_batch_size * i_b
                b_e = max_batch_size * (i_b + 1)

                tokens_tensor_batch = torch.cat(tokens_tensors_list[b_s:b_e], dim=0)
                segments_tensor_batch = torch.cat(segments_tensors_list[b_s:b_e], dim=0)
                attention_mask_tensor_batch = torch.cat(attention_tensors_list[b_s:b_e], dim=0)

                # print(f"tokens_tensor_batch: {tokens_tensor_batch}")
                # print(f"segments_tensor_batch: {segments_tensor_batch}")
                # print(f"attention_mask_tensor_batch: {attention_mask_tensor_batch}")
                # print(f"masked_indices_list: {masked_indices_list}")

                logits_b = self.masked_bert_model(
                    input_ids=tokens_tensor_batch.to(self._model_device),
                    token_type_ids=segments_tensor_batch.to(self._model_device),
                    attention_mask=attention_mask_tensor_batch.to(self._model_device),
                ).logits
                logits_b = logits_b.cpu()

                if logits is None:
                    logits = logits_b
                else:
                    logits = torch.cat([logits, logits_b], dim=0)

            # print(f"logits size: {logits.size()}")

            log_probs = F.log_softmax(logits, dim=-1)
            # size: (token_length - 2, token_length, vocab_size)

            # Get log probs at the masked tokens
            log_probs_at_mask = log_probs[torch.arange(log_probs.shape[0]), 
                [x[0] for x in masked_indices_list], :]
            # size: (token_length - 2, vocab_size)

            # assert log_probs[0, 1, 123] == log_probs_at_mask[0, 123]

            # Log probability for the correct token
            token_scores = log_probs_at_mask[torch.arange(log_probs_at_mask.shape[0]), original_token_ids]
            assert token_scores[0] == log_probs[0, 1, original_token_ids[0]]

            score = torch.sum(token_scores, dim=0).item()

            # print(
            #     f"original_token_ids: {original_token_ids}\n"
            #     f"log_probs_at_mask.size(): {log_probs_at_mask.size()}\n"
            #     f"token_scores: {token_scores}\n"
            #     f"sentence score: {score}")

            return score
