# regression_based.py
# This file handles calculating fitness in a regression-based manner inspired by the GloVe model


# Internal Imports
from src.common_utils import directory

# External Imports
import json
import math
import random
import numpy as np
from itertools import combinations

# Global Variables
chunk_size = 16


'''
----------sample_function_comment----------
- <function_description>
-----Inputs-----
- <input_listing> - description
-----Output-----
- <output_listing> - description
'''

'''
----------calculate_fitness----------
- This function calculates the fitness according to the word analogies supplied
-----Inputs-----
- chromosome - The chromosome for which to calculate the fitness
- vocab - The vocab to use
-----Output-----
- fitness - The fitness of the chromosome
'''
def calculate_correctness_fitness(chromosome, vocab, method = "get_relationship"):
    # Decode the chromosome
    embedding = chromosome.decode(vocab, binary = False)

    # Load the analogy location from the configuration file
    analogy_loc = ""
    with open("src/genetic_embedding/core/config.json") as config_file:
        analogy_loc = json.load(config_file)["evaluation"]["word_analogy"]["analogy_path"]

    # Next, open the analogy file and begin calculating fitness as a function of the number of analogies gotten correct
    total_correct = 0
    total_analogy = 0
    # Define a random dropout threshold (Set it to 0 to remove analogy dropout)
    dropout_probability = .75 # Percentage of analogies to drop
    # Define a similarity threshold as a function of the dimension
    similarity_threshold = math.floor(chromosome.dimension - (chromosome.dimension / 4))
    with open(analogy_loc) as analogy_file:
        # For every line in the file, get ready to test
        for line in analogy_file:
            # Do some pre-processing
            analogy = ""
            line = line.strip()
            if line[0] == ":":
                # This line denotes a category. We don't care about that here, so just ignore it
                continue
            else:#if (np.random.uniform(0,1) >= dropout_probability):
                # This line holds an analogy, get ready to test it
                line = line.lower()
                analogy = line.split()
                total_analogy += 1
            '''else:
                continue'''
            
            correct = False
            if method == "get_relationship":
                # Get the two candidate embeddings
                candidate_embed = [
                    embedding[analogy[1]] ^ embedding[analogy[0]],
                    embedding[analogy[3]] ^ embedding[analogy[2]]
                ]
                # Alternate way to get this (w1 + w4) - (w2 + w3)
                '''candidate_embed = [
                    embedding[analogy[0]] | embedding[analogy[4]],
                    embedding[analogy[1]] | embedding[analogy[2]]
                ]'''

                # Get the similarity between them
                candidate_similarity = similarity(candidate_embed[0], candidate_embed[1], chromosome.dimension)
                # If the similarity is above a threshold, count it as "correct"
                correct = (candidate_similarity > similarity_threshold)
            else: # We're going to try to get the fourth word
                # Test the analogy
                candidate_embed = ((embedding[analogy[1]] ^ embedding[analogy[0]]) | embedding[analogy[2]])
                #candidate_embed = bin(candidate_embed)[2:].rjust(chromosome.dimension, "0")

                # Get the closest item to the embedding (ChECK THIS OVER)
                closest_word = ""
                highest_similarity = 0
                for key in embedding.keys():
                    candidate_similarity = similarity(candidate_embed, embedding[key], chromosome.dimension)
                    if candidate_similarity > highest_similarity:
                        closest_word = key
                        highest_similarity = candidate_similarity

                correct = (closest_word == analogy[3])
            
            # If the analogy was correct, increment the counter
            if correct:
                total_correct += 1

    # Get the word pairs
    word_pairs = list(combinations(sorted(vocab), 2))
    total_word_pairs = len(word_pairs)
    total_same_pairs = len([(word1, word2) for (word1, word2) in word_pairs if embedding[word1] == embedding[word2]])

    # Get the total number correct as a decimal
    #print("Total Correct: {}\nTotal Analogy: {}\nTotal Same Pairs: {}\nTotal Word Pairs: {}".format(total_correct, total_analogy, total_same_pairs, total_word_pairs))
    return (total_correct / total_analogy) - (total_same_pairs / total_word_pairs)
    # Return the total number correct. That's the fitness here
    return total_correct


'''
----------calculate_fitness----------
- This function calculates the fitness according to the word analogies supplied
-----Inputs-----
- chromosome - The chromosome for which to calculate the fitness
- vocab - The vocab to use
-----Output-----
- fitness - The fitness of the chromosome
'''
def calculate_bitcount_fitness(chromosome, vocab, analogy_loc = ""):
    # Decode the chromosome
    embedding = chromosome.decode(vocab)

    # Load the analogy location from the configuration file
    #analogy_loc = ""
    if not analogy_loc:
        with open("src/genetic_embedding/core/config.json") as config_file:
            #analogy_loc = json.load(config_file)["evaluation"]["word_analogy"]["analogy_path"]
            analogy_loc = json.load(config_file)["scripts"]["data_path"]

    # Next, open the analogy file & begin calculating fitness
    total_fitness = 0
    # Define a random dropout threshold (Set it to 0 to remove analogy dropout)
    dropout_probability = 0#.75 # Percentage of analogies to drop
    with open(analogy_loc) as analogy_file:
        # For every line in the file, get ready to test
        for line in analogy_file:
            # Do some light pre-processing
            analogy = ""
            line = line.strip()
            if line[0] == ":":
                continue
            elif (np.random.uniform(0,1) >= dropout_probability):
                line = line.lower()
                analogy = line.split()
            else:
                continue

            # Test the analogy
            '''analogy_embeds = [
                xor(embedding[analogy[0]], embedding[analogy[1]]),
                xor(embedding[analogy[2]], embedding[analogy[3]])
            ]
            false_example_base = [
                xor(embedding[analogy[0]], embedding[random.choice(list(vocab.keys()))]),
                xor(embedding[random.choice(list(vocab.keys()))], embedding[analogy[1]]),
                xor(embedding[analogy[2]], embedding[random.choice(list(vocab.keys()))]),
                xor(embedding[random.choice(list(vocab.keys()))], embedding[analogy[3]])
            ]
            false_examples = [
                xor(false_example_base[0], analogy_embeds[1]),
                xor(false_example_base[1], analogy_embeds[1]),
                xor(false_example_base[2], analogy_embeds[0]),
                xor(false_example_base[3], analogy_embeds[0])
            ]'''

            # Calculate the fitness number based on the analogy & negative samples
            '''temp_fitness = 4 * xor(analogy_embeds[0], analogy_embeds[1]).count("0")
            for item in false_examples:
                temp_fitness += item.count("1")
            total_fitness += temp_fitness'''
            temp_fitness = xor(embedding[analogy[0]], embedding[analogy[1]])
            temp_fitness = bitwise_or(temp_fitness, embedding[analogy[2]])
            temp_fitness = xor(temp_fitness, embedding[analogy[3]])
            total_fitness += temp_fitness.count("0")
    # Return the fitness
    return total_fitness



'''
----------get_analogy_subset----------
- This function
-----Inputs-----
- item - description
-----Output-----
- item - description
'''
def get_analogy_subset():
    KEEP_PERCENTAGE = 5 # The percent of analogies to keep as a subset
    SUBSET_DIRECTORY = "data/word_analogy_subset/"

    # Load the analogy location from the configuration file
    analogy_loc = ""
    with open("src/genetic_embedding/core/config.json") as config_file:
        analogy_loc = json.load(config_file)["evaluation"]["word_analogy"]["analogy_path"]
    
    # Create the subset directory & open the file
    directory.create(SUBSET_DIRECTORY)
    subset_file = open(SUBSET_DIRECTORY + "questions-words.txt", "w")

    # Read in the analogies
    with open(analogy_loc) as analogy_file:
        # For each line in the file, write the chosen analogies to the subset file
        for line in analogy_file:
            # If the line is an analogy, keep it on a random chance (based on KEEP_PERCENTAGE)
            if (line[0] != ":") and (np.random.uniform(0,1) <= KEEP_PERCENTAGE/100):
                # Write the analogy to the file
                subset_file.write(line)
    
    # Delete the vocab & co-occurrence files, if needed
    directory.delete(SUBSET_DIRECTORY + "vocab.json")
    directory.delete(SUBSET_DIRECTORY + "co_occurrence.json")
    pass



def similarity(embed1, embed2, dimension):
    similarity = 0
    for i in range(dimension / chunk_size):
        binary = bin(int(embed1[i * chunk_size:(i+1) * chunk_size], 2) ^ int(embed2[i * chunk_size:(i+1) * chunk_size], 2))
        similarity += binary[2:].rjust(chunk_size, "0").count("0")
    return similarity
    #binary = bin(embed1 ^ embed2)[2:]
    #return binary.rjust(dimension, "0").count("0")


def distance(embed1, embed2, dimension):
    similarity = 0
    for i in range(dimension / chunk_size):
        binary = bin(int(embed1[i * chunk_size:(i+1) * chunk_size], 2) ^ int(embed2[i * chunk_size:(i+1) * chunk_size], 2))
        similarity += binary[2:].rjust(chunk_size, "0").count("1")
    return similarity
    #binary = bin(embed1 ^ embed2)[2:]
    #return binary.rjust(dimension, "0").count("1")


'''def xor(embed1, embed2):
    result = ""
    for i in range(int(len(embed1) / chunk_size)):
        binary = bin(int(embed1[i:chunk_size * (i+1)], 2) ^ int(embed2[i:chunk_size * (i+1)], 2))
        result += binary[2:].rjust(chunk_size, "0")
    return result'''
def xor(embed1, embed2):
    result = ""
    for i in range(int(len(embed1)/chunk_size)):
        binary = bin(int(embed1[i * chunk_size:(i+1) * chunk_size], 2) ^ int(embed2[i * chunk_size:(i+1) * chunk_size], 2))
        result += binary[2:].rjust(chunk_size, "0")
    return result


'''def bitwise_or(embed1, embed2):
    result = ""
    for i in range(int(len(embed1) / chunk_size)):
        binary = bin(int(embed1[i:chunk_size * (i+1)], 2) | int(embed2[i:chunk_size * (i+1)], 2))
        result += binary[2:].rjust(chunk_size, "0")
    return result'''
def bitwise_or(embed1, embed2):
    result = ""
    for i in range(int(len(embed1)/chunk_size)):
        binary = bin(int(embed1[i * chunk_size:(i+1) * chunk_size], 2) | int(embed2[i * chunk_size:(i+1) * chunk_size], 2))
        result += binary[2:].rjust(chunk_size, "0")
    return result



if (__name__ == "__main__"):
    # If this file is being run, create the analogy subset
    get_analogy_subset()