# train_embeddings.py
# This script handles training the embeddings using the specified configuration


# Internal Imports
from src.genetic_embedding.core.chromosome import Chromosome
from src.genetic_embedding.core.interface import population_manip
from src.common_utils import data_ops, directory, timestamp

# External Imports
import json
from alive_progress import alive_bar
import gc
#from pympler import muppy, summary
import time
#import copy


def main():
    # Start by loading in the configuration file
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["scripts"]

    # Create the output files/folders
    current_time = timestamp.create()
    #---Logging
    directory.create(config["logging_path"])
    #---Embedding Output
    output_path = config["output_path"].replace("<timestamp>", current_time)
    directory.create(output_path)

    # Next, open & clean the data for use
    print("PREPARATION: Reading and cleaning the data for use")
    raw_data = data_ops.load_data(config["data_path"])
    text = data_ops.clean_data(raw_data)
    vocab = data_ops.load_vocab(config["data_path"], text)

    # Build the co-occurrence matrix
    cooccurrence = data_ops.load_cooccurrence(config["data_path"], text, vocab, config["window_size"])

    # Initialize the population(s)
    print("PREPARATION: Initializing the population(s)")
    populations = []
    for dimension in config["dimensions"]:
        temp_population = []
        print("POPULATION: Initializing {}-bit population".format(dimension))
        with alive_bar(config["population_size"], bar="smooth", spinner="classic") as bar:
            for i in range(config["population_size"]):
                temp_chromosome = Chromosome(vocab, dim = dimension)
                temp_chromosome.compute_fitness(vocab, cooccurrence)
                temp_population.append(temp_chromosome)
                # Update the progress bar
                bar()
        populations.append(temp_population)
        # Open the associated log file
        print("POPULATION: {}-bit population of size {} initialized".format(dimension, config["population_size"]))

    # Run the evolutionary algorithm on the populations
    counter = 0
    temp_population = [None] * (config["population_size"] + 3) # Use this to allocate a fixed-size population with added candidate children
    print("EVOLUTION: Starting training for the selected population(s)")
    with alive_bar(config["max_timestep"], bar="smooth", spinner="classic") as bar:
        for timestep in range(config["max_timestep"]):
            counter = 0
            for population in populations:
                #tim = time.perf_counter()
                # Execute mutations & crossovers
                #---Crossover
                new_chromosomes = population_manip.crossover(population)
                #---Mutation
                new_chromosomes.append(population_manip.mutation(population))

                #tim1 = time.perf_counter()
                # Compute fitness for the new chromosomes
                for chromosome in new_chromosomes:
                    chromosome.compute_fitness(vocab, cooccurrence)
                
                # Build out a temporary population
                temp_population = population + new_chromosomes # This line still causes some trouble - makes the script run infinitely slower
                '''for i in range(config["population_size"] + len(new_chromosomes)):
                    if i < len(population):
                        temp_population[i] = population[i]
                    else:
                        temp_population[i] = new_chromosomes[i - config["population_size"]]'''

                #tim2 = time.perf_counter()
                # Sort the population and prune it to the population size
                new_population = population_manip.sort(temp_population)[:config["population_size"]]
                populations[counter] = new_population
                del new_population
                del population
                del new_chromosomes


                #tim3 = time.perf_counter()
                #print("Time for Crossover & Mutation:", tim1-tim)
                #print("Time for fitness:", tim2-tim1)
                #print("Time for sorting:", tim3-tim2)
                # If the iteration meets the output interval, log the current fitness
                if not (timestep % config["logging_interval"]):
                    log = "ITERATION {}: {}-bit population\n\tHigh:<{} {} {}>\n\tLow: <{} {} {}>".format(timestep, populations[counter][0].dimension, populations[counter][0].fitness, populations[counter][1].fitness, populations[counter][2].fitness, populations[counter][-3].fitness, populations[counter][-2].fitness, populations[counter][-1].fitness)
                    print(log)
                    logfile = open("{}/log-{}bit-{}.txt".format(config["logging_path"], populations[counter][0].dimension, current_time), "a")
                    logfile.write(log + "\n")
                    logfile.close()

                # If the iteration meets the checkpoint interval, save the current embedding as a checkpoint
                if not (timestep % config["checkpoint_interval"]):
                    print("CHECKPOINT: Saving checkpoint for {}-bit population".format(populations[counter][0].dimension))
                    with open(config["output_path"].replace("<timestamp>", current_time) + "/ckpt-{}-embedding-{}bit.txt".format(timestep, populations[counter][0].dimension), "w") as output_file:
                        # Decode the top-performing chromosome
                        embed = populations[counter][0].decode(vocab)
                        # For each word in the embedding, write it to the file
                        for word in embed:
                            output_file.write(word + "\t" + bin(embed[word])[2:] + "\n")
                        del embed

                counter += 1
                gc.collect()
            # Update the progress bar
            bar()

    # Once finished, save the completed embedding
    for population in populations:
        print("CHECKPOINT: Saving finished embedding for {}-bit population".format(population[0].dimension))
        with open(config["output_path"].replace("<timestamp>", current_time) + "/embedding-{}bit.txt".format(population[0].dimension), "w") as output_file:
            # Decode the top-performing chromosome
            embed = population[0].decode(vocab)
            # For each word in the embedding, write it to the file
            for word in embed:
                output_file.write(word + "\t" + bin(embed[word])[2:] + "\n")

main()