import pandas
import torch
from keras_preprocessing.sequence import pad_sequences

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import pandas as pd

from utility.words_in_synset import SynsetCouple


class PedersenSynsetCouple_AsBertInput:
    """
    Classe prototipo per gestire gli input di Bert a partire da una lista di oggetti SynsetCouple
    """

    @staticmethod
    def instantiate(mode, tokenizer):
        if mode == 'one_word_sentence':
            return PedersenSynsetCouple_OneWordSentenceEmbedding(tokenizer)
        if mode == 'word_in_example':
            return PedersenSynsetCouple_ExampleEmbedding(tokenizer)
        if mode == 'definition_to_definition':
            return PedersenSynsetCouple_DefinitionsEmbedding(tokenizer)
        raise NotImplemented('\'one_word_sentence\' and \'word_in_example\' are avaible modes')

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def get_inputs_from(self, pedersen_synset_couples, output_path=None):
        raise NotImplementedError('Use one of subclasses: SynsetCouple_OneWordSentenceEmbedding, '
                                  'SynsetCouple_ExampleEmbedding')


class PedersenSynsetCouple_OneWordSentenceEmbedding(PedersenSynsetCouple_AsBertInput):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)

    def get_inputs_from(self, pedersen_synset_couples, output_path=None):
        w1_list = []
        w2_list = []
        value_list = []

        for (value, couple) in pedersen_synset_couples:
            couple: SynsetCouple = couple
            w1_list.append(couple.w1)
            w2_list.append(couple.w2)
            value_list.append(value)

        df = pd.DataFrame([])
        df['w1'] = w1_list
        df['w2'] = w2_list
        df['sim_value'] = value_list
        if output_path is not None:
            df.to_csv(output_path)

        sentences_w1 = ["[CLS] " + w1 + " [SEP]" for w1 in df.w1.values]
        sentences_w2 = ["[CLS] " + w2 + " [SEP]" for w2 in df.w2.values]

        tokenized_w1s = [self.tokenizer.tokenize(w1) for w1 in sentences_w1]
        tokenized_w2s = [self.tokenizer.tokenize(w2) for w2 in sentences_w2]

        input_ids_w1 = [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_w1s]
        input_ids_w2 = [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_w2s]

        return torch.tensor(input_ids_w1, device=device), \
               torch.tensor(input_ids_w2, device=device), df.sim_value.values.tolist()


class PedersenSynsetCouple_DefinitionsEmbedding(PedersenSynsetCouple_AsBertInput):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)

    def get_inputs_from(self, pedersen_synset_couples, output_path=None):
        def1_list = []
        def2_list = []
        w1_list = []
        w2_list = []
        value_list = []

        for (value, couple) in pedersen_synset_couples:
            couple: SynsetCouple = couple
            def1_list.append(couple.s1.definition())
            def2_list.append(couple.s2.definition())
            w1_list.append(couple.w1)
            w2_list.append(couple.w2)
            value_list.append(value)

        df = pd.DataFrame([])
        df['def1'] = def1_list
        df['def2'] = def2_list
        df['w1'] = w1_list
        df['w2'] = w2_list
        df['sim_value'] = value_list
        if output_path is not None:
            df.to_csv(output_path)

        sentences_w1 = ["[CLS] " + w + " [SEP]" for w in df.w1.values]
        sentences_w2 = ["[CLS] " + w + " [SEP]" for w in df.w2.values]

        sentences_def1 = ["[CLS] " + d + " [SEP]" for d in df.def1.values]
        sentences_def2 = ["[CLS] " + d + " [SEP]" for d in df.def2.values]

        tokenized_w1s = [self.tokenizer.tokenize(w) for w in sentences_w1]
        tokenized_w2s = [self.tokenizer.tokenize(w) for w in sentences_w2]

        tokenized_def1s = [self.tokenizer.tokenize(d) for d in sentences_def1]
        tokenized_def2s = [self.tokenizer.tokenize(d) for d in sentences_def2]

        input_ids_w1s = [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_w1s]
        input_ids_w2s = [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_w2s]

        input_ids_def1s = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_def1s],
                                        maxlen=len(max(tokenized_def1s, key=lambda x: len(x))),
                                        dtype="long", truncating="post", padding="post")
        input_ids_def2s = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_def2s],
                                        maxlen=len(max(tokenized_def2s, key=lambda x: len(x))),
                                        dtype="long", truncating="post", padding="post")

        return torch.tensor(input_ids_w1s, device=device), \
               torch.tensor(input_ids_def1s, device=device), \
               torch.tensor(input_ids_w2s, device=device), \
               torch.tensor(input_ids_def2s, device=device), df.sim_value.values.tolist()


class PedersenSynsetCouple_ExampleEmbedding(PedersenSynsetCouple_AsBertInput):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)

    def get_inputs_from(self, pedersen_synset_couples, output_path=None):
        syn_list = []
        word_list = []
        definition_list = []
        example_list = []

        sister_syn_list = []
        sister_word_list = []
        sister_definition_list = []
        sister_example_list = []

        value_list = []

        for (value, el) in pedersen_synset_couples:
            el: SynsetCouple = el
            examples = el.s1.examples()
            sister_examples = el.s2.examples()

            example = None
            sister_example = None
            for i in range(0, len(examples)):
                if el.w1 in str(examples[i]).split(' '):
                    example = examples[i]
                    break
            for i in range(0, len(sister_examples)):
                if el.w2 in str(sister_examples[i]).split(' '):
                    sister_example = sister_examples[i]
                    break

            if example is not None and sister_example is not None:
                syn_list.append(el.s1.name())
                word_list.append(el.w1)
                definition_list.append(el.s1.definition())
                example_list.append(example)

                sister_syn_list.append(el.s2.name())
                sister_word_list.append(el.w2)
                sister_definition_list.append(el.s2.definition())
                sister_example_list.append(sister_example)

                value_list.append(value)
            else:
                raise Exception('OOV unexpected')

        df = pd.DataFrame([])
        df['syn'] = syn_list
        df['word'] = word_list
        df['example'] = example_list

        df['sister_syn'] = sister_syn_list
        df['sister_word'] = sister_word_list
        df['sister_example'] = sister_example_list

        df['sim_value'] = value_list
        if output_path is not None:
            df.to_csv(output_path)

        sentences_word = ["[CLS] " + w + " [SEP]" for w in df.word.values]
        sentences_examples = ["[CLS] " + e + " [SEP]" for e in df.example.values]

        sister_sentences_word = ["[CLS] " + w + " [SEP]" for w in df.sister_word.values]
        sister_sentences_examples = ["[CLS] " + e + " [SEP]" for e in df.sister_example.values]

        tokenized_word = [self.tokenizer.tokenize(w) for w in sentences_word]
        tokenized_examples = []
        indexes = []

        tokenized_sister = [self.tokenizer.tokenize(w) for w in sister_sentences_word]
        tokenized_examples_sister = []
        indexes_sister = []

        for i in range(0, len(tokenized_word)):
            if len(tokenized_word[i]) == 3 and len(tokenized_sister[i]) == 3:
                tokenized_example = self.tokenizer.tokenize(sentences_examples[i])
                j = tokenized_example.index(tokenized_word[i][1])
                tokenized_examples.append(tokenized_example)
                indexes.append(j)

                tokenized_example_sister = self.tokenizer.tokenize(sister_sentences_examples[i])
                j_sister = tokenized_example_sister.index(tokenized_sister[i][1])
                tokenized_examples_sister.append(tokenized_example_sister)
                indexes_sister.append(j_sister)
            else:
                raise KeyError('words tokenized as:' + str(tokenized_word[i]) + ' cause it is not in vocabulary')

        input_ids_word = [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_word]
        input_ids_example = pad_sequences([self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples],
                                          maxlen=len(max(tokenized_examples, key=lambda x: len(x))),
                                          dtype="long", truncating="post", padding="post")

        input_ids_sister = [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_sister]
        input_ids_example_sister = pad_sequences(
            [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples_sister],
            maxlen=len(max(tokenized_examples_sister, key=lambda x: len(x))),
            dtype="long", truncating="post", padding="post")

        return torch.tensor(input_ids_word, device=device), \
               torch.tensor(input_ids_example, device=device),torch.tensor(indexes, device=device), \
               torch.tensor(input_ids_sister, device=device), \
               torch.tensor(input_ids_example_sister, device=device), torch.tensor(indexes_sister, device=device),\
               df.sim_value.values.tolist()
