import os

from statsmodels.stats.descriptivestats import sign_test
from scipy.stats import wilcoxon

from preprocessing.w2v_preprocessing_embedding import PreprocessingWord2VecEmbedding
from utility.models_utility import model_by
from utility.bert_vocab import BertVocab, BertW2VVocab
from baselines.Additive import Additive
from baselines.HeadModel import HeadModel
from defiNNet.DefiNNet import DefiNNet
from utility.distributions import UnknownDistribution
from utility.randomfixedseed import Random
from utility.similarity_evaluator import SimilarityEvaluator
from nltk.corpus import wordnet as wn


def predict_definitions_according_to(test_model, correlation):
    if isinstance(test_model, HeadModel):
        prediction_def = test_model.predict(correlation['w1'], pos_tag=correlation['target_pos'][0].lower(),
                                                synset=correlation['s1'])
        return prediction_def

    if isinstance(test_model, Additive):
        prediction_def = test_model.predict_analyzed(correlation["first"])
        return prediction_def

    if isinstance(test_model, DefiNNet):
        prediction_def = test_model.predict_analyzed(correlation['first'][0], correlation['w1_pos'],
                                                     correlation['first'][1], correlation['w2_pos'],
                                                     correlation['target_pos'])
        return prediction_def


class Test_Definition_Lemma:
    def __init__(self, model, output_path_dir):
        self.model = model
        if not os.path.exists(output_path_dir):
            os.mkdir(output_path_dir)
        self.output_path = output_path_dir

    def run(self, model_name, tests, seed, pos=None):
        tests_examples = [test for test in tests if (pos is None or test['target_pos'].startswith(pos.upper()))]

        cosines = []
        evaluator = SimilarityEvaluator('cosine_similarity')
        for correlation in tests_examples:
            prediction_def = predict_definitions_according_to(self.model, correlation)
            prediction_word = self.model.preprocessor.get_vector(correlation["w1"])

            pred = evaluator.similarity_function(prediction_def, prediction_word)
            if type(self.model) is DefiNNet:
                cosines.append(- pred[0])
            else:
                cosines.append(- pred)

        print('\t'.join(['tests', str(len(tests_examples))]))
        distr = UnknownDistribution(cosines)
        seed_dir = os.path.join(self.output_path, 'seed_' + seed)
        if not os.path.exists(seed_dir):
            os.mkdir(seed_dir)

        distr.save(output_path=os.path.join(seed_dir, pos + '_' + model_name + '.png'), title='Cosines distributions')
        with open(os.path.join(seed_dir, 'cosines_'+pos + '_' + model_name +'.txt'), 'w+') as f:
            lines = []
            for cosine in cosines:
                lines.append(str(cosine) + '\t#\n')

            f.writelines(lines)

        return distr.mu, distr.std, cosines


class Test_BertDefinition_Lemma:
    def __init__(self, model, output_path_dir, bert_vocab_path):
        self.model = model
        if not os.path.exists(output_path_dir):
            os.mkdir(output_path_dir)
        self.output_path = output_path_dir
        self.bert_vocab = BertVocab.read_vocab_file(vocab_file=bert_vocab_path)

    def run(self, model_name, tests, pos=None):
        tests_examples = []
        for test in tests:
            if pos is not None and not test['target_pos'].startswith(pos.upper()):
                continue

            if test['w1'] not in self.bert_vocab:
                continue

            synset = wn.synset(test['s1'])

            examples = synset.examples()
            example = None
            for i in range(0, len(examples)):
                if test['w1'] in str(examples[i]).split(' '):
                    example = examples[i]
                    break

            if example is None:
                continue

            tests_examples.append(test)

        if pos is None:
            pos = ''
        with open('test_bert_' + pos + '.txt', 'w+') as f:
            lines = []
            for test in tests_examples:
                line = ''
                for key in test:
                    line += key + ':' + str(test[key]) + '\t'
                lines.append(line + '\n')
            f.writelines(lines)

        cosines = []
        evaluator = SimilarityEvaluator('cosine_similarity')
        if model_name != 'defBERT_CLS' and model_name != 'defBERT':
            for correlation in tests_examples:
                prediction_def = predict_definitions_according_to(self.model, correlation)
                prediction_word = self.model.preprocessor.get_vector(correlation["w1"])

                pred = evaluator.similarity_function(prediction_def, prediction_word)
                if type(self.model) is DefiNNet:
                    cosines.append(- pred[0])
                else:
                    cosines.append(- pred)
        else:
            cosines = self.model.compare_embeddings_lemma_definition(tests_examples)

        print('\t'.join(['tests', str(len(tests_examples))]))

        distr = UnknownDistribution(cosines)

        distr.save(output_path=os.path.join(self.output_path, pos + '_' + model_name + '.png'),
                   title='Cosines distributions')
        with open(os.path.join(self.output_path, 'cosines_' + pos + '_' + model_name + '.txt'), 'w+') as f:
            lines = []
            for cosine in cosines:
                lines.append(str(cosine) + '\t#\n')

            f.writelines(lines)

        return distr.mu, distr.std, cosines


def all_descendant_files_of(base):
    input_paths = []
    for file in os.listdir(base):
        input_path = os.path.join(base, file)
        if os.path.isfile(input_path):
            input_paths.append(input_path)
    return input_paths


def test_definition(model_mappings=None):
    base_data_dir = 'data/task1_1_sim_definition_lemma'
    if not os.path.exists(base_data_dir):
        os.mkdir(base_data_dir)

    if model_mappings is None:
        pretained_model = 'data/pretrained_embeddings/GoogleNews-vectors-negative300.bin'

        model_mappings = {
            'definnet': ['definnet', pretained_model],
            'additive': ['additive_model', pretained_model],
            'head': ['head_model', pretained_model],
        }

    pos_tags = ['n', 'v']
    seeds = ['19']

    definitions_path = 'data/definitions_test/'
    model_cosines = {}

    for model_name in model_mappings:
        model_cosines[model_name] = {}
        test_model = model_by(model_mappings[model_name])

        output_path_dir = os.path.join(base_data_dir, model_name)
        task = Test_Definition_Lemma(test_model, output_path_dir=output_path_dir)

        for pos in pos_tags:
            model_cosines[model_name][pos] = {}
            test_paths = all_descendant_files_of(os.path.join(definitions_path, pos))

            s1_index = 0
            w1_index = 1

            first_indexes = [2, 4]
            target_pos_index = 4
            w1_pos = 5
            w2_pos = 6

            tests = []

            for test_path in test_paths:
                with open(test_path, 'r') as f:
                    lines = f.readlines()

                    for line in lines:
                        line = line.split('\t')
                        correlation = {
                            's1': line[s1_index],
                            'w1': line[w1_index],
                            'first': line[first_indexes[0]: first_indexes[1]],
                            'target_pos': line[target_pos_index],
                            'w1_pos': line[w1_pos],
                            'w2_pos': line[w2_pos]
                        }
                        tests.append(correlation)

            for seed in seeds:
                Random.set_seed(seed)
                distrmu, distrstd, cosines = task.run(model_name, tests, seed, pos)
                model_cosines[model_name][pos][seed] = cosines
                print('\t'.join([model_name, seed, pos, str(distrmu), str(distrstd)]))

    targets = ['definnet']
    baselines = ['additive', 'parent']
    for target in targets:
        if target in model_mappings:
            for baseline in baselines:
                if baseline in model_mappings:
                    for pos in pos_tags:
                        for seed in seeds:
                            statistic, p_value = sign_test(
                                samp=[model_cosines[target][pos][seed][i] - model_cosines[baseline][pos][seed][i]
                                      for i in range(0, len(model_cosines[target][pos][seed]))], mu0=0)
                            print("\t".join(['sign', target, baseline, pos, seed, "p-value="+str(p_value)]))

                            statistic, p_value = wilcoxon(x=model_cosines[target][pos][seed],
                                                          y=model_cosines[baseline][pos][seed], alternative='greater')
                            print("\t".join(['wilcoxon', target, baseline, pos, seed, "p-value=" + str(p_value)]))


def test_definition_with_bert(model_mappings=None):
    base_data_dir = 'data/task1_bert_sim_definition_lemma'
    if not os.path.exists(base_data_dir):
        os.mkdir(base_data_dir)

    pretained_model = 'data/pretrained_embeddings/GoogleNews-vectors-negative300.bin'

    if model_mappings is None:
        model_mappings = {
            'definnet': ['definnet', pretained_model],
            'additive': ['additive_model', pretained_model],
            'head': ['head_model', pretained_model],
            'defBERT': ['defBERT', 'bert-base-uncased'],
            'defBERT_CLS': ['defBERT_CLS', 'bert-base-uncased']
        }

    pos_tags = ['n', 'v']

    definitions_path = 'data/definitions_test/'
    model_cosines = {}

    preprocessor = PreprocessingWord2VecEmbedding(pretained_model, binary=True).model
    for model_name in model_mappings:
        model_cosines[model_name] = {}
        test_model = model_by(model_mappings[model_name])

        output_path_dir = os.path.join(base_data_dir, model_name)
        task = Test_BertDefinition_Lemma(test_model,
                                         output_path_dir=output_path_dir,
                                         bert_vocab_path='data/bert_vocabulary_in_synset.txt')

        for pos in pos_tags:
            model_cosines[model_name][pos] = []
            test_paths = all_descendant_files_of(os.path.join(definitions_path, pos))

            s1_index = 0
            w1_index = 1

            first_indexes = [2, 4]
            target_pos_index = 4
            w1_pos = 5
            w2_pos = 6

            tests = []
            policy = BertW2VVocab(preprocessor, bert_vocab_path='data/bert_vocabulary_in_synset.txt')
            for test_path in test_paths:
                with open(test_path, 'r') as f:
                    lines = f.readlines()

                    for line in lines:
                        split = line.split('\t')
                        correlation = {
                            's1': split[s1_index],
                            'w1': split[w1_index],
                            'first': split[first_indexes[0]: first_indexes[1]],
                            'target_pos': split[target_pos_index],
                            'w1_pos': split[w1_pos],
                            'w2_pos': split[w2_pos]
                        }

                        if policy.comparable_definitions(correlation, pos):
                            tests.append(correlation)
                            policy.add(line)

                    output_comparable = 'data/comparable_definitions_test_v4'
                    if not os.path.exists(output_comparable):
                        os.mkdir(output_comparable)
                    output_comparable_pos = os.path.join(output_comparable, pos)
                    if not os.path.exists(output_comparable_pos):
                        os.mkdir(output_comparable_pos)
                    policy.memorize(output_comparable_pos, original_file=test_path)

            distrmu, distrstd, cosines = task.run(model_name, tests, pos)
            model_cosines[model_name][pos] = cosines
            print('\t'.join([model_name, pos, str(distrmu), str(distrstd)]))

    targets = ['definnet', 'defBERT']
    for target in targets:
        if target in model_mappings:
            for baseline in [x for x in model_mappings if x != target]:
                for pos in pos_tags:
                    statistic, p_value = sign_test(samp=[model_cosines[target][pos][i] - model_cosines[baseline][pos][i]
                                  for i in range(0, len(model_cosines[target][pos]))], mu0=0)
                    print("\t".join(['sign', target, baseline, pos, "p-value=" + str(p_value)]))

                    statistic, p_value = wilcoxon(x=model_cosines[target][pos],
                                                  y=model_cosines[baseline][pos], alternative='greater')
                    print("\t".join(['wilcoxon', target, baseline, pos, "p-value=" + str(p_value)]))


