# population.py
# This file handles the population construct for use in training binary word embeddings


# Internal Imports
from src.genetic_embedding.core.chromosome import Chromosome
from src.genetic_embedding.core.interface import population_manip
from src.common_utils import directory, timestamp
from src.genetic_embedding.scripts.evaluation import word_analogy

# External Imports
from alive_progress import alive_bar
import random


class Population:
    # Constructor
    def __init__(self, pop_size, dimension, vocab, logging_interval = 100, checkpoint_interval = 1000, goal_performance = .8, maximum_iteration = 100000, logging_path = "src/genetic_embedding/output/logs", output_path = "embedding/genetic_embedding/output_embeddings/<timestamp>", evaluation_filepath = "", fold_number = 0):
        # Initialize member variables
        #Population Properties
        self.pop_size = pop_size
        self.dimension = dimension
        #self.max_embed_length = int(len(vocab) * dimension)
        self.vocab = vocab
        self.function_evals = 0
        #Manipulation Properties
        self.selection_method = "tournament"
        self.crossover_method = "uniform"
        self.selection_pressure = 5
        self.mutation_rate = .01
        #Evolution Properties
        self.injection_interval = 5000
        self.log_interval = logging_interval
        self.ckpt_interval = checkpoint_interval
        self.performance_goal = goal_performance
        self.max_iteration = maximum_iteration
        self.members = [None] * pop_size
        #Logging Properties
        self.timestamp = timestamp.create()
        self.logging_path = logging_path
        self.output_path = output_path.replace("<timestamp>", self.timestamp)
        self.current_performance = 0.0
        self.mean_rank = 0.0
        self.perf_eval_used = 0
        self.performance_step = 0
        #Evaluation Properties
        self.eval_filepath = evaluation_filepath
        self.fold_num = fold_number
        if (self.fold_num > 0):
            self.output_path = self.output_path + f'fold{self.fold_num}'

        # Create the logging & output directories
        directory.create(self.logging_path)
        directory.create(self.output_path)

        # Spawn in the individuals
        self.init_population()


    def __getitem__(self, key):
        return self.members[key]


    def __str__ (self):
        string = "ITERATION {}: {}-bit population, size {}\n".format(self.function_evals, self.dimension, self.pop_size)
        string += "\tTop 3 Performers: {} {} {}\n".format(self[0], self[1], self[2])
        string += "\tBottom 3 Performers: {} {} {}\n".format(self[-3], self[-2], self[-1])
        string += "\tTotal Function Evaluations: {}".format(self.function_evals)
        if self.function_evals == self.performance_step:
            string += "\n\tCurrent Performance: {}\n\t(mean rank {} across {} analogies & {} vocab words)".format(self.current_performance, self.mean_rank, self.perf_eval_used, len(self.vocab))
        return string


    def init_population(self, strategy = "random"):
        # Abstract this part out to population_manip/initialization & make an interface for it that gets called here
        with alive_bar(self.pop_size, bar="smooth", spinner="classic") as bar:
            for i in range(self.pop_size):
                self.members[i] = Chromosome(vocab=self.vocab, dim=self.dimension)
                self.members[i].compute_fitness(self.vocab, filepath=self.eval_filepath)
                bar()
        
        # Checkpoint & log the initial population
        self.checkpoint()
        self.log()

    
    def sort(self):
        self.members = sorted(self.members, reverse=True)#key = lambda cr: cr.fitness, 


    def select(self):
        population_manip.select(self.members, method=self.selection_method)


    def evolution_cycle(self):
        # Assume the population starts out sorted (Minus the random solution)
        # Perform crossover
        '''
            Here is where some multithreading might be able to go. Split off into threads & perform crossover multiple times per iteration.
            This may throw off some of the function evaluation numbers, so it may be a no-go.
        '''
        '''child = population_manip.crossover(self.members, selection_method=self.selection_method, crossover_method=self.crossover_method)
        child.compute_fitness(self.vocab)'''
        #---Crossover
        new_chromosomes = population_manip.crossover(self.members)
        #---Mutation
        new_chromosomes.append(population_manip.mutation(self.members))
        new_chromosomes.append(population_manip.mutation(self.members, mutation_amt = 155))
        #---Random Crossover
        new_chromosomes += population_manip.crossover(self.members, selection_method = "random", crossover_span = 225)
        #---Random Mutation
        new_chromosomes.append(population_manip.mutation(self.members, selection_method = "random", mutation_amt = 150))

        # Compute fitness for the new chromosomes
        for chromosome in new_chromosomes:
            chromosome.compute_fitness(self.vocab, filepath=self.eval_filepath)
        # Add the child into the population, replacing the lowest member if it's better
        '''if child > self.members[-1]:
            self.members[-1] = child'''
        temp_population = self.members + new_chromosomes
        self.members = population_manip.sort(temp_population)[:self.pop_size]
        # Sort the population
        self.sort()
        # If we hit an interval of some kind, execute the associated functions
        self.function_evals += 1
        if self.ckpt_interval and not self.function_evals % self.ckpt_interval:
            self.checkpoint()
        if self.log_interval and not self.function_evals % self.log_interval:
            self.log()
        if self.injection_interval and not self.function_evals % self.injection_interval:
            # Replace a random (not top-performing) chromosome in the population with a random one
            random_solution = Chromosome(self.vocab, dim=self.dimension)
            random_solution.compute_fitness(self.vocab, filepath=self.eval_filepath)
            self.members[random.choice(range(1,self.pop_size))] = random_solution


    def log(self):
        logloc = "{}/log-{}bit-{}.txt".format(self.logging_path, self.dimension, self.timestamp)
        if self.fold_num > 0:
            logloc = "{}/log-{}bit-{}-fold{}.txt".format(self.logging_path, self.dimension, self.timestamp, self.fold_num)
        with open(logloc, "a") as logfile:
            logfile.write(str(self) + "\n")
            logfile.close()


    def checkpoint(self):
        # Open the output file
        checkpoint = self.output_path + "/ckpt-{}-embedding-{}bit.txt".format(self.function_evals, self.dimension)
        with open(checkpoint, "w") as output_file:
            # Decode the top-performing chromosome
            embed = self.members[0].decode(self.vocab)
            # For each word in the embedding, write it to the file
            for word in embed:
                output_file.write(word + "\t" + embed[word] + "\n")
            output_file.close()
        # Check & record the MRR of the current model
        total_stats, _ = word_analogy.mrr_analysis(checkpoint, dim = self.dimension, analogy_path = self.eval_filepath)
        self.performance_step = self.function_evals
        self.perf_eval_used = total_stats["used"]
        self.mean_rank = total_stats["mean_rank"]
        self.current_performance = total_stats["mrr"]


    def export_population(self):
        pass# This function can be added in later so we can further refine already-trained populations