import os

from pytorch_transformers import BertTokenizer, BertModel
from scipy.stats import wilcoxon
from torch.nn import CosineSimilarity

import torch

from BERT_eval_definitions_sister_terms_similarity.impl_pedersen_test_definitions_sister import \
    Definition_PedersenSimilaritySister_Test

from BERT_sister_terms_similarity.collect_sister_terms_similarities import compute_oov_sister_terms, \
    retrieve_synset_couples_divided_by_value_of_similarity
from utility.cluster import ClusterMinDiam
from utility.distributions import UnknownDistribution
from utility.randomfixedseed import Random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utility.words_in_synset import SynsetCouple, WordInSynset


def evaluate_sister_terms_definitions_similarity(model_name, base_data_dir, oov_sister_terms_path):
    if not os.path.exists(base_data_dir):
        os.mkdir(base_data_dir)

    if not os.path.exists(oov_sister_terms_path):
        os.mkdir(oov_sister_terms_path)

    seeds = ['19']

    tokenizer = BertTokenizer.from_pretrained(model_name)
    bert_model = BertModel.from_pretrained(model_name, output_hidden_states=True)
    bert_model.to(device)

    modes = [('bert_head_example', 'test_bert_head_example'),
             ('def_bert_head', 'test_def_bert_head'),
             ('bert_wordpieces', 'test_bert_wordpiecese'),
             ('def_bert_cls', 'test_def_bert_cls')]
    similarity_measures = ['path', 'wup', 'res']
    cosines_model = {}

    print('T3 COMPUTE SISTER TERMS')
    compute_oov_sister_terms(seeds, oov_sister_terms_path, model_name=model_name, check_if_computed=True)

    print('T3 SPEARMAN CORRELATIONS SISTER IN EXAMPLES')
    for (mode, destination_dir) in modes:
        seed_spearman = calculate_spearman_correlation_similarities_sister_in_example(bert_model, tokenizer,
                                                                                      oov_sister_terms_path,
                                                                                      destination_dir,
                                                                                      model_name='bert_base_uncased',
                                                                                      seeds=seeds,
                                                                                      similarity_measures=similarity_measures,
                                                                                      mode=mode)
        cosines_model[mode] = seed_spearman

    for seed in seeds:
        for measure in similarity_measures:
            cosines_target = [x.correlation for x in cosines_model["def_bert_head"][seed][measure]]

            cosines_baseline_parent = [x.correlation for x in cosines_model["bert_head_example"][seed][measure]]
            cosines_baseline_cls = [x.correlation for x in cosines_model["def_bert_cls"][seed][measure]]
            cosines_baseline_word_piece = [x.correlation for x in cosines_model["bert_wordpieces"][seed][measure]]

            statistic, p_value = wilcoxon(x=cosines_target, y=cosines_baseline_parent, alternative='greater')
            print('\t'.join(
                ['def_bert_head greater that bert_head_example', seed, measure,
                 "def_bert_head", "parent_from_example", str(p_value)]))

            statistic, p_value = wilcoxon(x=cosines_target, y=cosines_baseline_cls, alternative='greater')
            print('\t'.join(
                ['def_bert_head greater that def_bert_cls', seed, measure,
                 "def_bert_head", "def_bert_cls", str(p_value)]))
            statistic, p_value = wilcoxon(x=cosines_target, y=cosines_baseline_word_piece, alternative='greater')
            print('\t'.join(
                ['def_bert_head greater that bert_wordpieces', seed, measure,
                 "def_bert_head", "bert_wordpieces", str(p_value)]))


def collect_scores_from(score_file):
    with open(score_file, 'r') as file:
        lines = file.readlines()

    scores = []
    for line in lines:
        score = float(line.split('\t')[1])
        scores.append(score)
    return scores


def wilcoxontest(models_to_test, target_model, measures, score_dir, seeds):
    for seed in seeds:
        seed_dir = os.path.join(score_dir, 'seed_' + seed)
        for measure in measures:
            print('\t'.join(['wilcoxon', seed, measure]))

            measure_dir = os.path.join(seed_dir, measure, 'spearman_' + measure)
            target_score_file = os.path.join(measure_dir, target_model + '_spearman_correlation.txt')
            target_model_score = collect_scores_from(target_score_file)

            for model in models_to_test:
                if model == target_model:
                    continue
                model_score_file = os.path.join(measure_dir, model + '_spearman_correlation.txt')
                model_score = collect_scores_from(model_score_file)

                n = len(model_score)
                M = 0

                pos = 0
                neg = 0
                for i in range(0, n):
                    if model_score[i] - target_model_score[i] > M:
                        pos += 1
                    if model_score[i] - target_model_score[i] < M:
                        neg += 1

                n1 = pos + neg
                k = pos

                statistic, p_value = wilcoxon(x=model_score, y=target_model_score, alternative='greater')
                print('\t'.join(
                    [model + ' greater that ' + target_model, seed, measure, model, target_model, f'n1 = {n1}',
                     f'k = {k}', str(p_value)]))


def calculate_spearman_correlation_similarities_sister_in_example(test_model, test_tokenizer,
                                                                  root_data_model, destination_dir,
                                                                  model_name,
                                                                  seeds=None, similarity_measures=None,
                                                                  mode='def_bert_cls'):
    if similarity_measures is None:
        similarity_measures = ['path', 'wup', 'res']
    if seeds is None:
        seeds = ['19']

    destination_dir = os.path.join(root_data_model, destination_dir)
    if not os.path.exists(destination_dir):
        os.mkdir(destination_dir)

    K = 15

    return_cosines = {}
    evaluator = CosineSimilarity()
    for seed in seeds:
        Random.set_seed(int(seed))
        spearman = {}

        for measure in similarity_measures:
            spearman[measure] = []

            seed_dir = os.path.join(destination_dir, 'seed_' + seed)
            if not os.path.exists(seed_dir):
                os.mkdir(seed_dir)

            n_couple_clusters = retrieve_synset_couples_divided_by_value_of_similarity(
                positive_input_path=os.path.join(root_data_model, 'seed_' + seed, 'oov_sister_terms_positive.txt'),
                negative_input_path=os.path.join(root_data_model, 'seed_' + seed, 'oov_sister_terms_negative.txt'),
                measure_name=measure)

            available_clusters = {}
            for center in n_couple_clusters:
                for el in n_couple_clusters[center]:
                    el: SynsetCouple = el
                    sister_examples = el.s2.examples()

                    sister_example = None
                    for i in range(0, len(sister_examples)):
                        if el.w2 in str(sister_examples[i]).split(' '):
                            sister_example = sister_examples[i]
                            break

                    if sister_example is not None:
                        if center not in available_clusters:
                            available_clusters[center] = []
                        available_clusters[center].append(el)

            for pos in [None]:
                k_clusters = ClusterMinDiam.k_clusters_of_min_diameter(k=K, n_clusters=available_clusters, pos=pos)
                if pos is None:
                    pos = ''
                else:
                    pos = '_' + pos

                output_path = os.path.join(seed_dir, measure + pos + '_sister_in_example.txt')
                tests = collect_test_of_size(n_test=500, test_size=7, k_clusters=k_clusters,
                                             ouput_path=output_path)

                mode_dir = os.path.join(seed_dir)
                if not os.path.exists(mode_dir):
                    os.mkdir(mode_dir)

                if mode == 'def_bert_head' or mode == 'bert_wordpieces':
                    mode_saver_path = os.path.join(mode_dir, mode)
                    if not os.path.exists(mode_saver_path):
                        os.mkdir(mode_saver_path)
                    mode_saver_path = os.path.join(mode_saver_path, measure)
                    if not os.path.exists(mode_saver_path):
                        os.mkdir(mode_saver_path)

                i = 0
                for test in tests:
                    tester = Definition_PedersenSimilaritySister_Test.instantiate(mode, test_tokenizer, test_model)
                    if mode == 'def_bert_head' or mode == 'bert_wordpieces':
                        sperman_coeff = tester.run(test, evaluator,
                                                   output_path=os.path.join(mode_saver_path, str(i) + '.csv'))
                    else:
                        sperman_coeff = tester.run(test, evaluator)

                    i += 1
                    spearman[measure].append(sperman_coeff)

                distribution = UnknownDistribution(data=[x.correlation for x in spearman[measure]])
                output_path = os.path.join(mode_dir, mode + '_' + measure + pos + '_hist.png')
                distribution.save(output_path=output_path,
                                  title=f"{measure} mini-lists spearman results")

                print('\t'.join([seed, measure, mode, pos, str(distribution.mu), str(distribution.std),
                                 str(len(tests))]))
        return_cosines[seed] = spearman

    return return_cosines


def collect_test_of_size(n_test, test_size, k_clusters: dict, ouput_path=None):
    tests = []
    exists_available = True
    i = 0

    while i in range(0, n_test) and exists_available:
        test = []
        available_centers = [center for center in k_clusters if len(k_clusters[center]) != 0]

        for j in range(0, test_size):
            if j + len(available_centers) < test_size:
                exists_available = False
                break

            center = Random.randomchoice(available_centers)
            available_centers.remove(center)

            d = Random.randomchoice(k_clusters[center])
            test.append(d)
            k_clusters[center].remove(d)

        if len(test) == test_size:
            test.sort()
            tests.append(test)
        else:
            break
        i += 1

    if ouput_path is not None:
        with open(ouput_path, 'w+') as output:
            output.write('\t'.join(['TEST_N', 'S1', 'S2', 'W1', 'W2', 'S1_POS', 'SIMILARITY', '#\n']))
            for i in range(0, len(tests)):
                for d in tests[i]:
                    (similarity_value, couple) = d
                    output.write('\t'.join([str(i + 1), couple.s1.name(), couple.s2.name(),
                                            couple.w1, couple.w2, couple.s_pos, str(similarity_value), '#\n']))
    return tests