import os

from gensim.models import KeyedVectors

from utility.bert_vocab import BertW2VVocab
from utility.words_in_synset import WordInSynset, SynsetOOVCouple


class DefinitionsOOVSisterTerms_Joiner_V2:
    def __init__(self, definitions_paths, sister_terms_path, output_path):
        self.definitions_paths = definitions_paths
        self.sister_terms_path = sister_terms_path
        self.output_path = output_path

    def join(self):
        definitions = []
        sister_terms = []
        with open(self.sister_terms_path, 'r') as sister_terms_file:
            first = True
            for line in sister_terms_file.readlines():
                if first:
                    first = False
                    continue

                split = line.split('\t')
                oov_s, w2_s2 = WordInSynset(word=split[2], synset_name=split[0], pos=split[4]), WordInSynset(
                    word=split[3], synset_name=split[1], pos=split[4])

                sister_terms.append((oov_s, w2_s2))

        for definitions_path in self.definitions_paths:
            with open(definitions_path, 'r') as definitions_file:
                first = True
                for line in definitions_file.readlines():
                    if first:
                        first = False
                        continue

                    split = line.split('\t')
                    oov_s = WordInSynset(word=split[1], synset_name=split[0], pos=split[4])
                    split.pop()
                    definitions.append((oov_s, split))
            print('def', len(definitions))

        output_lines = []
        for (oov_s, w2_s2) in sister_terms:
            oov_s: WordInSynset = oov_s
            for (candidate_oov_s, definition_line) in definitions:
                if oov_s.equals(candidate_oov_s):
                    definition_line.extend([w2_s2.synset_name, w2_s2.word, '#\n'])
                    output_lines.append('\t'.join(definition_line))

        print('out', len(output_lines))
        with open(self.output_path, 'w+') as output_file:
            header = '\t'.join(['target_synset', 'target_word', 'data_w1', 'data_w2', 'target_pos', 'w1_pos', 'w2_pos',
                                'definition', 'sister_synset', 'sister_word', "#\n"])
            output_file.writelines([header] + output_lines)


class DefinitionsBertComparableOOVSisterTerms_Joiner_V2:
    def __init__(self, definitions_paths, sister_terms_path, output_path):
        self.definitions_paths = definitions_paths
        self.sister_terms_path = sister_terms_path
        self.output_path = output_path

    def join(self):
        pretained_model_path = 'data/pretrained_embeddings/GoogleNews-vectors-negative300.bin'
        pretrained_embeddings_model = KeyedVectors.load_word2vec_format(pretained_model_path, binary=True)
        policy = BertW2VVocab(pretrained_embeddings_model, bert_vocab_path='data/bert_vocabulary_in_synset.txt')

        definitions = []
        sister_terms = []
        with open(self.sister_terms_path, 'r') as sister_terms_file:
            first = True
            for line in sister_terms_file.readlines():
                if first:
                    first = False
                    continue

                split = line.split('\t')
                oov_s, w2_s2 = WordInSynset(word=split[2], synset_name=split[0], pos=split[4]), WordInSynset(
                    word=split[3], synset_name=split[1], pos=split[4])

                sister_terms.append((oov_s, w2_s2))

        for definitions_path in self.definitions_paths:
            with open(definitions_path, 'r') as definitions_file:
                first = True
                for line in definitions_file.readlines():
                    if first:
                        first = False
                        continue

                    split = line.split('\t')
                    oov_s = WordInSynset(word=split[1], synset_name=split[0], pos=split[4])
                    split.pop()
                    definitions.append((oov_s, split))
            # print('def', len(definitions))

        output_lines = []
        for (oov_s, w2_s2) in sister_terms:
            oov_s: WordInSynset = oov_s
            for (candidate_oov_s, definition_line) in definitions:
                if oov_s.equals(candidate_oov_s):
                    definition_line.extend([w2_s2.synset_name, w2_s2.word, '#\n'])

                    s1_index = 0
                    w1_index = 1

                    s2_index = 8
                    w2_index = 9

                    first_indexes = [2, 4]
                    s_pos_index = 4
                    w1_pos = 5
                    w2_pos = 6

                    split = definition_line
                    couple = SynsetOOVCouple(oov=split[w1_index], synset_oov=split[s1_index],
                                             first=split[first_indexes[0]:first_indexes[1]],
                                             second=split[w2_index],
                                             synset_second=split[s2_index], target_pos=split[s_pos_index],
                                             w1_pos=split[w1_pos], w2_pos=split[w2_pos])

                    if policy.comparable_synset_oov_couple(couple):
                        output_lines.append('\t'.join(definition_line))

        print('out', len(output_lines))
        with open(self.output_path, 'w+') as output_file:
            header = '\t'.join(['target_synset', 'target_word', 'data_w1', 'data_w2', 'target_pos', 'w1_pos', 'w2_pos',
                                'definition', 'sister_synset', 'sister_word', "#\n"])
            output_file.writelines([header] + output_lines)


def all_descendant_files_of(base):
    input_paths = []
    for root, dirs, files in os.walk(base, topdown=False):
        input_paths.extend([os.path.join(root, x) for x in files])
    return input_paths


def oov_sister_terms_definitions():
    base_path = 'data/similarity_pedersen_test'
    sister_terms_dir = os.path.join('data/similarity_pedersen_test', 'sister_terms')

    definitions_paths = all_descendant_files_of('data/oov_definitions/examples')
    print(definitions_paths)
    for seed in ['19']:
        for t in ['positive', 'negative']:
            sister_terms_path = os.path.join(sister_terms_dir, 'seed_' + seed, 'oov_sister_terms_' + t + '.txt')

            oov_sister_terms_with_definitions = os.path.join(base_path, 'oov_sister_terms_with_definitions')
            if not os.path.exists(oov_sister_terms_with_definitions):
                os.mkdir(oov_sister_terms_with_definitions)

            seed_dir = os.path.join(oov_sister_terms_with_definitions, 'seed_' + seed)
            if not os.path.exists(seed_dir):
                os.mkdir(seed_dir)

            output_path = os.path.join(seed_dir, 'oov_definition_sister_terms_' + t + '.txt')

            joiner = DefinitionsOOVSisterTerms_Joiner_V2(definitions_paths, sister_terms_path, output_path)
            joiner.join()


def oov_sister_terms_definitions_comparable():
    base_path = 'data/similarity_pedersen_test'
    sister_terms_dir = os.path.join('data/similarity_pedersen_test', 'sister_terms')

    definitions_paths = all_descendant_files_of('data/oov_definitions/examples')
    for seed in ['19']:
        for t in ['positive', 'negative']:
            sister_terms_path = os.path.join(sister_terms_dir, 'seed_' + seed, 'oov_sister_terms_' + t + '.txt')

            oov_sister_terms_with_definitions = os.path.join(base_path, 'oov_sister_terms_with_definitions')
            if not os.path.exists(oov_sister_terms_with_definitions):
                os.mkdir(oov_sister_terms_with_definitions)

            seed_dir = os.path.join(oov_sister_terms_with_definitions, 'seed_' + seed)
            if not os.path.exists(seed_dir):
                os.mkdir(seed_dir)

            output_path = os.path.join(seed_dir, 'oov_definition_sister_terms_' + t + '_comparable.txt')

            joiner = DefinitionsBertComparableOOVSisterTerms_Joiner_V2(definitions_paths, sister_terms_path,
                                                                       output_path)
            joiner.join()