# word_analogy.py
# This script handles evaluating a genetic embedding on the word analogy task outlined by Mikolov et. al.


# Internal Imports
from src.genetic_embedding.scripts.evaluation import word_analogy

# External Imports
import json
import matplotlib
import matplotlib.pyplot as plt
from alive_progress import alive_bar

# Globals
chunk_size = 16
output_dir = "src/genetic_embedding/output/graphs/"


def main(scaling = False):
    # Start by loading the configuration
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["graphics"]
    matplotlib.rcParams.update({'font.size': 12})

    # Add the timestep labels
    time = [x/1000 for x in range(0, config["performance"]["max_timestep"] + config["performance"]["interval"], config["performance"]["interval"])]

    # For each file specified, generate datapoints
    data = []
    for file in config["performance"]["logfiles"]:
        #data.append(generate_datapoints(config["general"]["log_path"]+file, config["performance"]["ckpt_interval"], config["performance"]["interval"], config["performance"]["max_timestep"], scaling))
        data.append(calculate_performance(config["general"]["embed_path"]+file.split(".")[0][10:-6] + file.split(".")[0][-5:], file.split("-")[1][:2], config["performance"]["ckpt_interval"], config["performance"]["interval"], config["performance"]["max_timestep"], config["performance"]["analogy_path"]))
    
    # Generate the plot
    counter = 0
    fig = plt.figure()
    fig.supxlabel("Timestep (Thousands)")
    fig.supylabel(("Testing Performance" if not scaling else "Testing Performance (Scaled)"))
    axl = fig.add_subplot(111)
    for file in config["performance"]["logfiles"]:
        #axl.plot(time, data[counter], c=config["general"]["colors"][counter], label="Fold "+file.split("-")[-1][4:5])
        axl.plot(time, data[counter], c=config["general"]["colors"][counter], label=file.split("-")[1][:2]+"-bit")
        counter += 1
    plt.legend(loc = "lower right")
    fig.subplots_adjust(bottom=.15)
    # Save the plot
    plt.savefig(output_dir + "testing_performance.png")



def generate_datapoints(filename, ckpt_interval, graph_interval, max_timestep, scaling = False):
    # Start by opening the file & getting the targeted lines
    lines = []
    with open(filename) as input_file:
        lines = [line for line in input_file if "Current Performance" in line]
    # Iterate over the lines, generating data points for each one
    counter = 0
    result = []
    for line in lines:
        temp = line.split(" ")
        fitness = float(temp[2])
        if not counter % graph_interval:
            result.append(float(fitness))
        counter += ckpt_interval
        if (counter > max_timestep):
            break
    return result if not scaling else [value / result[-1] for value in result]



def calculate_performance(filename, dimension, ckpt_interval, graph_interval, max_timestep, analogy_set):
    ckpt_pre = "ckpt-"
    ckpt_suf = "-embedding-" + dimension + "bit.txt"
    # For each checkpoint being graphed, calculate & record the MRR
    counter = 0
    result = []
    with alive_bar(int(max_timestep/ckpt_interval), bar="smooth", spinner="classic") as bar:
        while counter <= max_timestep:
            total_stats, _ = word_analogy.mrr_analysis(filename + "/" + ckpt_pre + str(counter) + ckpt_suf, dim = int(dimension), analogy_path = analogy_set)
            if not counter % graph_interval:
                result.append(total_stats["mrr"])
            bar()
            counter += ckpt_interval
    return result


if __name__ == "__main__":
    main(scaling = False)