# train_embedding.py
# This script handles training the embedding using the specified configuration


# Internal Imports
from src.genetic_embedding.core.chromosome import Chromosome
from src.genetic_embedding.core.interface import population_manip
from src.common_utils import data_ops, directory, timestamp

# External Imports
import json
import os
import gc
from alive_progress import alive_bar

# Globals
checkpoint_location = "src/genetic_embedding/scripts/checkpoint.txt"


def main():
    # Start by loading the configuration file
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["scripts"]
        config_file.close()

    # Create the output files/folders
    current_time = ""
    counter = 1
    if os.path.exists(checkpoint_location):
        print("CHECKPOINT: Loading previously-checkpointed program execution")
        with open(checkpoint_location, "r") as checkpoint_file:
            checkpoint = json.load(checkpoint_file)
            current_time = checkpoint["timestamp"]
            counter = checkpoint["timestep"]
            checkpoint_file.close()
            del checkpoint
        print("CHECKPOINT: Previous checkpoint loaded")
    else:
        current_time = timestamp.create()
    #---Logging
    directory.create(config["logging_path"])
    #---Embedding Output
    output_path = config["output_path"].replace("<timestamp>", current_time)
    directory.create(output_path)

    # While we haven't reached the maximum timestep, run processing bursts
    print("PROCESS: Starting processing bursts")
    while counter < config["max_timestep"]:
        processing_burst(config, counter, current_time)
        gc.collect()
        # Increment the counter by the amount of processing done
        counter += config["processing_burst"]

    # Delete the checkpoint file
    os.remove(checkpoint_location)


def processing_burst(config, timestep, current_time):
    # Load the text, vocab, and cooccurrence matrix
    raw_data = data_ops.load_data(config["data_path"])
    text = data_ops.clean_data(raw_data)
    vocab = data_ops.load_vocab(config["data_path"], text)
    cooccurrence = data_ops.load_cooccurrence(config["data_path"], text, vocab, config["window_size"])
    del raw_data
    del text

    population = []
    # If there's a checkpoint file, load it. Otherwise, start a new training phase
    if os.path.exists(checkpoint_location):
        print("PREPARATION: Loading previous checkpoint population")
        # Get the checkpoint information
        checkpoint = {}
        with open(checkpoint_location, "r") as checkpoint_file:
            checkpoint = json.load(checkpoint_file)
            checkpoint_file.close()
        # Get the population
        for chromosome in checkpoint["population"]:
            temp_chromosome = Chromosome(embed = chromosome["embed"], fit = chromosome["fitness"], dim = config["dimension"])
            population.append(temp_chromosome)
        del checkpoint
        print("POPULATION: Checkpointed {}-bit population loaded at timestep {}".format(config["dimension"], timestep))
    else:
        # Initialize the population
        print("PREPARATION: Initializing the {}-bit population".format(config["dimension"]))
        for i in range(config["population_size"]):
            temp_chromosome = Chromosome(vocab, dim = config["dimension"])
            temp_chromosome.compute_fitness(vocab, cooccurrence)
            population.append(temp_chromosome)
        print("POPULATION: {}-bit population of size {} initialized".format(config["dimension"], config["population_size"]))

    # Run the processing loop for the burst amount
    with alive_bar(config["processing_burst"], bar="smooth", spinner="classic") as bar:
        for i in range(config["processing_burst"]):
            # Execute mutations & crossovers
            #---Crossover
            new_chromosomes = population_manip.crossover(population)
            #---Mutation
            new_chromosomes.append(population_manip.mutation(population))
            #---Random Crossover
            new_chromosomes.append(population_manip.crossover(population, selection_method = "random", crossover_span = 225))
            #---Random Mutation
            new_chromosomes.append(population_manip.mutation(population, selection_method = "random", mutation_amt = 150))

            # Compute fitness for the new chromosomes
            for chromosome in new_chromosomes:
                chromosome.compute_fitness(vocab, cooccurrence)

            # Build out a temporary population
            temp_population = population + new_chromosomes

            # Sort the population & prune it to the appropriate size
            new_population = population_manip.sort(temp_population)[:config["population_size"]]
            population = new_population

            # If the current iteration meets the logging interval, log the current fitnesses
            if not (timestep % config["logging_interval"]):
                log = "ITERATION {}: {}-bit population\n\tHigh:<{} {} {}>\n\tLow: <{} {} {}>".format(timestep, config["dimension"], population[0].fitness, population[1].fitness, population[2].fitness, population[-3].fitness, population[-2].fitness, population[-1].fitness)
                print(log)
                logfile = open("{}/log-{}bit-{}.txt".format(config["logging_path"], population[0].dimension, current_time), "a")
                logfile.write(log + "\n")
                logfile.close()

            # If the current interation meets the checkpoint interval, save the current embedding as a checkpoint
            if not (timestep % config["checkpoint_interval"]):
                print("CHECKPOINT: Saving checkpoint for {}-bit population".format(population[0].dimension))
                with open(config["output_path"].replace("<timestamp>", current_time) + "/ckpt-{}-embedding-{}bit.txt".format(timestep, population[0].dimension), "w") as output_file:
                    # Decode the top-performing chromosome
                    embed = population[0].decode(vocab)
                    # For each word in the embedding, write it to the file
                    for word in embed:
                        output_file.write(word + "\t" + bin(embed[word])[2:] + "\n")
                    output_file.close()

            # Increment the timestep
            timestep += 1
            bar()

    # Save a checkpoint
    print("CHECKPOINT: Processing burst finished. Saving checkpoint")
    checkpoint = {"timestamp":current_time, "timestep":timestep, "population":[]}
    for chromosome in population:
        checkpoint["population"].append({"embed":chromosome.embedding, "fitness":chromosome.fitness})
    with open(checkpoint_location, "w") as checkpoint_file:
        checkpoint_file.write(json.dumps(checkpoint))
        checkpoint_file.close()
    print("CHECKPOINT: Checkpoint saved.")

    # If we hit the max timestep, save the completed embedding. Otherwise, save a checkpoint
    if timestep >= config["max_timestep"]:
        print("EXPORT: Saving finished embedding")
        with open(config["output_path"].replace("<timestamp>", current_time) + "/embedding-{}bit.txt".format(population[0].dimension), "w") as output_file:
            # Decode the top-performing chromosome
            embed = population[0].decode(vocab)
            # For each word in the embedding, write it to the file
            for word in embed:
                output_file.write(word + "\t" + bin(embed[word])[2:] + "\n")
            output_file.close()
        print("EXPORT: Finished embedding exported successfully")

main()