import os
from pytorch_transformers import BertTokenizer, BertModel
from torch.nn import CosineSimilarity

import torch
from utility.randomfixedseed import Random
from utility.words_in_synset import SynsetCouple

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from BERT_eval_sister_terms_similarity.impl_pedersen_test_eval_sister_terms import InVocabulary_PedersenSimilaritySister_Test
from BERT_sister_terms_similarity.collect_sister_terms_similarities import compute_sister_terms, \
    retrieve_synset_couples_divided_by_value_of_similarity
from utility.cluster import ClusterMinDiam
from utility.distributions import UnknownDistribution


def evaluate_sister_terms_similarity(model_name, base_data_dir, sister_terms_path):
    if not os.path.exists(base_data_dir):
        os.mkdir(base_data_dir)

    if not os.path.exists(sister_terms_path):
        os.mkdir(sister_terms_path)

    seeds = ['19']
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name, output_hidden_states=True)
    model.to(device)

    print('COMPUTE SISTER TERMS')
    compute_sister_terms(seeds, sister_terms_path, model_name=model_name, check_if_computed=True)

    print('SPEARMAN CORRELATIONS IN VOC SISTER IN EXAMPLES')
    similarity_measures = ['path', 'wup', 'res']
    destination_dir = 'micro_list_sister_in_example_test'
    calculate_spearman_correlation_similarities_sister_in_example(model, tokenizer,
                                                                  sister_terms_path, destination_dir,
                                                                  model_name='bert_base_uncased',
                                                                  seeds=seeds,
                                                                  similarity_measures=similarity_measures)


def calculate_spearman_correlation_similarities_sister_in_example(test_model, test_tokenizer,
                                                                  root_data_model, destination_dir,
                                                                  model_name,
                                                                  seeds=None, similarity_measures=None):
    if similarity_measures is None:
        similarity_measures = ['path','wup', 'res']
    if seeds is None:
        seeds = ['19']

    destination_dir = os.path.join(root_data_model, destination_dir)
    if not os.path.exists(destination_dir):
        os.mkdir(destination_dir)

    K = 15

    evaluator = CosineSimilarity()
    for seed in seeds:
        spearman = {}
        Random.set_seed(int(seed))

        for measure in similarity_measures:
            spearman[measure] = []

            seed_dir = os.path.join(destination_dir, 'seed_' + seed)
            if not os.path.exists(seed_dir):
                os.mkdir(seed_dir)

            n_couple_clusters = retrieve_synset_couples_divided_by_value_of_similarity(
                positive_input_path=os.path.join(root_data_model, 'seed_' + seed, 'in_voc_sister_terms_positive.txt'),
                negative_input_path=os.path.join(root_data_model, 'seed_' + seed, 'in_voc_sister_terms_negative.txt'),
                measure_name=measure)

            available_clusters = {}

            for center in n_couple_clusters:
                for el in n_couple_clusters[center]:
                    el: SynsetCouple = el
                    examples = el.s1.examples()
                    sister_examples = el.s2.examples()

                    example = None
                    sister_example = None
                    for i in range(0, len(examples)):
                        if el.w1 in str(examples[i]).split(' '):
                            example = examples[i]
                            break
                    for i in range(0, len(sister_examples)):
                        if el.w2 in str(sister_examples[i]).split(' '):
                            sister_example = sister_examples[i]
                            break

                    if example is not None and sister_example is not None:
                        if center not in available_clusters:
                            available_clusters[center] = []
                        available_clusters[center].append(el)

            for pos in [None]:  # ['n', 'v', None]:
                k_clusters = ClusterMinDiam.k_clusters_of_min_diameter(k=K, n_clusters=available_clusters, pos=pos)
                if pos is None:
                    pos = ''
                else:
                    pos = '_' + pos

                output_path = os.path.join(seed_dir, measure + pos + '_sister_in_example.txt')

                tests = collect_test_of_size(n_test=500, test_size=7, k_clusters=k_clusters,
                                             ouput_path=output_path)

                mode_dir = os.path.join(seed_dir)
                if not os.path.exists(mode_dir):
                    os.mkdir(mode_dir)

                for test in tests:
                    tester = InVocabulary_PedersenSimilaritySister_Test.instantiate('word_in_example', test_tokenizer,
                                                                                    test_model)
                    sperman_coeff = tester.run(test, evaluator)
                    spearman[measure].append(sperman_coeff)

                distribution = UnknownDistribution(data=[x.correlation for x in spearman[measure]])
                output_path = os.path.join(mode_dir, measure + pos + '_hist.png')

                distribution.save(output_path=output_path,
                                  title=f"{measure} mini-lists spearman results")

                print('\t'.join([seed, measure, 'word_in_example', pos, str(distribution.mu), str(distribution.std),
                                 '# tests = ' + str(len(tests))]))


def collect_test_of_size(n_test, test_size, k_clusters: dict, ouput_path=None):
    tests = []
    exists_available = True
    i = 0

    while i in range(0, n_test) and exists_available:
        test = []
        available_centers = [center for center in k_clusters if len(k_clusters[center]) != 0]

        for j in range(0, test_size):
            if j + len(available_centers) < test_size:
                exists_available = False
                break

            center = Random.randomchoice(available_centers)
            available_centers.remove(center)

            d = Random.randomchoice(k_clusters[center])
            test.append(d)
            k_clusters[center].remove(d)

        if len(test) == test_size:
            test.sort()
            tests.append(test)
        else:
            break
        i += 1

    if ouput_path is not None:
        with open(ouput_path, 'w+') as output:
            output.write('\t'.join(['TEST_N', 'S1', 'S2', 'W1', 'W2', 'S1_POS', 'SIMILARITY', '#\n']))
            for i in range(0, len(tests)):
                for d in tests[i]:
                    (similarity_value, couple) = d
                    output.write('\t'.join([str(i + 1), couple.s1.name(), couple.s2.name(),
                                            couple.w1, couple.w2, couple.s_pos, str(similarity_value), '#\n']))
    return tests
