import torch
import pandas as pd
from keras_preprocessing.sequence import pad_sequences
from scipy.stats import spearmanr
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, TensorDataset
import numpy
from nltk.corpus import wordnet as wn
from pytorch_transformers import BertTokenizer, BertModel
from sister_terms_similarity.pedersen_similarities import SisterOOVPair

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class OOVExampleHandler:
    instance = None
    path = None
    @staticmethod
    def get_instance_from(pos, path):
        if OOVExampleHandler.path is None or path != OOVExampleHandler.path:
            OOVExampleHandler.path = path
            OOVExampleHandler.instance = OOVExampleHandler([(pos, path)])
            return OOVExampleHandler.instance
        else:
            return OOVExampleHandler.instance

    def __init__(self, path_pos):
        self.example_by_pos = {}
        self.not_found = []
        for (pos, path) in path_pos:
            if pos not in self.example_by_pos:
                self.example_by_pos[pos] = OOVExampleHandler.from_csv(path)
            else:
                examples = OOVExampleHandler.from_csv(path)
                for oov in examples:
                    self.example_by_pos[pos][oov] = examples[oov]

    @staticmethod
    def from_csv(path):
        example_by_words = {}

        df = pd.read_csv(path)
        oovs = [x.lower().replace('\t', '') for x in df.word.values]
        sentences = [x.lower() for x in df.sentence.values]

        for i in range(0, len(oovs)):
            if oovs[i] in sentences[i].split(' '):
                example_by_words[oovs[i]] = sentences[i]
        #print(len(example_by_words.keys()))
        return example_by_words

    def get_example(self, pos, oov):
        oov = oov.lower()
        try:
            return self.example_by_pos[pos][oov]
        except KeyError:
            #print('key error:'+oov)
            return None


class BERTwordpieces:
    def __init__(self, model_name='bert-base-uncased'):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name, output_hidden_states=True)
        self.bert.to(device)

        path = 'data/similarity_pedersen_test/oov_sister_terms_with_definitions/seed_19/oov_in_sentence.csv'
        self.oov_example_handler = OOVExampleHandler(path_pos=[('_', path)])

    @staticmethod
    def _find_sub_list(sublist, l):
        sublistlen = len(sublist)
        for ind in (i for i, e in enumerate(l) if e == sublist[0]):
            if l[ind:ind + sublistlen] == sublist:
                return ind, ind + sublistlen - 1

    def _get_inputs_from(self, couples):
        missings_examples = []

        syn_list = []
        oov_list = []
        oov_examples_list = []

        sister_syn_list = []
        sister_word_list = []
        sister_example_list = []

        value_list = []

        for (value, couple) in couples:
            couple: SisterOOVPair = couple
            oov_examples = wn.synset(couple.target_synset).examples()
            """print(el.s1)
            print(el.w1)
            print(oov_examples)"""
            sister_examples = wn.synset(couple.sister_synset).examples()

            oov_example = None
            for i in range(0, len(oov_examples)):
                if couple.target_word in str(oov_examples[i]).split(' '):
                    oov_example = oov_examples[i]
                    break
            if oov_example is None:
                oov_example = self.oov_example_handler.get_example(pos='_', oov=couple.target_word)

            sister_example = None
            for i in range(0, len(sister_examples)):
                if couple.sister_word in str(sister_examples[i]).split(' '):
                    sister_example = sister_examples[i]
                    break

            if sister_example is not None:
                if oov_example is None:
                    missings_examples.append(couple.target_word)
                    oov_example = couple.target_word

                syn_list.append(couple.target_synset)
                oov_list.append(couple.target_word)
                oov_examples_list.append(oov_example)

                sister_syn_list.append(couple.sister_synset)
                sister_word_list.append(couple.sister_word)
                sister_example_list.append(sister_example)

                value_list.append(value)
            else:
                raise (KeyError('Examples not found'))

        df = pd.DataFrame([])
        df['syn'] = syn_list
        df['oov'] = oov_list
        df['oov_example'] = oov_examples_list

        df['sister_syn'] = sister_syn_list
        df['sister_word'] = sister_word_list
        df['sister_example'] = sister_example_list

        sentences_oov = ["[CLS] " + o + " [SEP]" for o in df.oov.values]
        sentences_oov_example = ["[CLS] " + e + " [SEP]" for e in df.oov_example.values]

        sister_sentences_word = ["[CLS] " + w + " [SEP]" for w in df.sister_word.values]
        sister_sentences_examples = ["[CLS] " + e + " [SEP]" for e in df.sister_example.values]

        tokenized_oovs = [self.tokenizer.tokenize(o) for o in sentences_oov]
        tokenized_oov_examples = []
        indexes_oovs = []

        tokenized_sister = [self.tokenizer.tokenize(w) for w in sister_sentences_word]
        tokenized_examples_sister = []
        indexes_sister = []

        for i in range(0, len(tokenized_sister)):
            if len(tokenized_sister[i]) == 3:
                tokenized_oov_example = self.tokenizer.tokenize(sentences_oov_example[i])
                j_1, j_2 = BERTwordpieces._find_sub_list(tokenized_oovs[i][1:-1], tokenized_oov_example)
                tokenized_oov_examples.append(tokenized_oov_example)
                indexes_oovs.append([j_1, j_2])

                tokenized_example_sister = self.tokenizer.tokenize(sister_sentences_examples[i])
                j_sister = tokenized_example_sister.index(tokenized_sister[i][1])
                tokenized_examples_sister.append(tokenized_example_sister)
                indexes_sister.append(j_sister)

                """print('----------')
                print(tokenized_oov_example)
                print(tokenized_oovs[i])
                print(j_1, j_2)
                print('---')
                print(tokenized_example_sister)
                print(tokenized_sister[i])
                print(j_sister)
                print('----------')"""
            else:
                raise KeyError('words tokenized as:' + str(tokenized_sister[i]) + ' cause it is not in vocabulary')

        input_ids_oov_examples = pad_sequences(
            [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_oov_examples],
            maxlen=len(max(tokenized_oov_examples, key=lambda x: len(x))),
            dtype="long", truncating="post", padding="post")

        input_ids_example_sister = pad_sequences(
            [self.tokenizer.convert_tokens_to_ids(t) for t in tokenized_examples_sister],
            maxlen=len(max(tokenized_examples_sister, key=lambda x: len(x))),
            dtype="long", truncating="post", padding="post")

        return torch.tensor(input_ids_oov_examples, device=device), torch.tensor(indexes_oovs, device=device), \
               torch.tensor(input_ids_example_sister, device=device), torch.tensor(indexes_sister, device=device), \
               value_list

    def calculate_spearmanr(self, couples):
        evaluator = CosineSimilarity()

        tensor_oov_examples, oov_indexes, \
        tensor_sister, indexes_sister, sim_values = self._get_inputs_from(couples)

        dataloader = DataLoader(TensorDataset(tensor_oov_examples, oov_indexes,
                                              tensor_sister, indexes_sister), batch_size=10)
        cosines = []
        for batch in dataloader:
            oov_examples = batch[0]
            oov_indexes = batch[1]

            sister_examples = batch[2]
            sister_indexes = batch[3]

            """print(sister_example.size())
            print(sister_index.size())"""

            outputs_oov_examples = self.bert(oov_examples)
            outputs_sister_examples = self.bert(sister_examples)

            last_hidden_state_oov_examples = outputs_oov_examples[0]
            last_hidden_state_sister_examples = outputs_sister_examples[0]

            last_hidden_state_oov_examples = last_hidden_state_oov_examples.tolist()
            oov_indexes = oov_indexes.tolist()

            last_hidden_state_sister_examples = last_hidden_state_sister_examples.tolist()
            sister_indexes = sister_indexes.tolist()

            for i in range(0, len(last_hidden_state_oov_examples)):
                # print(len(last_hidden_state_oov_examples[i][oov_indexes[i][0]:oov_indexes[i][1] + 1]))
                embedding_oov = [0 for _ in range(len(last_hidden_state_oov_examples[i][0]))]
                for x in last_hidden_state_oov_examples[i][oov_indexes[i][0]:oov_indexes[i][1] + 1]:
                    embedding_oov = numpy.add(embedding_oov, x).tolist()
                # print(embedding_oov)
                embedding_sister = last_hidden_state_sister_examples[i][sister_indexes[i]]

                cos = evaluator(torch.tensor([embedding_oov]),
                                torch.tensor([embedding_sister])).tolist()[0]

                cosines.append(cos)

            return spearmanr(sim_values, cosines)