import os

import torch
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, TensorDataset

from utility.randomfixedseed import Random
from utility.distributions import UnknownDistribution
from BERT_eval_sim_definition_lemma.bert_tensor_from_word_in_synset import DefinitionsAndLemma_AsBertInput

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Test_Definition_Lemma:
    def __init__(self, model, tokenizer, words_in_synsets,
                 input_processor: DefinitionsAndLemma_AsBertInput, output_path):
        self.model = model
        self.tokenizer = tokenizer
        self.words_in_synsets = words_in_synsets

        self.model.to(device)

        if not os.path.exists(output_path):
            os.mkdir(output_path)
        print(output_path)
        self.output_path = output_path
        # Classe ad hoc per ottenere l'input necessario
        self.input_processor: DefinitionsAndLemma_AsBertInput = input_processor

    def _get_model_inputs(self, words_in_synsets, output_path=None):
        return self.input_processor.get_inputs_from(words_in_synsets, output_path=output_path)

    def run(self, max_len, pos=None, seed=19):
        raise NotImplementedError("Use one of subclass: TestDefinitionLemmaInSentence")

    @staticmethod
    def instantiate(mode, model, tokenizer, words_in_synsets, output_path):
        if mode == 'def_bert_cls':
            input_processor = DefinitionsAndLemma_AsBertInput.instantiate('def_bert_cls', tokenizer)
            return Test_Definition_LemmaInExample(model, tokenizer, words_in_synsets, input_processor, output_path)

        if mode == 'def_bert_head':
            input_processor = DefinitionsAndLemma_AsBertInput.instantiate('def_bert_head', tokenizer)
            return Test_Parent_LemmaInExample(model, tokenizer, words_in_synsets, input_processor, output_path)

        if mode == 'bert_head_example':
            input_processor = DefinitionsAndLemma_AsBertInput.instantiate('bert_head_example', tokenizer)
            return Test_ParentFromExample_LemmaInExample(model, tokenizer, words_in_synsets, input_processor, output_path)
        raise NotImplemented(
            '\'one_word_sentence\' and \'word_in_example\' are avaible modes')


class Test_Definition_LemmaInExample(Test_Definition_Lemma):
    def __init__(self, model, tokenizer, words_in_synsets, input_processor, output_path):
        super().__init__(model, tokenizer, words_in_synsets, input_processor, output_path)

    def _get_model_inputs(self, words_in_synsets, output_path=None):
        input_tensor_defs, input_tensor_examples, indexes = super()._get_model_inputs(words_in_synsets, output_path=output_path)
        return input_tensor_defs, input_tensor_examples, indexes

    def run(self, max_len, pos=None, seed=19):
        Random.set_seed(int(seed))
        if pos is None:
            words_in_synsets = Random.sample(self.words_in_synsets, max_len)
        else:
            words_in_synsets = Random.sample([x for x in self.words_in_synsets if x.pos[0].lower() == pos], max_len)

        print(f'TASK 1.2 su {pos}, len test = {len(words_in_synsets)}, seed={seed}')
        input_tensor_defs, input_tensor_examples, indexes = self._get_model_inputs(words_in_synsets)
        """print(input_tensor_defs.size())
        print(input_tensor_examples.size())
        print(indexes.size())"""

        dataloader = DataLoader(TensorDataset(input_tensor_defs, input_tensor_examples, indexes),
                                batch_size=32)
        cosines = []
        evaluator = CosineSimilarity()

        self.model.eval()
        with torch.no_grad():
            for batch in dataloader:
                defs = batch[0]
                examples = batch[1]
                indexes = batch[2]

                outputs_defs = self.model(defs)
                outputs_examples = self.model(examples)

                # Prendiamo l'output dall'ultimo layer
                last_hidden_state_defs = outputs_defs[0]
                last_hidden_state_examples = outputs_examples[0]

                # L'embedding della definizione e' sempre l'embedding del suo CLS
                cls_defs = last_hidden_state_defs[:, 0, :]
                cls_defs = cls_defs.tolist()
                # L'embedding del lemma e' nell'esempio nella posizione indicatata da indexes[i]
                last_hidden_state_examples = last_hidden_state_examples.tolist()
                indexes = indexes.tolist()

                for i in range(0, len(cls_defs)):
                    embedding_word = last_hidden_state_examples[i][indexes[i]]
                    cosines.append(evaluator(torch.tensor([cls_defs[i]]), torch.tensor([embedding_word])).tolist()[0])

        seed_dir = os.path.join(self.output_path, 'seed_'+str(seed))
        if not os.path.exists(seed_dir):
            os.mkdir(seed_dir)

        if pos is None:
            file_name = 'hist_cosines.png'
        else:
            file_name = 'hist_cosines_'+pos+'.png'
        distr = UnknownDistribution(data=cosines)
        distr.save(output_path=os.path.join(seed_dir, file_name))

        with open(os.path.join(seed_dir, 'cosines_' + pos + '.txt'), 'w+') as f:
            lines = []
            for cosine in cosines:
                lines.append(str(cosine) + '\t#\n')

            f.writelines(lines)



        return distr.mu, distr.std, cosines



class Test_Parent_LemmaInExample(Test_Definition_Lemma):
    def __init__(self, model, tokenizer, words_in_synsets, input_processor, output_path):
        super().__init__(model, tokenizer, words_in_synsets, input_processor, output_path)

    def _get_model_inputs(self, words_in_synsets, output_path=None):
        input_tensor_parent_examples, parent_indexes, \
        input_tensor_examples, indexes = super()._get_model_inputs(words_in_synsets, output_path=output_path)
        return input_tensor_parent_examples, parent_indexes, input_tensor_examples, indexes

    def run(self, max_len, pos=None, seed=19):
        Random.set_seed(int(seed))
        if pos is None:
            words_in_synsets = Random.sample(self.words_in_synsets, max_len)
        else:
            words_in_synsets = Random.sample([x for x in self.words_in_synsets if x.pos[0].lower() == pos], max_len)

        print(f'TASK 1.4 on {pos}, len test = {len(words_in_synsets)}, seed={seed}')
        input_tensor_parent_examples, parent_indexes, \
        input_tensor_examples, indexes = self._get_model_inputs(words_in_synsets)

        """print(input_tensor_parent_examples.size())
        print(parent_indexes.size())
        print(input_tensor_examples.size())
        print(indexes.size())"""

        dataloader = DataLoader(TensorDataset(input_tensor_parent_examples, parent_indexes,
                                              input_tensor_examples, indexes),
                                batch_size=32)
        cosines = []
        evaluator = CosineSimilarity()

        self.model.eval()
        with torch.no_grad():
            for batch in dataloader:
                parent_examples = batch[0]
                parent_indexes = batch[1]
                examples = batch[2]
                indexes = batch[3]

                outputs_parent_examples = self.model(parent_examples)
                outputs_examples = self.model(examples)

                # Prendiamo l'output dall'ultimo layer
                last_hidden_state_parent_examples = outputs_parent_examples[0]
                last_hidden_state_examples = outputs_examples[0]

                last_hidden_state_parent_examples = last_hidden_state_parent_examples.tolist()
                parent_indexes = parent_indexes.tolist()

                last_hidden_state_examples = last_hidden_state_examples.tolist()
                indexes = indexes.tolist()

                for i in range(0, len(last_hidden_state_parent_examples)):
                    embedding_parent = last_hidden_state_parent_examples[i][parent_indexes[i]]
                    embedding_word = last_hidden_state_examples[i][indexes[i]]

                    cosines.append(evaluator(torch.tensor([embedding_parent]),
                                             torch.tensor([embedding_word])).tolist()[0])

        seed_dir = os.path.join(self.output_path, 'seed_'+str(seed))
        if not os.path.exists(seed_dir):
            os.mkdir(seed_dir)

        if pos is None:
            file_name = 'hist_cosines.png'
        else:
            file_name = 'hist_cosines_'+pos+'.png'
        distr = UnknownDistribution(data=cosines)
        distr.save(output_path=os.path.join(seed_dir, file_name))

        with open(os.path.join(seed_dir, 'cosines_' + pos + '.txt'), 'w+') as f:
            lines = []
            for cosine in cosines:
                lines.append(str(cosine) + '\t#\n')

            f.writelines(lines)

        return distr.mu, distr.std, cosines


class Test_Random_LemmaInExample(Test_Definition_Lemma):
    def __init__(self, model, tokenizer, words_in_synsets, input_processor, output_path):
        super().__init__(model, tokenizer, words_in_synsets, input_processor, output_path)

    def _get_model_inputs(self, words_in_synsets, output_path=None):
        input_tensor_parent_examples, parent_indexes, \
        input_tensor_examples, indexes = super()._get_model_inputs(words_in_synsets, output_path=output_path)
        return input_tensor_parent_examples, parent_indexes, input_tensor_examples, indexes

    def run(self, max_len, pos=None, seed=19):
        Random.set_seed(int(seed))
        if pos is None:
            words_in_synsets = Random.sample(self.words_in_synsets, max_len)
        else:
            words_in_synsets = Random.sample([x for x in self.words_in_synsets if x.pos[0].lower() == pos], max_len)

        print(f'TASK 1.5 su {pos}, len test = {len(words_in_synsets)}, seed={seed}')
        input_tensor_parent_examples, parent_indexes, \
        input_tensor_examples, indexes = self._get_model_inputs(words_in_synsets)

        """print(input_tensor_parent_examples.size())
        print(parent_indexes.size())
        print(input_tensor_examples.size())
        print(indexes.size())"""

        dataloader = DataLoader(TensorDataset(input_tensor_parent_examples, parent_indexes,
                                              input_tensor_examples, indexes),
                                batch_size=32)
        cosines = []
        evaluator = CosineSimilarity()

        self.model.eval()
        with torch.no_grad():
            for batch in dataloader:
                parent_examples = batch[0]
                parent_indexes = batch[1]
                examples = batch[2]
                indexes = batch[3]

                outputs_parent_examples = self.model(parent_examples)
                outputs_examples = self.model(examples)

                # Prendiamo l'output dall'ultimo layer
                last_hidden_state_parent_examples = outputs_parent_examples[0]
                last_hidden_state_examples = outputs_examples[0]

                last_hidden_state_parent_examples = last_hidden_state_parent_examples.tolist()
                parent_indexes = parent_indexes.tolist()

                last_hidden_state_examples = last_hidden_state_examples.tolist()
                indexes = indexes.tolist()

                for i in range(0, len(last_hidden_state_parent_examples)):
                    embedding_parent = last_hidden_state_parent_examples[i][parent_indexes[i]]
                    embedding_word = last_hidden_state_examples[i][indexes[i]]

                    cosines.append(evaluator(torch.tensor([embedding_parent]),
                                             torch.tensor([embedding_word])).tolist()[0])

        seed_dir = os.path.join(self.output_path, 'seed_'+str(seed))
        if not os.path.exists(seed_dir):
            os.mkdir(seed_dir)

        if pos is None:
            file_name = 'hist_cosines.png'
        else:
            file_name = 'hist_cosines_'+pos+'.png'
        distr = UnknownDistribution(data=cosines)
        distr.save(output_path=os.path.join(seed_dir, file_name))

        with open(os.path.join(seed_dir, 'cosines_' + pos + '.txt'), 'w+') as f:
            lines = []
            for cosine in cosines:
                lines.append(str(cosine) + '\t#\n')

            f.writelines(lines)

        return distr.mu, distr.std, cosines


class Test_ParentFromExample_LemmaInExample(Test_Definition_Lemma):
    def __init__(self, model, tokenizer, words_in_synsets, input_processor, output_path):
        super().__init__(model, tokenizer, words_in_synsets, input_processor, output_path)

    def _get_model_inputs(self, words_in_synsets, output_path=None):
        input_tensor_parent_examples, parent_indexes, \
        input_tensor_examples, indexes = super()._get_model_inputs(words_in_synsets, output_path=output_path)
        return input_tensor_parent_examples, parent_indexes, input_tensor_examples, indexes

    def run(self, max_len, pos=None, seed=19, output_parent=None):
        Random.set_seed(int(seed))
        if pos is None:
            words_in_synsets = Random.sample(self.words_in_synsets, max_len)
        else:
            words_in_synsets = Random.sample([x for x in self.words_in_synsets if x.pos[0].lower() == pos], max_len)

        print(f'TASK 1.6 su {pos}, len test = {len(words_in_synsets)}, seed={seed}')
        input_tensor_parent_examples, parent_indexes, \
        input_tensor_examples, indexes = self._get_model_inputs(words_in_synsets, output_path=output_parent)

        """print(input_tensor_parent_examples.size())
        print(parent_indexes.size())
        print(input_tensor_examples.size())
        print(indexes.size())"""

        dataloader = DataLoader(TensorDataset(input_tensor_parent_examples, parent_indexes,
                                              input_tensor_examples, indexes),
                                batch_size=32)
        cosines = []
        evaluator = CosineSimilarity()

        self.model.eval()
        with torch.no_grad():
            for batch in dataloader:
                parent_examples = batch[0]
                parent_indexes = batch[1]
                examples = batch[2]
                indexes = batch[3]

                outputs_parent_examples = self.model(parent_examples)
                outputs_examples = self.model(examples)

                # Prendiamo l'output dall'ultimo layer
                last_hidden_state_parent_examples = outputs_parent_examples[0]
                last_hidden_state_examples = outputs_examples[0]

                last_hidden_state_parent_examples = last_hidden_state_parent_examples.tolist()
                parent_indexes = parent_indexes.tolist()

                last_hidden_state_examples = last_hidden_state_examples.tolist()
                indexes = indexes.tolist()

                for i in range(0, len(last_hidden_state_parent_examples)):
                    embedding_parent = last_hidden_state_parent_examples[i][parent_indexes[i]]
                    embedding_word = last_hidden_state_examples[i][indexes[i]]

                    cosines.append(evaluator(torch.tensor([embedding_parent]),
                                             torch.tensor([embedding_word])).tolist()[0])

        seed_dir = os.path.join(self.output_path, 'seed_'+str(seed))
        if not os.path.exists(seed_dir):
            os.mkdir(seed_dir)

        if pos is None:
            file_name = 'hist_cosines.png'
        else:
            file_name = 'hist_cosines_'+pos+'.png'
        distr = UnknownDistribution(data=cosines)
        distr.save(output_path=os.path.join(seed_dir, file_name))

        with open(os.path.join(seed_dir, 'cosines_' + pos + '.txt'), 'w+') as f:
            lines = []
            for cosine in cosines:
                lines.append(str(cosine) + '\t#\n')

            f.writelines(lines)

        return distr.mu, distr.std, cosines

