import torch
import pandas as pd
from keras_preprocessing.sequence import pad_sequences
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, TensorDataset
from nltk.corpus import wordnet as wn
from pytorch_transformers import BertTokenizer, BertModel

from BERT.DefBERT import ParentModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class DefBERTCLS:
    def __init__(self, model_name='bert-base-uncased'):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name, output_hidden_states=True)
        self.parent_model = ParentModel(self.tokenizer)
        self.bert.to(device)

    def _get_model_definition_inputs(self, correlations):
        syn_list = []
        def_list = []
        word_list = []
        example_list = []

        for ws in correlations:
            s = wn.synset(ws['s1'])
            # recupera gli esempi dal synset in questione
            examples = s.examples()

            example = None
            for i in range(0, len(examples)):
                if ws['w1'] in str(examples[i]).split(' '):
                    example = examples[i]
                    break

            # example puo' essere None se non ci sono esempi per quel synset o se gli esempi non usano il lemma ws.word
            if example is not None:
                syn_list.append(ws['s1'])
                word_list.append(ws['w1'])
                def_list.append(s.definition())
                example_list.append(example)

        print('dataset_size:', len(syn_list))

        df = pd.DataFrame([])
        df['syns'] = syn_list
        df['defs'] = def_list
        df['word'] = word_list
        df['example'] = example_list

        sentences_word = ["[CLS] " + word + " [SEP]" for word in df.word.values]
        sentences_def = ["[CLS] " + d + " [SEP]" for d in df.defs.values]
        sentences_examples = ["[CLS] " + example + " [SEP]" for example in df.example.values]

        tokenized_words = [self.tokenizer.tokenize(word) for word in sentences_word]

        tokenized_defs = []
        tokenized_examples = []
        indexes = []

        for i in range(0, len(tokenized_words)):
            # Controllo per verificare che il lemma sia nel vocabolario (ridondante se i lemmi vengono dal vocabolario del modello)
            if len(tokenized_words[i]) == 3:
                # Tokenizza l'esempio
                tokenized_example = self.tokenizer.tokenize(sentences_examples[i])
                # Ottieni l'indice del lemma all'interno dell'esempio: sara' utile per recuperarne l'embedding
                j = tokenized_example.index(tokenized_words[i][1])
                # Tokenizza la definizione
                tokenized_def = self.tokenizer.tokenize(sentences_def[i])

                """print(tokenized_words[i])
                print(tokenized_example)
                print(j)"""

                tokenized_examples.append(tokenized_example)
                indexes.append(j)
                tokenized_defs.append(tokenized_def)
            else:
                raise KeyError('words tokenized as:' + str(tokenized_words[i]) + ' cause it is not in vocabulary')

        input_ids_words = [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_words]
        input_ids_defs = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_defs],
                                       maxlen=len(max(tokenized_defs, key=lambda x: len(x))),
                                       dtype="long", truncating="post", padding="post")
        input_ids_examples = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples],
                                           maxlen=len(max(tokenized_examples, key=lambda x: len(x))),
                                           dtype="long", truncating="post", padding="post")

        return torch.tensor(input_ids_defs, device=device), \
               torch.tensor(input_ids_examples, device=device), \
               torch.tensor(indexes, device=device)

    def compare_embeddings_lemma_definition(self, words_in_synsets):
        input_tensor_defs, input_tensor_examples, indexes = self._get_model_definition_inputs(words_in_synsets)
        """print(input_tensor_defs.size())
        print(input_tensor_examples.size())
        print(indexes.size())"""

        dataloader = DataLoader(TensorDataset(input_tensor_defs, input_tensor_examples, indexes),
                                batch_size=32)
        cosines = []
        evaluator = CosineSimilarity()

        self.bert.eval()
        with torch.no_grad():
            for batch in dataloader:
                defs = batch[0]
                examples = batch[1]
                indexes = batch[2]

                outputs_defs = self.bert(defs)
                outputs_examples = self.bert(examples)

                # Prendiamo l'output dall'ultimo layer
                last_hidden_state_defs = outputs_defs[0]
                last_hidden_state_examples = outputs_examples[0]

                # L'embedding della definizione e' sempre l'embedding del suo CLS
                cls_defs = last_hidden_state_defs[:, 0, :]
                cls_defs = cls_defs.tolist()
                # L'embedding del lemma e' nell'esempio nella posizione indicatata da indexes[i]
                last_hidden_state_examples = last_hidden_state_examples.tolist()
                indexes = indexes.tolist()

                for i in range(0, len(cls_defs)):
                    embedding_word = last_hidden_state_examples[i][indexes[i]]
                    cosines.append(evaluator(torch.tensor([cls_defs[i]]), torch.tensor([embedding_word])).tolist()[0])

        return cosines
