# train_embeddings.py
# This script handles training the embeddings using th especified configuration


# Internal Imports
#from src.genetic_embedding.core.chromosome import Chromosome
from src.genetic_embedding.core.population import Population
#from src.genetic_embedding.core.interface import population_manip
from src.common_utils import data_ops#, directory, timestamp
#from src.genetic_embedding.scripts.evaluation import word_analogy

# External Imports
import json
from alive_progress import alive_bar
#from threading import Thread
from multiprocessing import Process


def main():
    # Start by loading the configuration file
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["training"]

    # Open & clean the data for use
    print("PREPARATION: Reading and cleaning the data for use")
    raw_data = data_ops.load_data(config["data_path"])
    text = data_ops.clean_data(raw_data)
    vocab = data_ops.load_vocab(config["data_path"], text, remove_numeric = True)
    del raw_data

    # Build the co-occurrence matrix
    cooccurrence = data_ops.load_cooccurrence(config["data_path"], text, vocab, config["window_size"])
    del text

    # Build the list of configurations
    configurations = []
    for i in config["max_timestep"]:
        for j in config["population_size"]:
            for k in config["dimensions"]:
                for l in config["evaluation_path"]:
                    if l["folds"] > 0:
                        configurations.append({
                            "num_folds":l["folds"],
                            "folds":[{
                                "dimension":k,
                                "population_size":j,
                                "max_timestep":i,
                                "evaluation_path":l["file"]
                            }]
                        })
                    else:
                        configurations.append({
                            "num_folds":0,
                            "folds":[{
                                "dimension":k,
                                "population_size":j,
                                "max_timestep":i,
                                "evaluation_path":l["file"]
                            }]
                        })
    #print(configurations)

    # Start processing. For each dimension, spawn & run a population
    #for dimension in config["dimensions"]:
    for item in configurations:
        # If there are no folds, just run it
        if (item["num_folds"] == 0):
            population = init_population(vocab, config, item["folds"][0], 0)
            run_population(population, config, item["folds"][0], 0)
        else:# If there are folds, spawn a number of threads to do it synchronously
            print("CONFIGURATION: Multiple folds detected. Spawning threads for {}-bit embeddings for {}".format(item["folds"][0]["dimension"], item["folds"][0]["evaluation_path"]))
            threads = []
            for counter in range(item["num_folds"]):
                population = init_population(vocab, config, item["folds"][0], counter+1)
                # Spawn a thread with the population
                process = Process(target=run_population, args=[population, config, item["folds"][0], counter+1])
                process.start()
                threads.append(process)
            # Now, wait for all the threads to join up
            for thread in threads:
                thread.join()



def init_population(vocab, global_config, configuration, foldNum):
    print("PREPARATION: Initializing the {}-bit population".format(configuration["dimension"]))
    # Initialize the population
    population = Population(
        pop_size = configuration["population_size"],
        dimension = configuration["dimension"],
        vocab = vocab,
        logging_interval = global_config["logging_interval"],
        checkpoint_interval = global_config["checkpoint_interval"],
        logging_path = global_config["logging_path"],
        output_path = global_config["output_path"],
        evaluation_filepath = configuration["evaluation_path"].replace('<fold>', str(foldNum)),
        fold_number = foldNum
    )
    print("POPULATION: {}-bit population of size {} initialized".format(configuration["dimension"], configuration["population_size"]))
    return population



def run_population(population, global_config, configuration, foldNum):
    # Log what dataset this is being trained on
    logloc = "{}/log-{}bit-{}.txt".format(population.logging_path, population.dimension, population.timestamp)
    if foldNum > 0:
        logloc = "{}/log-{}bit-{}-fold{}.txt".format(population.logging_path, population.dimension, population.timestamp, population.fold_num)
    with open(logloc, "a") as logfile:
        logfile.write("Training on: " + configuration["evaluation_path"] + "\n")
        logfile.close()

    # Until either convergence or it exceeds the maximum timestep, run evolution cycles
    with alive_bar(configuration["max_timestep"], bar="smooth", spinner="classic") as bar:
        while((population.function_evals < configuration["max_timestep"]) and (population.current_performance < global_config["goal_performance"])):
            population.evolution_cycle()
            bar()
            
        # Log how long evolution took & the rate at which it did so (maybe - this will be done if alive-progress exposes an interface for this)
        '''with open("{}/log-{}bit-{}.txt".format(population.logging_path, population.dimension, population.timestamp), "a") as logfile:
            logfile.write("" + "\n")
            logfile.close()'''
        
    # Log & checkpoint one last time
    population.checkpoint()
    population.log()

main()