import torch
from scipy.stats import spearmanr
from torch.utils.data import DataLoader, TensorDataset
from collections import Counter
import matplotlib.pyplot as plt

from BERT_eval_sister_terms_similarity.bert_tensor_from_pedersen_synset_couple import PedersenSynsetCouple_AsBertInput
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class InVocabulary_PedersenSimilaritySister_Test:
    def __init__(self, tokenizer, bert, input_processor: PedersenSynsetCouple_AsBertInput):
        self.tokenizer = tokenizer
        self.bert = bert

        self.bert.to(device)
        self.input_processor: PedersenSynsetCouple_AsBertInput = input_processor

    def _get_inputs_from(self, couples, output_path=None):
        return self.input_processor.get_inputs_from(pedersen_synset_couples=couples, output_path=output_path)

    def run(self, couples, evaluator, output_path=None):
        raise NotImplementedError('Use one of subclasses: InVocabularyTestOneWordSentenceEmbedding, '
                                  'InVocabularyTestWordInOneSentenceEmbedding, InVocabularyTestExampleEmbedding')

    @staticmethod
    def instantiate(mode, tokenizer, bert):
        if mode == 'one_word_sentence':
            input_processor = PedersenSynsetCouple_AsBertInput.instantiate('one_word_sentence', tokenizer)
            return OneWordSentenceEmbedding_PedersenSimilarity(tokenizer, bert, input_processor)

        if mode == 'word_in_one_word_sentence':
            input_processor = PedersenSynsetCouple_AsBertInput.instantiate('one_word_sentence', tokenizer)
            return WordInOneSentenceEmbedding_PedersenSimilarity(tokenizer, bert, input_processor)

        if mode == 'definition_to_definition':
            input_processor = PedersenSynsetCouple_AsBertInput.instantiate('definition_to_definition', tokenizer)
            return DefinitionEmbedding_PedersenSimilarity(tokenizer, bert, input_processor)

        if mode == 'word_in_example':
            input_processor = PedersenSynsetCouple_AsBertInput.instantiate('word_in_example', tokenizer)
            return ExampleEmbedding_PedersenSimilarity(tokenizer, bert, input_processor)

        raise NotImplemented(
            '\'one_word_sentence\', \'word_in_one_word_sentence\', \'word_in_example\' and \'definition_to_definition\' are avaible modes')


class OneWordSentenceEmbedding_PedersenSimilarity(InVocabulary_PedersenSimilaritySister_Test):
    def __init__(self, tokenizer, bert, input_processor: PedersenSynsetCouple_AsBertInput):
        super().__init__(tokenizer, bert, input_processor)

    def _get_inputs_from(self, couples, output_path=None):
        return super()._get_inputs_from(couples, output_path)

    def run(self, couples, evaluator, output_path=None):
        tensor_ids_w1, tensor_ids_w2, sim_value = self._get_inputs_from(couples, output_path)
        """print(tensor_ids_w2)
        print(tensor_ids_w1)
        print(sim_value)"""
        dataloader = DataLoader(TensorDataset(tensor_ids_w1, tensor_ids_w2),
                                batch_size=10)

        cosines = []
        self.bert.eval()
        with torch.no_grad():
            for batch in dataloader:
                w1 = batch[0]
                w2 = batch[1]

                outputs_w1 = self.bert(w1)
                outputs_w2 = self.bert(w2)

                last_hidden_state_w1 = outputs_w1[0]
                last_hidden_state_w2 = outputs_w2[0]

                cls_w1 = last_hidden_state_w1[:, 0, :]
                cls_w2 = last_hidden_state_w2[:, 0, :]

                val = evaluator(cls_w1, cls_w2).tolist()
                """print(val)"""
                cosines.extend(val)

        return spearmanr(sim_value, cosines)


class WordInOneSentenceEmbedding_PedersenSimilarity(InVocabulary_PedersenSimilaritySister_Test):
    def __init__(self, tokenizer, bert, input_processor: PedersenSynsetCouple_AsBertInput):
        super().__init__(tokenizer, bert, input_processor)

    def _get_inputs_from(self, couples, output_path=None):
        return super()._get_inputs_from(couples, output_path)

    def run(self, couples, evaluator, output_path=None):
        tensor_ids_w1, tensor_ids_w2, sim_values = self._get_inputs_from(couples, output_path)
        """print(tensor_ids_w2)
        print(tensor_ids_w1)
        print(sim_value)"""
        dataloader = DataLoader(TensorDataset(tensor_ids_w1, tensor_ids_w2),
                                batch_size=10)

        cosines = []
        self.bert.eval()
        with torch.no_grad():
            for batch in dataloader:
                w1 = batch[0]
                w2 = batch[1]

                outputs_w1 = self.bert(w1)
                outputs_w2 = self.bert(w2)

                last_hidden_state_w1 = outputs_w1[0]
                last_hidden_state_w2 = outputs_w2[0]

                emb_w1 = last_hidden_state_w1[:, 1, :]
                emb_w2 = last_hidden_state_w2[:, 1, :]

                val = evaluator(emb_w1, emb_w2).tolist()
                """print(val)"""
                cosines.extend(val)

        return spearmanr(sim_values, cosines)


class DefinitionEmbedding_PedersenSimilarity(InVocabulary_PedersenSimilaritySister_Test):
    def __init__(self, tokenizer, bert, input_processor: PedersenSynsetCouple_AsBertInput):
        super().__init__(tokenizer, bert, input_processor)

    def _get_inputs_from(self, couples, output_path=None):
        return super()._get_inputs_from(couples, output_path)

    def run(self, couples, evaluator, output_path=None):
        tensor_w1s, tensor_def1s, tensor_w2s, tensor_def2s, sim_values = self._get_inputs_from(couples)

        """print(tensor_def1s.size())
        print(tensor_def2s.size())"""

        dataloader = DataLoader(TensorDataset(tensor_def1s, tensor_def2s),
                                batch_size=10)

        cosines = []

        self.bert.eval()
        with torch.no_grad():
            for batch in dataloader:
                def1 = batch[0]
                def2 = batch[1]

                outputs_definition = self.bert(def1)
                outputs_sister_definition = self.bert(def2)

                last_hidden_state_definition = outputs_definition[0]
                last_hidden_state_sister_definition = outputs_sister_definition[0]

                cls_definition = last_hidden_state_definition[:, 0, :]
                cls_sister_definition = last_hidden_state_sister_definition[:, 0, :]

                val = evaluator(cls_definition, cls_sister_definition).tolist()
                """print(val)"""
                cosines.extend(val)

        return spearmanr(sim_values, cosines)


class ExampleEmbedding_PedersenSimilarity(InVocabulary_PedersenSimilaritySister_Test):
    def __init__(self, tokenizer, bert, input_processor: PedersenSynsetCouple_AsBertInput):
        super().__init__(tokenizer, bert, input_processor)

    def _get_inputs_from(self, couples, output_path=None):
        return super()._get_inputs_from(couples, output_path)

    def run(self, couples, evaluator, output_path=None):
        input_ids_word, input_ids_example, indexes, \
        input_ids_sister, input_ids_example_sister , indexes_sister, sim_values = \
            self._get_inputs_from(couples, output_path)

        #plot(sim_values, output_path)

        dataloader = DataLoader(TensorDataset(input_ids_word, input_ids_example, indexes, \
                                                input_ids_sister, input_ids_example_sister, indexes_sister),
                                batch_size=10)
        cosines = []
        self.bert.eval()
        with torch.no_grad():
            for batch in dataloader:
                word = batch[0]
                example = batch[1]
                index = batch[2]

                sister = batch[3]
                sister_example = batch[4]
                sister_index = batch[5]

                outputs_example = self.bert(example)
                outputs_sister_example = self.bert(sister_example)

                last_hidden_state_example = outputs_example[0]
                last_hidden_state_sister_example = outputs_sister_example[0]

                last_hidden_state_example = last_hidden_state_example.tolist()
                last_hidden_state_sister_example = last_hidden_state_sister_example.tolist()

                index = index.tolist()
                sister_index = sister_index.tolist()

                for i in range(0, len(last_hidden_state_example)):

                    h = index[i]
                    k = sister_index[i]

                    """print('INDEX WORD', h)
                    print('INDEX SISTER', k)"""

                    embedding_word = last_hidden_state_example[i][h]
                    embedding_sister = last_hidden_state_sister_example[i][k]

                    embedding_word =   torch.tensor([embedding_word])
                    embedding_sister = torch.tensor([embedding_sister])

                    """print('words in example size', embedding_word.size())
                    print('words in example size', embedding_sister.size())"""

                    cos = evaluator(embedding_word, embedding_sister)
                    cosines.append(cos.tolist()[0])

        return spearmanr(sim_values, cosines)


def plot(values, output_path):
    counter = Counter(values)

    x = [key for key in counter]
    y = [counter[key] for key in counter]

    plt.scatter(x, y, alpha=0.5)
    plt.savefig(output_path.split('.')[0] + '_graph.png')
    plt.close()