# semantic_analogy_generator.py
# This file handles generating semantic analogies for training the genetic embedding


# Internal Imports
import src.common_utils.directory as directory

# External Imports
import os
import json
from datasets import load_dataset, Dataset, concatenate_datasets
import pandas as pd
import re
from gensim.parsing.preprocessing import remove_stopwords
import nltk
from nltk.tokenize import word_tokenize

# Global Variables


def process_dataset(name, text_column = 'text'):
    # Get some configuration parameters
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["general"]
        config_file.close()
    output_dir = config['data_dir'] + name + '/'

    # Load the dataset & get the word list
    data_splits = load_dataset(name)
    nltk.download('punkt')# Get the word list

    # Combine all of the datasets together
    arr = []
    for split in data_splits.keys():
        arr.append(data_splits[split])
    dataset = concatenate_datasets(arr)

    # Pre-process the samples
    dataset = dataset.map(lambda sample:preprocess_sample(sample, text_column))

    # Output the samples to a JSON dataset file
    directory.create(output_dir)
    with open(output_dir + 'dataset.json', 'w') as output_file:
        json.dump(dataset.to_dict(), output_file)

    # Build a vocabulary file (used to generate analogies)
    vocab = []
    dataset.map(lambda sample:add_vocab_from_sample(sample, text_column, vocab))

    # Write the vocab to a file
    vocab.sort()
    with open(output_dir + name + '.vocab', 'w') as output_file:
        [output_file.write(word + '\n') for word in vocab]


def preprocess_sample(sample, text_column):
    # Generate the output object
    obj = {}
    for key in sample.keys():
        if key != text_column:
            obj[key] = sample[key]
    # Process the text column
    pattern = r'[^A-Za-z0-9]+'
    temp = re.sub(pattern, ' ', sample[text_column]).lower()
    obj[text_column] = word_tokenize(remove_stopwords(temp))
    # Return the object
    return obj


def add_vocab_from_sample(sample, text_column, vocab):
    # For each word in the tokenized sample, add it to the vocab
    for word in sample[text_column]:
        if not (word in vocab):
            vocab.append(word)


if (__name__ == "__main__"):
    # If this file is being run, process the sample dataset
    process_dataset("imdb")