import torch
import pandas as pd
from keras_preprocessing.sequence import pad_sequences
from nltk.corpus.reader import Synset
from scipy.stats import spearmanr
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, TensorDataset
from nltk.corpus import wordnet as wn
from pytorch_transformers import BertTokenizer, BertModel
from sister_terms_similarity.pedersen_similarities import SisterOOVPair
from utility.randomfixedseed import Random
from utility.word_in_vocabulary import WNManager, BertChecker

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ParentExampleHandler:
    instance = None
    path = None
    @staticmethod
    def get_instance_from(pos, path):
        if ParentExampleHandler.path is None or path != ParentExampleHandler.path:
            ParentExampleHandler.path = path
            ParentExampleHandler.instance = ParentExampleHandler([(pos, path)])
            return ParentExampleHandler.instance
        else:
            return ParentExampleHandler.instance

    def __init__(self, path_pos):
        self.example_by_pos = {}
        self.not_found = []
        for (pos, path) in path_pos:
            if pos not in self.example_by_pos:
                self.example_by_pos[pos] = ParentExampleHandler.from_csv(path)
            else:
                examples = ParentExampleHandler.from_csv(path)
                for parent in examples:
                    self.example_by_pos[pos][parent] = examples[parent]

    @staticmethod
    def from_csv(path):
        example_by_parent = {}

        df = pd.read_csv(path)
        parents = df.parent.values
        sentences = df.sentence.values

        for i in range(0, len(parents)):
            example_by_parent[parents[i]] = sentences[i]
        return example_by_parent

    def get_example(self, pos, parent):
        try:
            return self.example_by_pos[pos][parent]
        except KeyError:
            #print('key error:'+parent)
            return parent



class ParentModel:
    def __init__(self, tokenizer):
        self.checker = BertChecker(tokenizer)
        self.hyper_from_def = 0
        self.hyper_from_tax = 0
        super(ParentModel, self).__init__()

    def in_voc_parent(self, x, pos_tag=None, synset=None):
        word = x
        pos = pos_tag
        if synset is None:
            syn = wn.synsets(lemma=word, pos=pos)[0]
        else:
            if type(synset) is str:
                syn = wn.synset(synset)
            if type(synset) is Synset:
                syn = synset

        h_paths = syn.hypernym_paths()
        definition = syn.definition().split(' ')
        for h_path in h_paths:
            h_path.reverse()
            for hyp in h_path:
                in_voc = [lemma for lemma in hyp.lemma_names() if x != lemma
                          and self.checker.is_in_vocabulary(word=lemma) and not WNManager.is_expression(lemma=lemma)]
                for voc in in_voc:
                    for w in definition:
                        if w.startswith(voc) and (len(w) - len(voc)) in range(-3, 3):
                            self.hyper_from_def += 1
                            return voc, hyp

        for h_path in h_paths:
            for hyp in h_path:
                in_voc = [lemma for lemma in hyp.lemma_names() if x != lemma
                          and self.checker.is_in_vocabulary(word=lemma) and not WNManager.is_expression(lemma=lemma)]
                if len(in_voc) != 0:
                    voc = Random.randomchoice(in_voc)
                    self.hyper_from_tax += 1
                    return voc, hyp
        raise KeyError('No parent found')


class DefBERT:
    def __init__(self, model_name='bert-base-uncased'):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name, output_hidden_states=True)
        self.parent_model = ParentModel(self.tokenizer)
        self.bert.to(device)

    def _get_inputs_from(self, couples):
        syn_list = []
        parent_list = []
        parent_example_list = []

        sister_syn_list = []
        sister_word_list = []
        sister_example_list = []

        value_list = []

        entity_used = 0

        self.used_example_definitions = 0
        self.used_one_word_sentences = 0

        for (value, couple) in couples:
            couple: SisterOOVPair = couple
            try:
                parent, parent_syns = self.parent_model.in_voc_parent(couple.target_word,
                                                                      couple.target_pos[0].lower(),
                                                                      couple.target_synset)
            except KeyError as e:
                entity_used += 1
                parent, parent_syns = 'entity', wn.synsets('entity')[0]

            parent_example = None
            # dalla definizione del figlio
            for w in wn.synset(couple.target_synset).definition().split(' '):
                if w.startswith(parent) and (len(w) - len(parent)) in range(-3, 3):
                    parent_example = wn.synset(couple.target_synset).definition()
                    self.used_example_definitions += 1
                    break
            # oppure una one-word sentence
            if parent_example is None:
                self.used_one_word_sentences += 1
                parent_example = parent

            sister_examples = wn.synset(couple.sister_synset).examples()
            sister_example = None
            for i in range(0, len(sister_examples)):
                if couple.sister_word in str(sister_examples[i]).split(' '):
                    sister_example = sister_examples[i]
                    break

            if sister_example is not None:
                syn_list.append(couple.target_synset)
                parent_list.append(parent)
                parent_example_list.append(parent_example)

                sister_syn_list.append(couple.sister_synset)
                sister_word_list.append(couple.sister_word)
                sister_example_list.append(sister_example)

                value_list.append(value)

            else:
                raise (KeyError('Examples not found'))

        df = pd.DataFrame([])
        df['syn'] = syn_list
        df['parent'] = parent_list
        df['parent_example'] = parent_example_list

        df['sister_syn'] = sister_syn_list
        df['sister_word'] = sister_word_list
        df['sister_example'] = sister_example_list

        sentences_parent = ["[CLS] " + parent + " [SEP]" for parent in df.parent.values]
        sentences_parent_examples = ["[CLS] " + eparent + " [SEP]" for eparent in df.parent_example.values]

        sister_sentences_word = ["[CLS] " + w + " [SEP]" for w in df.sister_word.values]
        sister_sentences_examples = ["[CLS] " + e + " [SEP]" for e in df.sister_example.values]

        tokenized_parents = [self.tokenizer.tokenize(parent) for parent in sentences_parent]
        tokenized_parent_examples = []
        indexes_parent = []

        tokenized_sister = [self.tokenizer.tokenize(w) for w in sister_sentences_word]
        tokenized_examples_sister = []
        indexes_sister = []

        for i in range(0, len(tokenized_sister)):
            if len(tokenized_sister[i]) == 3 and len(tokenized_parents[i]) == 3:

                tokenized_parent_example = self.tokenizer.tokenize(sentences_parent_examples[i])
                j_parent = None
                for k in range(0, len(tokenized_parent_example)):
                    w = tokenized_parent_example[k]
                    if w.startswith(tokenized_parents[i][1]) and \
                            (len(w) - len(tokenized_parents[i][1])) in range(-3, 3):
                        j_parent = k
                        break
                tokenized_parent_examples.append(tokenized_parent_example)
                indexes_parent.append(j_parent)
                """print("----------------")
                print(tokenized_parent_example)
                print(j_parent)
                print(tokenized_parents[i][1])"""

                tokenized_example_sister = self.tokenizer.tokenize(sister_sentences_examples[i])
                j_sister = None
                for k in range(0, len(tokenized_example_sister)):
                    w = tokenized_example_sister[k]
                    if w.startswith(tokenized_sister[i][1]) and (len(w) - len(tokenized_sister[i][1])) in range(-3, 3):
                        j_sister = k
                        break

                tokenized_examples_sister.append(tokenized_example_sister)
                indexes_sister.append(j_sister)

                """print("----------------")
                print(tokenized_example_sister)
                print(j_sister)
                print(tokenized_sister[i][1])"""

            else:
                raise KeyError('words tokenized as:' + str(tokenized_sister[i]) + ' cause it is not in vocabulary')

        input_ids_example_parent = pad_sequences(
            [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_parent_examples],
            maxlen=len(max(tokenized_parent_examples, key=lambda x: len(x))),
            dtype="long", truncating="post", padding="post")

        input_ids_example_sister = pad_sequences(
            [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples_sister],
            maxlen=len(max(tokenized_examples_sister, key=lambda x: len(x))),
            dtype="long", truncating="post", padding="post")
        return torch.tensor(input_ids_example_parent, device=device), torch.tensor(indexes_parent, device=device), \
               torch.tensor(input_ids_example_sister, device=device), torch.tensor(indexes_sister, device=device), \
               value_list

    def calculate_spearmanr(self, couples):
        evaluator = CosineSimilarity()

        tensor_parent_examples, parent_indexes, \
        tensor_sister, indexes_sister, sim_values = self._get_inputs_from(couples)

        dataloader = DataLoader(TensorDataset(tensor_parent_examples, parent_indexes,
                                              tensor_sister, indexes_sister), batch_size=10)
        cosines = []
        for batch in dataloader:
            parent_examples = batch[0]
            parent_indexes = batch[1]
            """print(definition.size())"""

            sister_examples = batch[2]
            sister_indexes = batch[3]

            """print(sister_example.size())
            print(sister_index.size())"""

            outputs_parent_examples = self.bert(parent_examples)
            outputs_sister_examples = self.bert(sister_examples)

            last_hidden_state_parent_examples = outputs_parent_examples[0]
            last_hidden_state_sister_examples = outputs_sister_examples[0]

            last_hidden_state_parent_examples = last_hidden_state_parent_examples.tolist()
            parent_indexes = parent_indexes.tolist()

            last_hidden_state_sister_examples = last_hidden_state_sister_examples.tolist()
            sister_indexes = sister_indexes.tolist()

            for i in range(0, len(last_hidden_state_parent_examples)):
                embedding_parent = last_hidden_state_parent_examples[i][parent_indexes[i]]
                embedding_sister = last_hidden_state_sister_examples[i][sister_indexes[i]]

                """print('words in example size', embedding_word.size())
                print('words in example size', embedding_sister.size())"""

                cos = evaluator(torch.tensor([embedding_parent]),
                                torch.tensor([embedding_sister])).tolist()[0]

                cosines.append(cos)

            return spearmanr(sim_values, cosines)

    def _get_model_definition_inputs(self, correlations):
        syn_list = []
        parent_list = []
        parent_example_list = []

        word_list = []
        example_list = []

        entity_used = 0

        used_example_definitions = 0
        used_one_word_sentences = 0

        for ws in correlations:
            try:
                parent, parent_syns = self.parent_model.in_voc_parent(ws['w1'], ws['target_pos'][0].lower(), ws['s1'])
            except KeyError as e:
                """print(ws.to_dict())
                print(e)"""
                entity_used += 1
                parent, parent_syns = 'entity', wn.synsets('entity')[0]

            s = wn.synset(ws['s1'])
            examples = s.examples()

            example = None
            for i in range(0, len(examples)):
                if ws['w1'] in str(examples[i]).split(' '):
                    example = examples[i]
                    break

            parent_example = None
            # dalla definizione del figlio
            for w in s.definition().split(' '):
                if w.startswith(parent) and (len(w) - len(parent)) in range(-3, 3):
                    parent_example = s.definition()
                    break

            if example is not None:
                if parent_example is not None:
                    used_example_definitions += 1
                else:
                    used_one_word_sentences += 1
                    parent_example = parent

                syn_list.append(ws['s1'])

                parent_list.append(parent)
                word_list.append(ws['w1'])

                parent_example_list.append(parent_example)
                example_list.append(example)

        """print('dataset_size:', len(syn_list))
        print('entity_used:', entity_used)

        print('parents from def', self.parent_model.hyper_from_def)
        print('parents from tax', self.parent_model.hyper_from_tax)
        print('parents from ke', entity_used)

        print('used def', used_example_definitions)
        print('used word', used_one_word_sentences)"""

        df = pd.DataFrame([])
        df['syns'] = syn_list

        df['parent'] = parent_list
        df['parent_example'] = parent_example_list

        df['word'] = word_list
        df['example'] = example_list

        sentences_parent = ["[CLS] " + parent + " [SEP]" for parent in df.parent.values]
        sentences_word = ["[CLS] " + word + " [SEP]" for word in df.word.values]

        sentences_parent_examples = ["[CLS] " + d + " [SEP]" for d in df.parent_example.values]
        sentences_examples = ["[CLS] " + example + " [SEP]" for example in df.example.values]

        tokenized_parents = [self.tokenizer.tokenize(parent) for parent in sentences_parent]
        tokenized_words = [self.tokenizer.tokenize(word) for word in sentences_word]

        tokenized_parent_examples = []
        tokenized_examples = []

        indexes_parent = []
        indexes = []

        for i in range(0, len(tokenized_words)):
            if len(tokenized_words[i]) == 3 and len(tokenized_parents[i]) == 3:

                tokenized_example = self.tokenizer.tokenize(sentences_examples[i])
                j = None
                for k in range(0, len(tokenized_example)):
                    w = tokenized_example[k]
                    if w.startswith(tokenized_words[i][1]) and (len(w) - len(tokenized_words[i][1])) in range(-3, 3):
                        j = k
                        break

                """print("----------------")
                print(tokenized_example)
                print(j)
                print(tokenized_words[i][1])"""
                tokenized_examples.append(tokenized_example)
                indexes.append(j)

                tokenized_parent_example = self.tokenizer.tokenize(sentences_parent_examples[i])
                j_parent = None
                for k in range(0, len(tokenized_parent_example)):
                    w = tokenized_parent_example[k]
                    if w.startswith(tokenized_parents[i][1]) and \
                            (len(w) - len(tokenized_parents[i][1])) in range(-3, 3):
                        j_parent = k
                        break

                """print("----------------")
                print(tokenized_parent_example)
                print(j_parent)
                print(tokenized_parents[i][1])"""
                tokenized_parent_examples.append(tokenized_parent_example)
                indexes_parent.append(j_parent)
            else:
                raise KeyError('words tokenized as:' + str(tokenized_words[i]) + ' cause it is not in vocabulary')

        input_ids_parent_examples = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in
                                                   tokenized_parent_examples],
                                                  maxlen=len(max(tokenized_parent_examples, key=lambda x: len(x))),
                                                  dtype="long", truncating="post", padding="post")
        input_ids_examples = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples],
                                           maxlen=len(max(tokenized_examples, key=lambda x: len(x))),
                                           dtype="long", truncating="post", padding="post")

        return torch.tensor(input_ids_parent_examples, device=device), \
               torch.tensor(indexes_parent, device=device), \
               torch.tensor(input_ids_examples, device=device), \
               torch.tensor(indexes, device=device)

    def compare_embeddings_lemma_definition(self, words_in_synsets):
        input_tensor_parent_examples, parent_indexes, \
        input_tensor_examples, indexes = self._get_model_definition_inputs(words_in_synsets)

        """print(input_tensor_parent_examples.size())
        print(parent_indexes.size())
        print(input_tensor_examples.size())
        print(indexes.size())"""

        dataloader = DataLoader(TensorDataset(input_tensor_parent_examples, parent_indexes,
                                              input_tensor_examples, indexes),
                                batch_size=32)
        cosines = []
        evaluator = CosineSimilarity()

        self.bert.eval()
        with torch.no_grad():
            for batch in dataloader:
                parent_examples = batch[0]
                parent_indexes = batch[1]
                examples = batch[2]
                indexes = batch[3]

                outputs_parent_examples = self.bert(parent_examples)
                outputs_examples = self.bert(examples)

                # Prendiamo l'output dall'ultimo layer
                last_hidden_state_parent_examples = outputs_parent_examples[0]
                last_hidden_state_examples = outputs_examples[0]

                last_hidden_state_parent_examples = last_hidden_state_parent_examples.tolist()
                parent_indexes = parent_indexes.tolist()

                last_hidden_state_examples = last_hidden_state_examples.tolist()
                indexes = indexes.tolist()

                for i in range(0, len(last_hidden_state_parent_examples)):
                    embedding_parent = last_hidden_state_parent_examples[i][parent_indexes[i]]
                    embedding_word = last_hidden_state_examples[i][indexes[i]]

                    cosines.append(evaluator(torch.tensor([embedding_parent]),
                                             torch.tensor([embedding_word])).tolist()[0])

            return cosines