# training_convergence.py
# This script handles plotting the training convergence of a given log file


# Internal Imports

# External Imports
import json
import matplotlib
import matplotlib.pyplot as plt

# Globals
chunk_size = 16
output_dir = "src/genetic_embedding/output/graphs/"


def main(scaling = False):
    # Start by loading the configuration
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["graphics"]
    matplotlib.rcParams.update({'font.size': 14})

    # Add the timestep labels
    time = [x/1000 for x in range(0, config["convergence"]["max_timestep"] + config["convergence"]["interval"], config["convergence"]["interval"])]

    # For each file specified, generate datapoints
    data = []
    for file in config["convergence"]["logfiles"]:
        data.append(generate_datapoints(config["general"]["log_path"]+file, config["convergence"]["log_interval"], config["convergence"]["interval"], config["convergence"]["max_timestep"], scaling))
    
    # Generate the plot
    counter = 0
    fig = plt.figure()
    fig.supxlabel("Timestep (Thousands)")
    fig.supylabel("Fitness (Scaled)")
    axl = fig.add_subplot(111)
    for file in config["convergence"]["logfiles"]:
        axl.plot(time, data[counter], c=config["general"]["colors"][counter], label=file.split("-")[1][:-3]+"-bit")
        counter += 1
    plt.legend(loc = "lower right")
    fig.subplots_adjust(bottom=.15)
    # Save the plot
    plt.savefig(output_dir + "training_convergence.png")



def generate_datapoints(filename, log_interval, graph_interval, max_timestep, scaling = False):
    # Start by opening the file & getting the targeted lines
    lines = []
    with open(filename) as input_file:
        lines = [line for line in input_file if "Top 3" in line]
    # Iterate over the lines, generating data points for each one
    counter = 0
    result = []
    for line in lines:
        temp = line.split(" ")
        fitness = temp[3]
        if not counter % graph_interval:
            result.append(float(fitness))
        counter += log_interval
        if (counter > max_timestep):
            break
    return result if not scaling else [value / result[-1] for value in result]


if __name__ == "__main__":
    main(scaling = True)