import os

import torch

from BERT.DefBERT import ParentExampleHandler
from BERT.DefBERT import ParentModel
from keras_preprocessing.sequence import pad_sequences

from utility.randomfixedseed import Random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from utility.words_in_synset import WordInSynset
from nltk.corpus import wordnet as wn
import pandas as pd


class DefinitionsAndLemma_AsBertInput:
    @staticmethod
    def instantiate(mode, tokenizer):
        if mode == 'def_bert_cls':
            return ExampleEmbedding(tokenizer)
        if mode == 'def_bert_head':
            return ParentExampleEmbedding(tokenizer)
        if mode == 'bert_head_example':
            return ParentFromExampleEmbedding(tokenizer)
        raise NotImplemented()

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def get_inputs_from(self, words_in_synsets, output_path=None):
        pass


class ExampleEmbedding(DefinitionsAndLemma_AsBertInput):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)

    def get_inputs_from(self, words_in_synsets, output_path=None):
        syn_list = []
        def_list = []
        word_list = []
        example_list = []

        for ws in words_in_synsets:
            ws: WordInSynset = ws

            s = wn.synset(ws.synset_name)
            examples = s.examples()

            example = None
            for i in range(0, len(examples)):
                if ws.word in str(examples[i]).split(' '):
                    example = examples[i]
                    break

            if example is not None:
                syn_list.append(ws.synset_name)
                word_list.append(ws.word)
                def_list.append(s.definition())
                example_list.append(example)

        df = pd.DataFrame([])
        df['syns'] = syn_list
        df['defs'] = def_list
        df['word'] = word_list
        df['example'] = example_list

        sentences_word = ["[CLS] " + word + " [SEP]" for word in df.word.values]
        sentences_def = ["[CLS] " + d + " [SEP]" for d in df.defs.values]
        sentences_examples = ["[CLS] " + example + " [SEP]" for example in df.example.values]

        tokenized_words = [self.tokenizer.tokenize(word) for word in sentences_word]

        tokenized_defs = []
        tokenized_examples = []
        indexes = []

        for i in range(0, len(tokenized_words)):
            if len(tokenized_words[i]) == 3:
                tokenized_example = self.tokenizer.tokenize(sentences_examples[i])
                j = tokenized_example.index(tokenized_words[i][1])
                tokenized_def = self.tokenizer.tokenize(sentences_def[i])

                """print(tokenized_words[i])
                print(tokenized_example)
                print(j)"""

                tokenized_examples.append(tokenized_example)
                indexes.append(j)
                tokenized_defs.append(tokenized_def)
            else:
                raise KeyError('words tokenized as:' + str(tokenized_words[i]) + ' cause it is not in vocabulary')

        #input_ids_words = [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_words]
        input_ids_defs = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_defs],
                                       maxlen=len(max(tokenized_defs, key=lambda x: len(x))),
                                       dtype="long", truncating="post", padding="post")
        input_ids_examples = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples],
                                           maxlen=len(max(tokenized_examples, key=lambda x: len(x))),
                                           dtype="long", truncating="post", padding="post")

        return torch.tensor(input_ids_defs, device=device), \
               torch.tensor(input_ids_examples, device=device), \
               torch.tensor(indexes, device=device)


class ParentExampleEmbedding(DefinitionsAndLemma_AsBertInput):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)
        self.parent_model = ParentModel('bert-base-uncased')

    def get_inputs_from(self, words_in_synsets, output_path=None):
        syn_list = []
        parent_list = []
        parent_example_list = []

        word_list = []
        example_list = []

        for ws in words_in_synsets:
            ws: WordInSynset = ws

            try:
                parent, parent_syns = self.parent_model.in_voc_parent(ws.word, ws.pos[0].lower(), ws.synset_name)
            except KeyError as e:
                """print(ws.to_dict())
                print(e)"""
                parent, parent_syns = 'entity', wn.synsets('entity')[0]

            s = wn.synset(ws.synset_name)
            examples = s.examples()

            example = None
            for i in range(0, len(examples)):
                if ws.word in str(examples[i]).split(' '):
                    example = examples[i]
                    break

            parent_example = None

            for w in s.definition().split(' '):
                if w.startswith(parent) and (len(w) - len(parent)) in range(-3, 3):
                    parent_example = s.definition()
                    break

            if example is not None:
                if parent_example is None:
                    parent_example = parent

                syn_list.append(ws.synset_name)

                parent_list.append(parent)
                word_list.append(ws.word)

                parent_example_list.append(parent_example)
                example_list.append(example)


        df = pd.DataFrame([])
        df['syns'] = syn_list

        df['parent'] = parent_list
        df['parent_example'] = parent_example_list

        df['word'] = word_list
        df['example'] = example_list
        if output_path is not None:
            df.to_csv(output_path)

        sentences_parent = ["[CLS] " + parent + " [SEP]" for parent in df.parent.values]
        sentences_word = ["[CLS] " + word + " [SEP]" for word in df.word.values]

        sentences_parent_examples = ["[CLS] " + d + " [SEP]" for d in df.parent_example.values]
        sentences_examples = ["[CLS] " + example + " [SEP]" for example in df.example.values]

        tokenized_parents = [self.tokenizer.tokenize(parent) for parent in sentences_parent]
        tokenized_words = [self.tokenizer.tokenize(word) for word in sentences_word]

        tokenized_parent_examples = []
        tokenized_examples = []

        indexes_parent = []
        indexes = []

        for i in range(0, len(tokenized_words)):
            if len(tokenized_words[i]) == 3 and len(tokenized_parents[i]) == 3:

                tokenized_example = self.tokenizer.tokenize(sentences_examples[i])
                j = None
                for k in range(0, len(tokenized_example)):
                    w = tokenized_example[k]
                    if w.startswith(tokenized_words[i][1]) and (len(w) - len(tokenized_words[i][1])) in range(-3, 3):
                        j = k
                        break

                """print("----------------")
                print(tokenized_example)
                print(j)
                print(tokenized_words[i][1])"""
                tokenized_examples.append(tokenized_example)
                indexes.append(j)

                tokenized_parent_example = self.tokenizer.tokenize(sentences_parent_examples[i])
                j_parent = None
                for k in range(0, len(tokenized_parent_example)):
                    w = tokenized_parent_example[k]
                    if w.startswith(tokenized_parents[i][1]) and \
                            (len(w) - len(tokenized_parents[i][1])) in range(-3, 3):
                        j_parent = k
                        break

                """print("----------------")
                print(tokenized_parent_example)
                print(j_parent)
                print(tokenized_parents[i][1])"""
                tokenized_parent_examples.append(tokenized_parent_example)
                indexes_parent.append(j_parent)
            else:
                raise KeyError('words tokenized as:' + str(tokenized_words[i]) + ' cause it is not in vocabulary')

        input_ids_parent_examples = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in
                                                   tokenized_parent_examples],
                                                  maxlen=len(max(tokenized_parent_examples, key=lambda x: len(x))),
                                                  dtype="long", truncating="post", padding="post")
        input_ids_examples = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples],
                                           maxlen=len(max(tokenized_examples, key=lambda x: len(x))),
                                           dtype="long", truncating="post", padding="post")

        return torch.tensor(input_ids_parent_examples, device=device), \
               torch.tensor(indexes_parent, device=device), \
               torch.tensor(input_ids_examples, device=device), \
               torch.tensor(indexes, device=device)


class ParentFromExampleEmbedding(DefinitionsAndLemma_AsBertInput):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)
        self.parent_model = ParentModel('bert-base-uncased')
        base_path = 'data_BERT/sentences/task1/parent_from_example'
        self.parent_example_handler = ParentExampleHandler(
            path_pos=[('n', os.path.join(base_path, 'n_output_parent_0.csv')),
                      ('v', os.path.join(base_path, 'v_output_parent_0.csv'))])

    def get_inputs_from(self, words_in_synsets, output_path=None):
        syn_list = []
        parent_list = []
        parent_example_list = []

        word_list = []
        example_list = []

        for ws in words_in_synsets:
            ws: WordInSynset = ws

            try:
                parent, _ = self.parent_model.in_voc_parent(ws.word, ws.pos[0].lower(), ws.synset_name)
            except KeyError as e:
                parent, _ = 'entity', wn.synsets('entity')[0]

            s = wn.synset(ws.synset_name)
            examples = s.examples()

            example = None
            for i in range(0, len(examples)):
                if ws.word in str(examples[i]).split(' '):
                    example = examples[i]
                    break

            if example is not None:
                parent_example = self.parent_example_handler.get_example(pos=ws.pos[0].lower(), parent=parent)
                splitted_paren_example = parent_example.split()
                found = False
                for k in range(0, len(splitted_paren_example)):
                    w = splitted_paren_example[k]
                    if (w.startswith(parent) and (len(w) - len(parent)) in range(-3,3)) or w == parent + 'ing':
                        found = True
                        break
                if not found:
                    parent_example = parent

                syn_list.append(ws.synset_name)

                parent_list.append(parent)
                word_list.append(ws.word)

                parent_example_list.append(parent_example)
                example_list.append(example)

        df = pd.DataFrame([])
        df['syns'] = syn_list

        df['parent'] = parent_list
        df['parent_example'] = parent_example_list

        df['word'] = word_list
        df['example'] = example_list
        if output_path is not None:
            df.to_csv(output_path)

        sentences_parent = ["[CLS] " + parent + " [SEP]" for parent in df.parent.values]
        sentences_word = ["[CLS] " + word + " [SEP]" for word in df.word.values]

        sentences_parent_examples = ["[CLS] " + d + " [SEP]" for d in df.parent_example.values]
        sentences_examples = ["[CLS] " + example + " [SEP]" for example in df.example.values]

        tokenized_parents = [self.tokenizer.tokenize(parent) for parent in sentences_parent]
        tokenized_words = [self.tokenizer.tokenize(word) for word in sentences_word]

        tokenized_parent_examples = []
        tokenized_examples = []

        indexes_parent = []
        indexes = []

        for i in range(0, len(tokenized_words)):
            if len(tokenized_words[i]) == 3 and len(tokenized_parents[i]) == 3:
                tokenized_example = self.tokenizer.tokenize(sentences_examples[i])
                j = None
                for k in range(0, len(tokenized_example)):
                    w = tokenized_example[k]
                    if (w.startswith(tokenized_words[i][1]) and (len(w) - len(tokenized_words[i][1])) in range(-3, 3)):
                        j = k
                        break

                """print("----------------")
                print(tokenized_example)
                print(j)
                print(tokenized_words[i][1])"""
                tokenized_examples.append(tokenized_example)
                indexes.append(j)

                tokenized_parent_example = self.tokenizer.tokenize(sentences_parent_examples[i])
                j_parent = None
                for k in range(0, len(tokenized_parent_example)):
                    w = tokenized_parent_example[k]
                    if (w.startswith(tokenized_parents[i][1]) and (len(w) - len(tokenized_parents[i][1])) in range(-3, 3))\
                            or w == tokenized_parents[i][1]+'ing':
                        j_parent = k
                        break
                if j_parent is None:
                    tokenized_parent_example = tokenized_parents[i]
                    j_parent = 1

                """print("----------------")
                print(tokenized_parent_example)
                print(j_parent)
                print(tokenized_parents[i][1])"""
                tokenized_parent_examples.append(tokenized_parent_example)
                indexes_parent.append(j_parent)
            else:
                raise KeyError('words tokenized as:' + str(tokenized_words[i]) + ' cause it is not in vocabulary')

        input_ids_parent_examples = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in
                                                   tokenized_parent_examples],
                                                  maxlen=len(max(tokenized_parent_examples, key=lambda x: len(x))),
                                                  dtype="long", truncating="post", padding="post")
        input_ids_examples = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples],
                                           maxlen=len(max(tokenized_examples, key=lambda x: len(x))),
                                           dtype="long", truncating="post", padding="post")
        return torch.tensor(input_ids_parent_examples, device=device), \
               torch.tensor(indexes_parent, device=device), \
               torch.tensor(input_ids_examples, device=device), \
               torch.tensor(indexes, device=device)
