# train_embeddings.py
# This script handles training the embeddings using th especified configuration


# Internal Imports
from src.genetic_embedding.core.chromosome import Chromosome
from src.genetic_embedding.core.interface import population_manip
from src.common_utils import data_ops, directory, timestamp
from src.genetic_embedding.scripts.evaluation import word_analogy

# External Imports
import json
from alive_progress import alive_bar


def main():
    # Start by loading the configuration file
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["scripts"]

    # Create the output files/folders
    current_time = timestamp.create()
    #---Logging
    directory.create(config["logging_path"])
    #---Embedding Output
    output_path = config["output_path"].replace("<timestamp>", current_time)
    directory.create(output_path)

    # Next, open & clean the data for use
    print("PREPARATION: Reading and cleaning the data for use")
    raw_data = data_ops.load_data(config["data_path"])
    text = data_ops.clean_data(raw_data)
    vocab = data_ops.load_vocab(config["data_path"], text, remove_numeric = True)
    del raw_data

    # Build the co-occurrence matrix
    cooccurrence = data_ops.load_cooccurrence(config["data_path"], text, vocab, config["window_size"])
    del text

    # Initialize the population
    print("PREPARATION: Initializing the {}-bit population".format(config["dimension"]))
    population = [None] * config["population_size"]
    with alive_bar(config["population_size"], bar="smooth", spinner="classic") as bar:
        for i in range(config["population_size"]):
            temp_chromosome = Chromosome(vocab, dim = config["dimension"])
            temp_chromosome.compute_fitness(vocab, cooccurrence)
            population[i] = temp_chromosome
            # Update the progress bar
            bar()
    print("POPULATION: {}-bit population of size {} initialized".format(config["dimension"], config["population_size"]))

    # Now, run the evolutionary algorithm on the population
    with alive_bar(config["max_timestep"], bar="smooth", spinner="classic") as bar:
        # Run the processing loop for 100 iterations (configurable)
        for timestep in range(config["max_timestep"]):
            # Execute mutations & crossovers
            #---Crossover
            new_chromosomes = population_manip.crossover(population)
            #---Mutation
            new_chromosomes.append(population_manip.mutation(population))
            new_chromosomes.append(population_manip.mutation(population, mutation_amt = 155))
            #---Random Crossover
            new_chromosomes += population_manip.crossover(population, selection_method = "random", crossover_span = 225)
            #---Random Mutation
            new_chromosomes.append(population_manip.mutation(population, selection_method = "random", mutation_amt = 150))

            # Compute fitness for the new chromosomes
            for chromosome in new_chromosomes:
                chromosome.compute_fitness(vocab, cooccurrence)

            # Build out a temporary population
            temp_population = population + new_chromosomes
            '''for chromosome in temp_population:
                chromosome.compute_fitness(vocab, cooccurrence)'''

            # Sort the population & prune it to the appropriate size
            new_population = population_manip.sort(temp_population)[:config["population_size"]]
            population = new_population

            # If the current interation meets the checkpoint interval, save the current embedding as a checkpoint
            if not (timestep % config["checkpoint_interval"]):
                print("CHECKPOINT: Saving checkpoint for {}-bit population".format(population[0].dimension))
                with open(config["output_path"].replace("<timestamp>", current_time) + "/ckpt-{}-embedding-{}bit.txt".format(timestep, population[0].dimension), "w") as output_file:
                    # Decode the top-performing chromosome
                    embed = population[0].decode(vocab)
                    # For each word in the embedding, write it to the file
                    for word in embed:
                        #output_file.write(word + "\t" + bin(embed[word])[2:] + "\n")
                        output_file.write(word + "\t" + embed[word] + "\n")

            # If the current iteration meets the logging interval, log the current fitnesses
            if not (timestep % config["logging_interval"]):
                log = "ITERATION {}: {}-bit population\n\tHigh:<{} {} {}>\n\tLow: <{} {} {}>".format(timestep, config["dimension"], population[0].fitness, population[1].fitness, population[2].fitness, population[-3].fitness, population[-2].fitness, population[-1].fitness)
                if not (timestep % config["checkpoint_interval"]):
                    # On each checkpoint interval, log the MRR of the current model
                    checkpoint = str(config["output_path"].replace("<timestamp>", current_time) + "/ckpt-{}-embedding-{}bit.txt".format(timestep, population[0].dimension))
                    total_stats, _ = word_analogy.mrr_analysis(checkpoint)
                    log += "\nUsing {} analogies & {} vocab words, total MRR is: {}".format(total_stats["used"], len(vocab), total_stats["mrr"])
                print(log)
                logfile = open("{}/log-{}bit-{}.txt".format(config["logging_path"], population[0].dimension, current_time), "a")
                logfile.write(log + "\n")
                logfile.close()

            # Update the progress bar
            bar()

    print("EXPORT: Exporting finished embedding")
    with open(config["output_path"].replace("<timestamp>", current_time) + "/embedding-{}bit.txt".format(population[0].dimension), "w") as output_file:
        # Decode the top-performing chromosome
        embed = population[0].decode(vocab)
        # For each word in the embedding, write it to the file
        for word in embed:
            #output_file.write(word + "\t" + bin(embed[word])[2:] + "\n")
            output_file.write(word + "\t" + embed[word] + "\n")
        output_file.close()
    print("EXPORT: Finished embedding exported successfully")

main()