import os

from scipy.stats import wilcoxon
from pytorch_transformers import BertTokenizer, BertModel
from BERT_eval_sim_definition_lemma.impl_test_definitions_lemma import Test_Definition_Lemma
from utility.word_in_vocabulary import collected_vocabulary_words_in_synset

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def all_descendant_files_of(base):
    input_paths = []
    for file in os.listdir(base):
        input_path = os.path.join(base, file)
        if os.path.isfile(input_path):
            input_paths.append(input_path)
    return input_paths


def evaluate_definition_lemma_embeddings(model_name='bert-base-uncased', base_data_dir='data',
                                         output_path='data/task1_sim_definition_lemma'):
    if not os.path.exists(base_data_dir):
        os.mkdir(base_data_dir)

    word_in_synset_path = os.path.join(base_data_dir, 'words_in_synsets_model_' + model_name)
    if not os.path.exists(word_in_synset_path):
        os.mkdir(word_in_synset_path)

    filename = 'vocabulary_in_synset.txt'
    pos_tags = ['v', 'n']

    words_in_synsets = collected_vocabulary_words_in_synset(model_name, pos_tags=pos_tags,
                                                            output_path=os.path.join(word_in_synset_path, filename))

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
    model.to(device)
    if not os.path.exists(output_path):
        os.mkdir(output_path)

    seeds = ['19']
    model_cosines = {}

    model_name = 'def_bert_cls'
    task12 = Test_Definition_Lemma.instantiate(model_name,
                                               model, tokenizer,
                                               words_in_synsets,
                                               output_path=os.path.join(output_path, 'def_bert_cls'))
    model_cosines[model_name] = {}
    for pos in pos_tags:
        model_cosines[model_name][pos] = {}
        for seed in seeds:
            model_cosines[model_name][pos][seed] = []
            distrmu, distrstd, cosines = task12.run(115000, pos, seed)
            model_cosines[model_name][pos][seed] = cosines
            print('\t'.join(['def_bert_cls', pos, seed, str(distrmu), str(distrstd)]))

    model_name = 'def_bert_head'
    task14 = Test_Definition_Lemma.instantiate(model_name,
                                               model, tokenizer,
                                               words_in_synsets,
                                               output_path=os.path.join(output_path, 'def_bert_head'))
    model_cosines[model_name] = {}
    for pos in pos_tags:
        model_cosines[model_name][pos] = {}
        for seed in seeds:
            model_cosines[model_name][pos][seed] = []

            distrmu, distrstd, cosines = task14.run(115000, pos, seed)
            model_cosines[model_name][pos][seed] = cosines
            print('\t'.join(['def_bert_head', pos, seed, str(distrmu), str(distrstd)]))

    model_name = 'bert_head_example'
    task16 = Test_Definition_Lemma.instantiate(model_name,
                                               model, tokenizer,
                                               words_in_synsets,
                                               output_path=os.path.join(output_path, 'bert_head_example'))
    model_cosines[model_name] = {}
    for pos in pos_tags:
        model_cosines[model_name][pos] = {}
        for seed in seeds:
            model_cosines[model_name][pos][seed] = []
            distrmu, distrstd, cosines = task16.run(115000, pos, seed)
            model_cosines[model_name][pos][seed] = cosines
            print('\t'.join(['bert_head_example', pos, seed, str(distrmu), str(distrstd)]))

    targets = ['def_bert_head']
    baselines = ['bert_head_example', 'def_bert_cls']
    for target in targets:
        for baseline in baselines:
            for pos in pos_tags:
                for seed in seeds:
                    statistic, p_value = wilcoxon(x=model_cosines[target][pos][seed],
                                                  y=model_cosines[baseline][pos][seed], alternative='greater')
                    print("\t".join(['wilcoxon', target, baseline, pos, "p-value=" + str(p_value)]))
