import torch
from scipy.stats import spearmanr
from torch.utils.data import DataLoader, TensorDataset
import numpy
from BERT_eval_definitions_sister_terms_similarity.bert_tensor_from_pedersen_definition_sister import \
    PedersenDefinitionSisterCouple_AsBertInput

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Definition_PedersenSimilaritySister_Test:
    def __init__(self, tokenizer, bert, input_processor: PedersenDefinitionSisterCouple_AsBertInput):
        self.tokenizer = tokenizer
        self.bert = bert

        self.bert.to(device)
        self.input_processor: PedersenDefinitionSisterCouple_AsBertInput = input_processor

    def _get_inputs_from(self, couples, output_path=None):
        return self.input_processor.get_inputs_from(pedersen_synset_couples=couples, output_path=output_path)

    def run(self, couples, evaluator, output_path=None):
        raise NotImplementedError('Use one of subclasses: InVocabularyTestOneWordSentenceEmbedding, '
                                  'InVocabularyTestWordInOneSentenceEmbedding, InVocabularyTestExampleEmbedding')

    @staticmethod
    def instantiate(mode, tokenizer, bert):
        if mode == 'def_bert_cls':
            input_processor = PedersenDefinitionSisterCouple_AsBertInput.instantiate('def_bert_cls', tokenizer)
            return DefinitionExampleEmbedding_PedersenSimilarity(tokenizer, bert, input_processor)

        if mode == 'bert_wordpieces':
            input_processor = PedersenDefinitionSisterCouple_AsBertInput.instantiate('bert_wordpieces', tokenizer)
            return OOVFromExampleEmbedding_PedersenSimilarity(tokenizer, bert, input_processor)

        if mode == 'def_bert_head':
            input_processor = PedersenDefinitionSisterCouple_AsBertInput.instantiate('def_bert_head', tokenizer)
            return ParentExampleEmbedding_PedersenSimilarity(tokenizer, bert, input_processor)

        if mode == 'bert_head_example':
            input_processor = PedersenDefinitionSisterCouple_AsBertInput.instantiate('bert_head_example', tokenizer)
            return ParentFromExampleEmbedding_PedersenSimilarity(tokenizer, bert, input_processor)
        raise NotImplemented()


class DefinitionExampleEmbedding_PedersenSimilarity(Definition_PedersenSimilaritySister_Test):
    def __init__(self, tokenizer, bert, input_processor: Definition_PedersenSimilaritySister_Test):
        super().__init__(tokenizer, bert, input_processor)

    def _get_inputs_from(self, couples, output_path=None):
        return super()._get_inputs_from(couples, output_path)

    def run(self, couples, evaluator, output_path=None):
        tensor_definitions, \
        tensor_sister, indexes_sister, sim_values = self._get_inputs_from(couples, output_path)

        dataloader = DataLoader(TensorDataset(tensor_definitions, tensor_sister, indexes_sister), batch_size=10)
        cosines = []
        for batch in dataloader:
            definition = batch[0]
            """print(definition.size())"""

            sister_example = batch[1]
            sister_index = batch[2]

            """print(sister_example.size())
            print(sister_index.size())"""

            outputs_definition = self.bert(definition)
            outputs_sister_example = self.bert(sister_example)

            last_hidden_state_definition = outputs_definition[0]
            last_hidden_state_sister_example = outputs_sister_example[0]

            last_hidden_state_definition = last_hidden_state_definition.tolist()
            last_hidden_state_sister_example = last_hidden_state_sister_example.tolist()

            sister_index = sister_index.tolist()

            for i in range(0, len(last_hidden_state_definition)):
                cls_definition = last_hidden_state_definition[i][0]
                embedding_sister = last_hidden_state_sister_example[i][sister_index[i]]

                cls_definition = torch.tensor([cls_definition])
                embedding_sister = torch.tensor([embedding_sister])

                """print('words in example size', embedding_word.size())
                print('words in example size', embedding_sister.size())"""

                cos = evaluator(cls_definition, embedding_sister)
                cosines.append(cos.tolist()[0])

            return spearmanr(sim_values, cosines)


class OOVFromExampleEmbedding_PedersenSimilarity(Definition_PedersenSimilaritySister_Test):
    def __init__(self, tokenizer, bert, input_processor: Definition_PedersenSimilaritySister_Test):
        super().__init__(tokenizer, bert, input_processor)

    def _get_inputs_from(self, couples, output_path=None):
        return super()._get_inputs_from(couples, output_path)

    def run(self, couples, evaluator, output_path=None):
        tensor_oov_examples, oov_indexes, \
        tensor_sister, indexes_sister, sim_values = self._get_inputs_from(couples, output_path)

        dataloader = DataLoader(TensorDataset(tensor_oov_examples, oov_indexes,
                                              tensor_sister, indexes_sister), batch_size=10)
        cosines = []
        for batch in dataloader:
            oov_examples = batch[0]
            oov_indexes = batch[1]

            sister_examples = batch[2]
            sister_indexes = batch[3]

            """print(sister_example.size())
            print(sister_index.size())"""

            outputs_oov_examples = self.bert(oov_examples)
            outputs_sister_examples = self.bert(sister_examples)

            last_hidden_state_oov_examples = outputs_oov_examples[0]
            last_hidden_state_sister_examples = outputs_sister_examples[0]

            last_hidden_state_oov_examples = last_hidden_state_oov_examples.tolist()
            oov_indexes = oov_indexes.tolist()

            last_hidden_state_sister_examples = last_hidden_state_sister_examples.tolist()
            sister_indexes = sister_indexes.tolist()

            for i in range(0, len(last_hidden_state_oov_examples)):
                # print(len(last_hidden_state_oov_examples[i][oov_indexes[i][0]:oov_indexes[i][1] + 1]))
                embedding_oov = [0 for _ in range(len(last_hidden_state_oov_examples[i][0]))]
                for x in last_hidden_state_oov_examples[i][oov_indexes[i][0]:oov_indexes[i][1]+1]:
                    embedding_oov = numpy.add(embedding_oov, x).tolist()
                # print(embedding_oov)
                embedding_sister = last_hidden_state_sister_examples[i][sister_indexes[i]]

                cos = evaluator(torch.tensor([embedding_oov]),
                                torch.tensor([embedding_sister])).tolist()[0]

                cosines.append(cos)

            return spearmanr(sim_values, cosines)


class ParentExampleEmbedding_PedersenSimilarity(Definition_PedersenSimilaritySister_Test):
    def __init__(self, tokenizer, bert, input_processor: Definition_PedersenSimilaritySister_Test):
        super().__init__(tokenizer, bert, input_processor)

    def _get_inputs_from(self, couples, output_path=None):
        return super()._get_inputs_from(couples, output_path)

    def run(self, couples, evaluator, output_path=None):
        tensor_parent_examples, parent_indexes, \
        tensor_sister, indexes_sister, sim_values = self._get_inputs_from(couples, output_path)

        dataloader = DataLoader(TensorDataset(tensor_parent_examples, parent_indexes,
                                              tensor_sister, indexes_sister), batch_size=10)
        cosines = []
        for batch in dataloader:
            parent_examples = batch[0]
            parent_indexes = batch[1]
            """print(definition.size())"""

            sister_examples = batch[2]
            sister_indexes = batch[3]

            """print(sister_example.size())
            print(sister_index.size())"""

            outputs_parent_examples = self.bert(parent_examples)
            outputs_sister_examples = self.bert(sister_examples)

            last_hidden_state_parent_examples = outputs_parent_examples[0]
            last_hidden_state_sister_examples = outputs_sister_examples[0]

            last_hidden_state_parent_examples = last_hidden_state_parent_examples.tolist()
            parent_indexes = parent_indexes.tolist()

            last_hidden_state_sister_examples = last_hidden_state_sister_examples.tolist()
            sister_indexes = sister_indexes.tolist()

            for i in range(0, len(last_hidden_state_parent_examples)):
                embedding_parent = last_hidden_state_parent_examples[i][parent_indexes[i]]
                embedding_sister = last_hidden_state_sister_examples[i][sister_indexes[i]]

                """print('words in example size', embedding_word.size())
                print('words in example size', embedding_sister.size())"""

                cos = evaluator(torch.tensor([embedding_parent]),
                                torch.tensor([embedding_sister])).tolist()[0]

                cosines.append(cos)

            return spearmanr(sim_values, cosines)


class ParentFromExampleEmbedding_PedersenSimilarity(Definition_PedersenSimilaritySister_Test):
    def __init__(self, tokenizer, bert, input_processor: Definition_PedersenSimilaritySister_Test):
        super().__init__(tokenizer, bert, input_processor)

    def _get_inputs_from(self, couples, output_path=None):
        return super()._get_inputs_from(couples, output_path)

    def run(self, couples, evaluator, output_path=None):
        tensor_parent_examples, parent_indexes, \
        tensor_sister, indexes_sister, sim_values = self._get_inputs_from(couples, output_path)

        dataloader = DataLoader(TensorDataset(tensor_parent_examples, parent_indexes,
                                              tensor_sister, indexes_sister), batch_size=10)
        cosines = []
        for batch in dataloader:
            parent_examples = batch[0]
            parent_indexes = batch[1]
            """print(definition.size())"""

            sister_examples = batch[2]
            sister_indexes = batch[3]

            """print(sister_example.size())
            print(sister_index.size())"""

            outputs_parent_examples = self.bert(parent_examples)
            outputs_sister_examples = self.bert(sister_examples)

            last_hidden_state_parent_examples = outputs_parent_examples[0]
            last_hidden_state_sister_examples = outputs_sister_examples[0]

            last_hidden_state_parent_examples = last_hidden_state_parent_examples.tolist()
            parent_indexes = parent_indexes.tolist()

            last_hidden_state_sister_examples = last_hidden_state_sister_examples.tolist()
            sister_indexes = sister_indexes.tolist()

            for i in range(0, len(last_hidden_state_parent_examples)):
                embedding_parent = last_hidden_state_parent_examples[i][parent_indexes[i]]
                embedding_sister = last_hidden_state_sister_examples[i][sister_indexes[i]]

                """print('words in example size', embedding_word.size())
                print('words in example size', embedding_sister.size())"""

                cos = evaluator(torch.tensor([embedding_parent]),
                                torch.tensor([embedding_sister])).tolist()[0]

                cosines.append(cos)

            return spearmanr(sim_values, cosines)