import os
from scipy.stats import spearmanr
import pandas as pd
from nltk.corpus import wordnet as wn
from statsmodels.stats.descriptivestats import sign_test
from scipy.stats import wilcoxon

from baselines.W2VModel import W2VModel
from utility.models_utility import model_by
from baselines.Additive import Additive
from baselines.FastTextModel import FastTextModel
from baselines.HeadModel import HeadModel
from defiNNet.DefiNNet import DefiNNet
from sister_terms_similarity.pedersen_similarities import SimilarityFunction, SisterOOVPair
from utility.cluster import ClusterMinDiam
from utility.distributions import UnknownDistribution
from utility.randomfixedseed import Random
from utility.similarity_evaluator import SimilarityEvaluator


def signtest(spearman_scores, measures, models_to_test):
    for measure in measures:
        print('sign')
        for (model_1_name, model_2_name) in models_to_test:
            model_1_score = [x.correlation for x in spearman_scores[measure][model_1_name]]
            model_2_score = [x.correlation for x in spearman_scores[measure][model_2_name]]

            n = len(model_1_score)
            M = 0

            pos = 0
            neg = 0
            for i in range(0, n):
                if model_1_score[i] - model_2_score[i] > M:
                    pos += 1
                if model_1_score[i] - model_2_score[i] < M:
                    neg += 1

            n1 = pos + neg
            k = pos

            statistic, p_value = sign_test(
                samp=[model_1_score[i] - model_2_score[i] for i in range(0, len(model_1_score))], mu0=0)
            print('\t'.join([model_1_name, model_2_name, f'n1 = {n1}', f'k = {k}', str(p_value)]))


def wilcoxontest(spearman_scores, measures, models_to_test):
    for measure in measures:
        print('wilcoxon')
        for (model_1_name, model_2_name) in models_to_test:
            model_1_score = [x.correlation for x in spearman_scores[measure][model_1_name]]
            model_2_score = [x.correlation for x in spearman_scores[measure][model_2_name]]

            n = len(model_1_score)
            M = 0

            pos = 0
            neg = 0
            for i in range(0, n):
                if model_1_score[i] - model_2_score[i] > M:
                    pos += 1
                if model_1_score[i] - model_2_score[i] < M:
                    neg += 1

            n1 = pos + neg
            k = pos

            statistic, p_value = wilcoxon(x=model_1_score, y=model_2_score, alternative='greater')
            print('\t'.join([measure, model_1_name, model_2_name, f'n1 = {n1}', f'k = {k}', str(p_value)]))


def retrieve_oov_pairs_divided_by_value_of_similarity(positive_input_path, negative_input_path, measure_name):
    similarity_function = SimilarityFunction.by_name(measure_name)
    ordered_pairs = {}

    positive_pairs = pd.read_csv(positive_input_path, sep='\t')
    negative_pairs = pd.read_csv(negative_input_path, sep='\t')

    for index, couple in positive_pairs.iterrows():
        target_syn = wn.synset(couple['target_synset'])
        sister_syn = wn.synset(couple['sister_synset'])

        similarity_value = similarity_function(target_syn, sister_syn)
        if similarity_value not in ordered_pairs:
            ordered_pairs[similarity_value] = []
        ordered_pairs[similarity_value].append(SisterOOVPair(couple))

    for index, couple in negative_pairs.iterrows():
        target_syn = wn.synset(couple['target_synset'])
        sister_syn = wn.synset(couple['sister_synset'])

        similarity_value = similarity_function(target_syn, sister_syn)
        if similarity_value not in ordered_pairs:
            ordered_pairs[similarity_value] = []
        ordered_pairs[similarity_value].append(SisterOOVPair(couple))

    return ordered_pairs


def collect_test_of_size(n_test, test_size, k_clusters: dict, ouput_path=None):
    tests = []
    exists_available = True

    i = 0
    while i in range(0, n_test) and exists_available:
        test = []
        available_centers = [center for center in k_clusters if len(k_clusters[center]) != 0]

        for j in range(0, test_size):
            if j + len(available_centers) < test_size:
                exists_available = False
                break

            center = Random.randomchoice(available_centers)
            available_centers.remove(center)

            d = Random.randomchoice(k_clusters[center])
            test.append(d)
            k_clusters[center].remove(d)

        if len(test) == test_size:
            test.sort()
            tests.append(test)
        else:
            break
        i += 1

    if ouput_path is not None:
        with open(ouput_path, 'w+') as output:
            output.write('\t'.join(
                ['TEST_N', 'target_synset', 'target_word', 'data_w1', 'data_w2', 'target_pos', 'w1_pos', 'w2_pos',
                 'definition', 'sister_synset', 'sister_word', 'value', '#\n']))
            for i in range(0, len(tests)):
                for d in tests[i]:
                    (similarity_value, couple) = d
                    output.write('\t'.join([str(i)] + couple.to_list() + [str(similarity_value), '#\n']))
    return tests


def predict_target_sister_according_to(test_model, couple):
    if isinstance(test_model, HeadModel):
        prediction = test_model.predict(couple.target_word, pos_tag=couple.target_pos[0].lower(),
                                        synset=couple.target_synset)
        sister = test_model.preprocessor.get_vector(couple.sister_word)
        return prediction, sister

    if isinstance(test_model, Additive):
        prediction = test_model.predict_analyzed([couple.data_w1, couple.data_w2])
        sister = test_model.preprocessor.get_vector(couple.sister_word)
        return prediction, sister

    if isinstance(test_model, DefiNNet):
        prediction = test_model.predict_analyzed(couple.data_w1, couple.w1_pos, couple.data_w2,
                                                 couple.w2_pos, couple.target_pos)
        sister = test_model.preprocessor.get_vector(couple.sister_word)
        return prediction, sister

    if isinstance(test_model, FastTextModel):
        prediction = test_model.predict(couple.target_word)
        sister = test_model.predict(couple.sister_word)
        return prediction, sister

    if isinstance(test_model, W2VModel):
        prediction = test_model.predict(couple.target_word)
        sister = test_model.predict(couple.sister_word)
        return prediction, sister


def test_all_models_on_micro_lists(model_mappings, root_data_path, destination_dir,
                                   similarities_function_names=None):
    Random.set_seed(int(99))
    destination_dir = os.path.join(root_data_path, destination_dir)
    if not os.path.exists(destination_dir):
        os.mkdir(destination_dir)

    if similarities_function_names is None:
        similarities_function_names = ['wup', 'res', 'path']

    evaluator = SimilarityEvaluator('cosine_similarity')

    K = 15
    N_TEST = 1000
    TEST_SIZE = 7

    spearman_scores = {}
    for measure in similarities_function_names:
        measure_dir = os.path.join(destination_dir, measure)
        if not os.path.exists(measure_dir):
            os.mkdir(measure_dir)
        spearman_scores[measure] = {}
        n_couple_clusters = retrieve_oov_pairs_divided_by_value_of_similarity(
            positive_input_path=os.path.join(root_data_path, 'oov_definition_sister_terms_positive.txt'),
            negative_input_path=os.path.join(root_data_path, 'oov_definition_sister_terms_negative.txt'),
            measure_name=measure)

        k_clusters = ClusterMinDiam.k_clusters_of_min_diameter(k=K, n_clusters=n_couple_clusters)

        tests = collect_test_of_size(n_test=N_TEST, test_size=TEST_SIZE, k_clusters=k_clusters,
                                     ouput_path=os.path.join(measure_dir, measure + '_micro_lists_test.txt'))

        for model_name in model_mappings:
            model = model_by(model_mappings[model_name])

            model_dir = os.path.join(measure_dir, model_name)
            if not os.path.exists(model_dir):
                os.mkdir(model_dir)

            spearman_scores[measure][model_name] = []

            for i in range(0, len(tests)):
                similarity_values = []
                cosines = []
                for d in tests[i]:
                    (similarity_value, couple) = d

                    target, sister = predict_target_sister_according_to(model, couple)

                    pred = evaluator.similarity_function(target, sister)
                    if type(model) is DefiNNet:
                        cosines.append(- pred[0])
                    else:
                        cosines.append(- pred)

                    similarity_values.append(similarity_value)

                s = spearmanr(cosines, similarity_values)
                spearman_scores[measure][model_name].append(s)

            distribution = UnknownDistribution(data=[x.correlation for x in spearman_scores[measure][model_name]])
            distribution.save(output_path=os.path.join(model_dir, measure + '_' + model_name + '_hist_test.png'),
                              title=f"{measure} mini-lists spearman results")
            print('\t'.join([measure, model_name, str(distribution.mu), str(distribution.std), str(len(tests))]))

    return spearman_scores


def test_all_models_on_micro_lists_with_bert_vocab(model_mappings, root_data_path, destination_dir,
                                                   similarities_function_names=None):
    Random.set_seed(int(99))
    destination_dir = os.path.join(root_data_path, destination_dir)
    if not os.path.exists(destination_dir):
        os.mkdir(destination_dir)

    if similarities_function_names is None:
        similarities_function_names = ['wup', 'res', 'path']

    evaluator = SimilarityEvaluator('cosine_similarity')

    K = 15
    N_TEST = 1000
    TEST_SIZE = 7

    spearman_scores = {}
    for measure in similarities_function_names:
        measure_dir = os.path.join(destination_dir, measure)
        if not os.path.exists(measure_dir):
            os.mkdir(measure_dir)

        spearman_scores[measure] = {}
        n_couple_clusters = retrieve_oov_pairs_divided_by_value_of_similarity(
            positive_input_path=os.path.join(root_data_path, 'oov_definition_sister_terms_positive_comparable.txt'),
            negative_input_path=os.path.join(root_data_path, 'oov_definition_sister_terms_negative_comparable.txt'),
            measure_name=measure)

        k_clusters = ClusterMinDiam.k_clusters_of_min_diameter(k=K, n_clusters=n_couple_clusters)

        tests = collect_test_of_size(n_test=N_TEST, test_size=TEST_SIZE, k_clusters=k_clusters,
                                     ouput_path=os.path.join(measure_dir, measure + '_micro_lists_test.txt'))
        for model_name in model_mappings:
            model = model_by(model_mappings[model_name])

            model_dir = os.path.join(measure_dir, model_name)
            if not os.path.exists(model_dir):
                os.mkdir(model_dir)

            spearman_scores[measure][model_name] = []

            for i in range(0, len(tests)):
                similarity_values = []
                cosines = []
                if model_name != 'defBERT' and model_name != 'BERT_wordpieces':
                    for d in tests[i]:
                        (similarity_value, couple) = d
                        target, sister = predict_target_sister_according_to(model, couple)

                        pred = evaluator.similarity_function(target, sister)
                        if type(model) is DefiNNet:
                            cosines.append(- pred[0])
                        else:
                            cosines.append(- pred)

                        similarity_values.append(similarity_value)

                    s = spearmanr(cosines, similarity_values)
                    spearman_scores[measure][model_name].append(s)
                else:
                    s = model.calculate_spearmanr(tests[i])
                    spearman_scores[measure][model_name].append(s)

            distribution = UnknownDistribution(data=[x.correlation for x in spearman_scores[measure][model_name]])
            distribution.save(output_path=os.path.join(model_dir, measure + '_' + model_name + '_hist_test.png'),
                              title=f"{measure} mini-lists spearman results")
            print('\t'.join([measure, model_name, str(distribution.mu), str(distribution.std), str(len(tests))]))

    return spearman_scores
