# Copyright (c) <anonymized for review>
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from fairseq.models.roberta import RobertaModel
from fairseq.modules import FairseqDropout
from fairseq import utils
import torch
import torch.nn.functional as F
import numpy as np

from lama.modules.base_connector import *

# ROBERTA_PAD = "<pad>"


class RobertaVocab(object):
    def __init__(self, roberta):
        self.roberta = roberta

    def __getitem__(self, arg):
        value = ""
        try:
            predicted_token_bpe = self.roberta.task.source_dictionary.string([arg])
            if (
                predicted_token_bpe.strip() == ROBERTA_MASK
                or predicted_token_bpe.strip() == ROBERTA_START_SENTENCE
            ):
                value = predicted_token_bpe.strip()
            else:
                value = self.roberta.bpe.decode(str(predicted_token_bpe)).strip()
        except Exception as e:
            print(arg)
            print(predicted_token_bpe)
            print(value)
            print("Exception {} for input {}".format(e, arg))
        return value


class Roberta(Base_Connector):
    def __init__(self, args):
        super().__init__()
        roberta_model_dir = args.roberta_model_dir
        roberta_model_name = args.roberta_model_name
        roberta_vocab_name = args.roberta_vocab_name
        self.dict_file = "{}/{}".format(roberta_model_dir, roberta_vocab_name)
        self.model = RobertaModel.from_pretrained(
            roberta_model_dir, checkpoint_file=roberta_model_name
        )
        self.bpe = self.model.bpe
        self.task = self.model.task
        self._build_vocab()
        self._init_inverse_vocab()
        self.max_sentence_length = args.max_sentence_length

        self.bos_id = self.task.source_dictionary.bos()
        self.eos_id = self.task.source_dictionary.eos()
        self.pad_id = self.task.source_dictionary.pad()
        self.mask_id = self.task.mask_idx

        self.eval()

    def _cuda(self):
        self.model.cuda()

    def _build_vocab(self):
        self.vocab = []
        for key in range(ROBERTA_VOCAB_SIZE):
            predicted_token_bpe = self.task.source_dictionary.string([key])
            try:
                value = self.bpe.decode(predicted_token_bpe)

                if value[0] == " ":  # if the token starts with a whitespace
                    value = value.strip()
                else:
                    # this is subword information
                    value = "_{}_".format(value)

                if value in self.vocab:
                    # print("WARNING: token '{}' is already in the vocab".format(value))
                    value = "{}_{}".format(value, key)

                self.vocab.append(value)

            except Exception as e:
                self.vocab.append(predicted_token_bpe.strip())

    def eval(self):
        self.model.eval()
        self.model.model.eval()

    def _activate_dropout(self, m):
        if type(m) in (torch.nn.Dropout, FairseqDropout):
            m.train()

    def apply_dropout(self, seed):
        torch.manual_seed(seed)
        # Activate FairseqDropout
        self.model.model.apply(self._activate_dropout)


    def get_id(self, input_string):
        # Roberta predicts ' London' and not 'London'
        string = " " + str(input_string).strip()
        text_spans_bpe = self.bpe.encode(string.rstrip())
        tokens = self.task.source_dictionary.encode_line(
            text_spans_bpe, append_eos=False
        )
        return [element.item() for element in tokens.long().flatten()]

    def get_batch_generation(self, sentences_list, logger=None, try_cuda=True):
        if not sentences_list:
            return None
        if try_cuda:
            self.try_cuda()

        tensor_list = []
        masked_indices_list = []
        max_len = 0
        output_tokens_list = []
        for masked_inputs_list in sentences_list:

            tokens_list = []

            for idx, masked_input in enumerate(masked_inputs_list):

                # 2. sobstitute [MASK] with <mask>
                masked_input = masked_input.replace(MASK, ROBERTA_MASK)

                text_spans = masked_input.split(ROBERTA_MASK)
                text_spans_bpe = (
                    (" {0} ".format(ROBERTA_MASK))
                    .join(
                        [
                            self.bpe.encode(text_span.rstrip())
                            for text_span in text_spans
                        ]
                    )
                    .strip()
                )

                prefix = ""
                if idx == 0:
                    prefix = ROBERTA_START_SENTENCE

                tokens_list.append(
                    self.task.source_dictionary.encode_line(
                        str(prefix + " " + text_spans_bpe).strip(), append_eos=True
                    )
                )

            tokens = torch.cat(tokens_list)[: self.max_sentence_length]
            output_tokens_list.append(tokens.long().cpu().numpy())

            if len(tokens) > max_len:
                max_len = len(tokens)
            tensor_list.append(tokens)
            masked_index = (tokens == self.task.mask_idx).nonzero().numpy()
            masked_indices_list.append([x[0] for x in masked_index])
            # for x in masked_index:
            #     masked_indices_list.append([x[0]])

        pad_id = self.task.source_dictionary.pad()
        tokens_list = []
        for tokens in tensor_list:
            pad_lenght = max_len - len(tokens)
            if pad_lenght > 0:
                pad_tensor = torch.full([pad_lenght], pad_id, dtype=torch.int)
                tokens = torch.cat((tokens, pad_tensor))
            tokens_list.append(tokens)

        batch_tokens = torch.stack(tokens_list)

        with torch.no_grad():
            # with utils.eval(self.model.model):
            # self.model.eval()
            # self.model.model.eval()
            logits, extra = self.model.model(
                batch_tokens.long().to(device=self._model_device),
                features_only=False,
                return_all_hiddens=False,
            )
            logits = logits.cpu()

            # print(f"logits size: {logits.size()}")

            log_probs = F.log_softmax(logits, dim=-1)
            # size: (token_length - 2, token_length, vocab_size)

        return log_probs.cpu(), output_tokens_list, masked_indices_list

    def get_contextual_embeddings(self, sentences_list, try_cuda=True):
        # TBA
        return None
    
    def get_sentence_score(self, token_ids, logger=None, try_cuda=True, max_batch_size=32):
        if try_cuda:
            self.try_cuda()

        # Remove paddings
        content_indices = torch.tensor(token_ids != self.pad_id).nonzero(as_tuple=True)[0]
        # if pad_indices.size()[0] != 0:
        #     first_pad_index = pad_indices[0].item()
        #     token_ids = token_ids[:first_pad_index]
        first_pad_index = content_indices[-1].item() + 1
        token_ids = token_ids[:first_pad_index]


        assert token_ids[0] == self.bos_id, (
            f"token_ids: {token_ids}, bos_id: {self.bos_id}")
        assert token_ids[-1] == self.eos_id, (
            f"token_ids: {token_ids}, eos_id: {self.eos_id}")

        original_token_ids = token_ids[1:-1]

        # 1. Create a list of masked inputs
        #     If token_ids corresponds to ["<s>", "I", "have", "a", "pen", ".", "</s>"],
        #     token_ids_masked_list is the list of token ids corresponding to: [
        #     ["<s>", "<mask>", "have", "a", "pen", ".", "</s>"],
        #     ["<s>", "I", "<mask>", "a", "pen", ".", "</s>"],
        #     ["<s>", "I", "have", "<mask>", "pen", ".", "</s>"],
        #     ["<s>", "I", "have", "a", "<mask>", ".", "</s>"],
        #     ["<s>", "I", "have", "a", "pen", "<mask>", "</s>"],
        #     ]

        mask_indices = [[mask_pos] for mask_pos in range(len(token_ids))]
        # We don't mask the [CLS], [SEP] for now for PLL
        mask_indices = mask_indices[1:-1]

        token_ids_masked_list = []
        for mask_set in mask_indices:
            token_ids_masked = token_ids.copy()
            token_ids_masked[mask_set] = self.mask_id
            token_ids_masked_list.append((token_ids_masked, mask_set))

        # print(f"token_ids_masked_list: {token_ids_masked_list}")

        # 2. Create input batches

        tokens_tensors_list = []
        # segments_tensors_list = []
        masked_indices_list = []
        # attention_tensors_list = []
        for token_ids_masked, mask_set in token_ids_masked_list:
            # segment_ids = np.zeros(len(token_ids_masked), dtype=int).tolist()

            # Convert inputs to PyTorch tensors
            tokens_tensor = torch.tensor([token_ids_masked])
            # segment_tensor = torch.tensor([segment_ids])

            tokens_tensors_list.append(tokens_tensor)
            # segments_tensors_list.append(segment_tensor)
            masked_indices_list.append(mask_set)
            # assert tokens_tensor.shape[1] == segment_tensor.shape[1]

            # # No padding is required as all sentences are the same length
            # attention_tensor = torch.full([1,tokens_tensor.shape[1]], 1, dtype= torch.long)
            # attention_tensors_list.append(attention_tensor)

        tokens_tensor_batch_list = []
        # segments_tensor_batch_list = []
        # attention_mask_tensor_batch_list = []

        num_batches = ((len(tokens_tensors_list) - 1) // max_batch_size) + 1
        # print(f"tokens list length: {len(tokens_tensors_list)}")
        # print(f"num_batches: {num_batches}")

        # 3. Calculate score
        logits = None
        with torch.no_grad():
            for i_b in range(num_batches):
                b_s = max_batch_size * i_b
                b_e = max_batch_size * (i_b + 1)

                tokens_tensor_batch = torch.cat(tokens_tensors_list[b_s:b_e], dim=0)
                # segments_tensor_batch = torch.cat(segments_tensors_list[b_s:b_e], dim=0)
                # attention_mask_tensor_batch = torch.cat(attention_tensors_list[b_s:b_e], dim=0)

                # print(f"tokens_tensor_batch: {tokens_tensor_batch}")
                # print(f"segments_tensor_batch: {segments_tensor_batch}")
                # print(f"attention_mask_tensor_batch: {attention_mask_tensor_batch}")
                # print(f"masked_indices_list: {masked_indices_list}")

                # self.model.eval()
                # self.model.model.eval()
                logits_b, extra = self.model.model(
                    tokens_tensor_batch.to(self._model_device),
                    features_only=False,
                    return_all_hiddens=False,
                    # token_type_ids=segments_tensor_batch.to(self._model_device),
                    # attention_mask=attention_mask_tensor_batch.to(self._model_device),
                )
                logits_b = logits_b.cpu()

                if logits is None:
                    logits = logits_b
                else:
                    logits = torch.cat([logits, logits_b], dim=0)

            # print(f"logits size: {logits.size()}")

            log_probs = F.log_softmax(logits, dim=-1)
            # size: (token_length - 2, token_length, vocab_size)

            # Get log probs at the masked tokens
            log_probs_at_mask = log_probs[torch.arange(log_probs.shape[0]), 
                [x[0] for x in masked_indices_list], :]
            # size: (token_length - 2, vocab_size)

            # assert log_probs[0, 1, 123] == log_probs_at_mask[0, 123]

            # Log probability for the correct token
            token_scores = log_probs_at_mask[torch.arange(log_probs_at_mask.shape[0]), original_token_ids]
            assert token_scores[0] == log_probs[0, 1, original_token_ids[0]]

            score = torch.sum(token_scores, dim=0).item()

            # print(
            #     f"original_token_ids: {original_token_ids}\n"
            #     f"log_probs_at_mask.size(): {log_probs_at_mask.size()}\n"
            #     f"token_scores: {token_scores}\n"
            #     f"sentence score: {score}")

            return score
