import os

from nltk.corpus import wordnet as wn


class BertVocab:
    @staticmethod
    def read_vocab_file(vocab_file):
        bert_vocab = {}
        with open(vocab_file, 'r') as file:
            lines = file.readlines()

            for line in lines:
                word, s, p, _ = line.split('\t')
                bert_vocab[word] = 1

        return bert_vocab


class BertW2VVocab:
    def __init__(self, pretained_model, bert_vocab_path='data/bert_vocabulary_in_synset.txt'):
        self.pretained_model_vocab = pretained_model.vocab
        self.bert_vocab = BertVocab.read_vocab_file(bert_vocab_path)
        self.comparable_lines = []

    def comparable_definitions(self, correlation, pos):
        if pos is not None:
            if not correlation['target_pos'].startswith(pos.upper()):
                return False

        if correlation['w1'] not in self.pretained_model_vocab:
            return False

        if correlation['w1'] not in self.bert_vocab:
            return False

        synset = wn.synset(correlation['s1'])

        examples = synset.examples()
        example = None
        for i in range(0, len(examples)):
            if correlation['w1'] in str(examples[i]).split(' '):
                example = examples[i]
                break

        if example is None:
            return False

        return True

    def comparable_synset_oov_couple(self, couple, pos=None):
        if pos is not None:
            if not couple.target_pos.startswith(pos.upper()):
                return False

        if couple.second not in self.pretained_model_vocab or couple.second not in self.bert_vocab:
            return False

        synset = couple.synset_second
        examples = synset.examples()

        example = None
        for i in range(0, len(examples)):
            if couple.second in str(examples[i]).split(' '):
                example = examples[i]
                break

        if example is None:
            return False

        return True

    def memorize(self, output_comparable, original_file):
        with open(os.path.join(output_comparable, os.path.basename(original_file)), 'w+') as f:
            f.writelines(self.comparable_lines)

    def add(self, line):
        self.comparable_lines.append(line)
