import os
from defiNNet.DefiNNet import DefiNNet
from utility.training import all_files_of, write_w2v_example_excluding, load_dataset_from, split_in
from word_embeddings_benchmarks_test.test_benchmark import all_words_in_tasks


tagset = ['ADJP', '-ADV', 'ADVP', '-BNF', 'CC', 'CD', '-CLF', '-CLR', 'CONJP', '-DIR', 'DT', '-DTV', 'EX', '-EXT',
            'FRAG', 'FW', '-HLN', 'IN', 'INTJ', 'JJ', 'JJR', 'JJS', '-LGS', '-LOC', 'LS', 'LST', 'MD', '-MNR', 'NAC',
            'NN', 'NNS', 'NNP', 'NNPS', '-NOM', 'NP', 'NX', 'PDT', 'POS', 'PP', '-PRD', 'PRN', 'PRP', '-PRP',
            'PRP$', 'PRP-S', 'PRT', '-PUT', 'QP', 'RB', 'RBR', 'RBS', 'RP', 'RRC', 'S', 'SBAR', 'SBARQ', '-SBJ',
            'SINV', 'SQ', 'SYM', '-TMP', 'TO', '-TPC', '-TTL', 'UCP', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ',
            '-VOC', 'VP', 'WDT', 'WHADJP', 'WHADVP', 'WHNP', 'WHPP', 'WP', 'WP$', 'WP-S', 'WRB', 'X', 'AFX', '#', '$',
            '-LRB-', '\"', '(', ')', ',', '.', ':', '``']

test_words = all_words_in_tasks()

base1 = "data/rules/examples/v"
base2 = "data/rules/examples/n"
base3 = "data/rules/examples/a"

input_paths = all_files_of(base1) + all_files_of(base2) + all_files_of(base3)

output_dir = 'data/datasets'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

test_paths = []
for seed in ['19']:
    for t in ['positive', 'negative']:
        test_paths.append(
            'data/similarity_pedersen_test/oov_sister_terms_with_definitions/seed_' + seed + '/oov_definition_sister_terms_' + t + '.txt')

dataset_path = os.path.join(output_dir, 'google_w2v_example_v4_no_sister_terms_test.npz')
write_w2v_example_excluding(test_paths=test_paths, input_paths=input_paths,
                            output_path=dataset_path, tagset=tagset,
                            save_for_test=True,
                            words_in_test=test_words)

dataset_data, dataset_target, dataset_target_pos, dataset_w1_pos, dataset_w2_pos = load_dataset_from(path=dataset_path)

validation, train = split_in(0.20, dataset_data, dataset_target, dataset_target_pos, dataset_w1_pos, dataset_w2_pos)

(test_data, test_target, test_target_pos, test_w1_pos, test_w2_pos) = validation
(train_data, train_target, train_target_pos, train_w1_pos, train_w2_pos) = train

print("Training on "+str(len(train_data)))
pretrained = pretrained_embeddings_path = "data/pretrained_embeddings/GoogleNews-vectors-negative300.bin"
model = DefiNNet.train(pretrained_embeddings_path, train, tagset)

print("Validation on "+str(len(test_data)))
model.test(validation)
model.save("denn.h5")


