import os

import pandas as pd
from nltk.corpus import wordnet as wn

from scipy.stats import spearmanr

from defiNNet.DefiNNet import DefiNNet
from sister_terms_similarity.pedersen_similarities import SimilarityFunction, SisterInVocPair
from sister_terms_similarity.test_sister_terms_oov import collect_test_of_size, predict_target_sister_according_to
from utility.cluster import ClusterMinDiam
from utility.distributions import UnknownDistribution
from utility.models_utility import model_by
from utility.randomfixedseed import Random
from utility.similarity_evaluator import SimilarityEvaluator


def test_all_models_on_in_voc_micro_lists(model_mappings, root_data_path, destination_dir,
                                   similarities_function_names=None):
    Random.set_seed(int(19))
    destination_dir = os.path.join(root_data_path, destination_dir)
    if not os.path.exists(destination_dir):
        os.mkdir(destination_dir)

    if similarities_function_names is None:
        similarities_function_names = ['wup', 'res', 'path']

    evaluator = SimilarityEvaluator('cosine_similarity')

    K = 15
    N_TEST = 2000
    TEST_SIZE = 7

    spearman_scores = {}
    for measure in similarities_function_names:
        measure_dir = os.path.join(destination_dir, measure)
        if not os.path.exists(measure_dir):
            os.mkdir(measure_dir)
        spearman_scores[measure] = {}
        n_couple_clusters = retrieve_in_voc_pairs_divided_by_value_of_similarity(
            positive_input_path=os.path.join(root_data_path, 'in_voc_sister_terms_positive.txt'),
            negative_input_path=os.path.join(root_data_path, 'in_voc_sister_terms_negative.txt'),
            measure_name=measure)

        k_clusters = ClusterMinDiam.k_clusters_of_min_diameter(k=K, n_clusters=n_couple_clusters)

        tests = collect_test_of_size(n_test=N_TEST, test_size=TEST_SIZE, k_clusters=k_clusters,
                                     ouput_path=os.path.join(measure_dir, measure + '_micro_lists_test.txt'))

        for model_name in model_mappings:
            model = model_by(model_mappings[model_name])

            model_dir = os.path.join(measure_dir, model_name)
            if not os.path.exists(model_dir):
                os.mkdir(model_dir)

            spearman_scores[measure][model_name] = []

            for i in range(0, len(tests)):
                similarity_values = []
                cosines = []
                for d in tests[i]:
                    (similarity_value, couple) = d

                    target, sister = predict_target_sister_according_to(model, couple)

                    pred = evaluator.similarity_function(target, sister)
                    if type(model) is DefiNNet:
                        cosines.append(- pred[0])
                    else:
                        cosines.append(- pred)

                    similarity_values.append(similarity_value)

                s = spearmanr(cosines, similarity_values)
                spearman_scores[measure][model_name].append(s)

            distribution = UnknownDistribution(data=[x.correlation for x in spearman_scores[measure][model_name]])
            distribution.save(output_path=os.path.join(model_dir, measure + '_' + model_name + '_hist_test.png'),
                              title=f"{measure} mini-lists spearman results")
            print('\t'.join([measure, model_name, str(distribution.mu), str(distribution.std), str(len(tests))]))

    return spearman_scores


def retrieve_in_voc_pairs_divided_by_value_of_similarity(positive_input_path, negative_input_path, measure_name):
    similarity_function = SimilarityFunction.by_name(measure_name)
    ordered_pairs = {}

    positive_pairs = pd.read_csv(positive_input_path, sep='\t')
    negative_pairs = pd.read_csv(negative_input_path, sep='\t')

    for index, couple in positive_pairs.iterrows():
        target_syn = wn.synset(couple['S1'])
        sister_syn = wn.synset(couple['S2'])

        similarity_value = similarity_function(target_syn, sister_syn)
        if similarity_value not in ordered_pairs:
            ordered_pairs[similarity_value] = []
        ordered_pairs[similarity_value].append(SisterInVocPair(couple))

    for index, couple in negative_pairs.iterrows():
        target_syn = wn.synset(couple['S1'])
        sister_syn = wn.synset(couple['S2'])

        similarity_value = similarity_function(target_syn, sister_syn)
        if similarity_value not in ordered_pairs:
            ordered_pairs[similarity_value] = []
        ordered_pairs[similarity_value].append(SisterInVocPair(couple))

    return ordered_pairs