# -*- coding: UTF-8 -*-

import os
import sys
import logging
from gensim.test.utils import common_texts
from gensim.models import Word2Vec

#logging.basicConfig(format='%(asctime)s : %(levelname)s : %(module)s:%(lineno)d : %(funcName)s(%(threadName)s) : %(message)s', level=logging.DEBUG)



'''
There are two classifiers: CoDA classifier and DUTA classifier.
Each classifier classifies a BM dataset into 10 classes.
Using the classified dataset, we build several word2vec models according to the class.
Then, given a keyword, we query the most similar words.
'''

CODA = 'coda'
DUTA = 'duta'
PATH = {CODA: '../res/corpus/category_annotated_corpus_v4.1_20210513/txt_preprocessed', DUTA: '../res/corpus/DUTA_10K_Masked_Dataset'}
BENCHMARK_PATH = ''


class Sentences:
    def __init__(self, dirname, cat=None):
        self.dirname = dirname
        self.cat = cat


    def __iter__(self):
        for fname in os.listdir(self.dirname):
            num, cat, lang, sha256 = fname.split('-')
            if self.cat is None or cat == self.cat:
                with open(os.path.join(self.dirname, fname)) as f:
                    for line in f:
                        yield line.strip().split()



class DutaSentences(Sentences):
    def __init__(self, dirname, cat=None):
        super().__init__(dirname, cat)


    def __iter__(self):
        for root, parent, fnames in os.walk(self.dirname):
            for fname in fnames:
                with open(os.path.join(root, fname)) as f:
                    for line in f:
                        yield line.strip().split()


def main(word):
    #classifier1 = Classifier(CODA)

    models = (CODA, DUTA)
    models = {model: None for model in models}
    for model in models:
        if not os.path.exists(f'./{model}/{model}_model'):
            sentences = DutaSentences(PATH[model]) if model == DUTA else Sentences(PATH[model])
            models[model] = Word2Vec(sentences=sentences, vector_size=100, window=5, min_count=1, workers=4)
            models[model].save(f'./{model}/{model}_model')
        models[model] = Word2Vec.load(f'./{model}/{model}_model')
        print('*'*10 + f' {model} ' + '*'*10)
        try:
            print('\n'.join(map(lambda x: x[0], models[model].wv.most_similar(word, topn=20))))
        except KeyError as exception:
            print(exception)



if __name__ == '__main__':
    main(sys.argv[1])

