import torch
import numpy as np
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification


class RobertaModelHubWrapper:
    def __init__(self, model: str = 'roberta.large'):
        self._model = torch.hub.load('pytorch/fairseq', model)
        self._model.eval()

    @staticmethod
    def feature_pooling(token_features, token_ids):
        if len(token_ids) == 1:
            return token_features[token_ids[0], :]
        else:
            return np.mean(token_features[token_ids, :], axis=0)

    def find_insertion_tokens(self, tokens, text, start_char, end_char):
        # decode the string that correspond to each token (no whitespaces)
        bpes = [self._model.task.source_dictionary.string([x]) for x in tokens]
        strings = [self._model.bpe.decode(x).strip() for x in bpes]

        # assign the correct token id to each char position (no whitespaces)
        char_2_token_id = [[i] * len(s) for i, s in enumerate(strings)]
        char_2_token_id = [token_id for sublist in char_2_token_id for token_id in sublist]

        # correct the start and end position to account for whitespaces
        start_char = len("".join(text[:start_char].split()))
        end_char = len("".join(text[:end_char].split()))

        # collect all tokens that span the insertion text
        token_ids = list(set(char_2_token_id[start_char:end_char]))
        token_ids.sort()

        return token_ids

    def insertion_pair_features(self, contextual_pair):
        # unpack the contextualised insertion pair
        text = contextual_pair[0]
        start_char_1 = contextual_pair[1]
        end_char_1 = contextual_pair[2]
        start_char_2 = contextual_pair[3]
        end_char_2 = contextual_pair[4]

        # embed the full text with the RoBERTa transformer
        tokens = self._model.encode(text)
        last_layer_features = self._model.extract_features(tokens)
        token_features = last_layer_features[0, :, :].cpu().detach().numpy()

        # extract and concatenate the features of the insertion pair
        token_ids_1 = self.find_insertion_tokens(tokens, text, start_char_1, end_char_1)
        token_ids_2 = self.find_insertion_tokens(tokens, text, start_char_2, end_char_2)
        features_1 = RobertaModelHubWrapper.feature_pooling(token_features, token_ids_1)
        features_2 = RobertaModelHubWrapper.feature_pooling(token_features, token_ids_2)
        embeddings = np.concatenate([features_1, features_2])

        return embeddings

    def all_insertion_pair_features(self, contextual_pairs, verbose=True):
        embeddings = np.zeros([len(contextual_pairs), 2048])
        with torch.no_grad():
            for i, contextual_pair in enumerate(contextual_pairs):
                embeddings[i, :] = self.insertion_pair_features(contextual_pair)

                if verbose and i % 1000 == 999:
                    text = contextual_pair[0]
                    print("> roberta_embeddings.py:", str(i + 1),
                          "out of", str(len(contextual_pairs)), text)

        return embeddings


class RobertaModelPreTrained:
    def __init__(self, model_checkpoint: str, tokenizer_args: dict = {}, model_args: dict = {}):
        self.model = XLMRobertaForSequenceClassification.from_pretrained(model_checkpoint, **model_args)
        self.tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base", **tokenizer_args)

    def compute_score(self, text: str):
        tokens = self.tokenizer(text, return_tensors='pt')
        activations = self.model(**tokens)
        logits = np.ndarray.flatten(activations.logits.cpu().detach().numpy())
        score = logits[0] - logits[1]

        return score

    def compute_all_scores(self, dataset, verbose: bool = True):
        scores = np.zeros(len(dataset))
        for i, text in enumerate(dataset):
            scores[i] = self.compute_score(text)

            if verbose and i % 100 == 99:
                print("> roberta_scores:", str(i + 1),
                      "out of", str(len(dataset)), text)

        return scores
