# word_analogy.py
# This script handles evaluating a genetic embedding on the word analogy task outlined by Mikolov et. al.


# Internal Imports

# External Imports
import json

# Globals
chunk_size = 16


def main(embedding_path="", dim=0, analogy_path=""):
    # Start by loading the configuration
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["evaluation"]["word_analogy"]

    # Import the chromosome
    print("PREPARATION: Importing candidate chromosome")
    embedding = import_chromosome(config["embedding_path"], config["dimension"])

    # Import the word analogy dataset
    print("PREPARATION: Importing word analogy dataset")
    analogies = import_analogies(config["analogy_path"])

    # For each analogy section, attempt to test the embedding on it
    category_stats = []
    for category in analogies:
        analogy_count = len(analogies[category])
        category_counter = 0
        category_success_counter = 0
        wrong_analogies = []
        discarded_analogies = 0
        for analogy in analogies[category]:
            # First, figure out whether the embedding supports the words
            discard_analogy = False
            for word in analogy:
                if not word in embedding.keys():
                    discard_analogy = True
            # If the analogy needs to be discarded, ignore it
            if discard_analogy:
                discarded_analogies += 1
            else:
                # Otherwise, prepare to test the analogy
                category_counter += 1
                # For each analogy, use the first three words to try to find the fourth
                print(analogy)
                candidate_embed = ((int(embedding[analogy[1]], 2) ^ int(embedding[analogy[0]], 2)) | int(embedding[analogy[2]], 2))
                #candidate_embed = ((2 ** config["dimension"]) - 1) ^ candidate_embed# Invert all the bits
                candidate_embed = bin(candidate_embed)[2:].rjust(config["dimension"], "0")
                #print("Analogy 1st word:", embedding[analogy[0]])
                #print("Analogy 2nd word:", embedding[analogy[1]])
                #print("Analogy 3rd word:", embedding[analogy[2]])
                #print("Candidate embed: ", candidate_embed)
                #print("Analogy 4th word:", embedding[analogy[3]])
                # With the candidate embedding, get the item closest to it
                closest_word = ""
                highest_similarity = 0
                for key in embedding.keys():
                    candidate_similarity = similarity(candidate_embed, embedding[key], config["dimension"])
                    if candidate_similarity > highest_similarity:
                        closest_word = key
                        highest_similarity = candidate_similarity
                        # TO-DO: Re-organize this part to use a sorted list - can be used to determine how good/bad this is
                print(closest_word, highest_similarity, "-", analogy[3], similarity(candidate_embed, embedding[analogy[3]], config["dimension"]))
                # If the analogy is correct, increment the total
                if closest_word == analogy[3]:
                    category_success_counter += 1
                else:# Otherwise, increment the incorrect counter
                    wrong_analogies.append(analogy)
        category_stats.append({"category":category, "total":analogy_count, "used":category_counter, "correct":category_success_counter, "wrong":category_counter - category_success_counter})
    
    # Print out the statistics
    for stat in category_stats:
        print(stat)


def mrr_analysis(embedding_path="", dim=0, analogy_path="", log_output = False):
    # Start by loading the configuration
    with open("src/genetic_embedding/core/config.json") as config_file:
        config = json.load(config_file)["evaluation"]["word_analogy"]

    # Import the chromosome
    #print("PREPARATION: Importing candidate chromosome")
    dimension = int(config["dimension"])
    if dim:
        dimension = dim
    embedding = []
    if embedding_path:
        embedding = import_chromosome(embedding_path, dimension)
    else:
        embedding = import_chromosome(config["embedding_path"], dimension)

    # Import the word analogy dataset
    #print("PREPARATION: Importing word analogy dataset")
    analogy_loc = config["analogy_path"]
    if analogy_path:
        analogy_loc = analogy_path
    analogies = import_analogies(analogy_loc)

    # For each analogy section, attempt to test the embedding on it
    positionwise_dist = [0, 0, 0, 0, 0]
    category_stats = []
    for category in analogies:
        analogy_count = len(analogies[category])
        item_stats = {
            "category":category,
            "total":analogy_count,
            "used":0,
            "correct":0,
            "wrong":0,
            "mean_rank":0.0,
            "mrr":0.0
        }
        # For each analogy, calculate the MRR
        for analogy in analogies[category]:
            # Only process the analogy if the embedding supports all the words
            if (not False in [(word in embedding.keys()) for word in analogy]):
                # Prepare to test the analogy
                item_stats["used"] += 1
                # For each analogy, use the first three words to find the fourth
                candidate_embed = bitwise_or(xor(embedding[analogy[0]], embedding[analogy[1]]), embedding[analogy[2]])
                # With the candidate embedding, get a sorted vocab list
                #sorted_vocab = [{"word":word, "similarity":similarity(candidate_embed, embedding[word])} for word in embedding.keys()].sort()
                sorted_vocab = sorted(list(embedding.keys()), reverse=True, key=lambda vocab_word: similarity(candidate_embed, embedding[vocab_word], dimension))
                # Get the rank of the correct word
                rank = sorted_vocab.index(analogy[3])
                if not rank:
                    item_stats["correct"] += 1
                else:
                    item_stats["wrong"] += 1
                # Add to the position-wise distributions, if applicable
                if rank < len(positionwise_dist):
                    positionwise_dist[rank] += 1
                item_stats["mean_rank"] += rank+1
                item_stats["mrr"] += 1/(rank+1)
        # Divide the MRR by the number of analogies used
        item_stats["mean_rank"] /= item_stats["used"]
        item_stats["mrr"] /= item_stats["used"]
        # Add the item stats to the category breakdown
        category_stats.append(item_stats)
        if log_output:
            print(item_stats)

    # Put the categories together to get a total breakdown
    total_stats = {
        "total":0,
        "used":0,
        "correct":0,
        "wrong":0,
        "mean_rank":0.0,
        "mrr":0.0,
        "top5-total":[],
        "top1-percent":0.0,
        "top3-percent":0.0,
        "top5-percent":0.0
    }
    for category in category_stats:
        total_stats["total"] += category["total"]
        total_stats["used"] += category["used"]
        total_stats["correct"] += category["correct"]
        total_stats["wrong"] += category["wrong"]
        total_stats["mean_rank"] += category["mean_rank"]
        total_stats["mrr"] += category["mrr"]
    total_stats["mean_rank"] /= len(category_stats)
    total_stats["mrr"] /= len(category_stats)
    total_stats["top5-total"] = positionwise_dist
    total_stats["top1-percent"] = positionwise_dist[0] / total_stats["total"]
    total_stats["top3-percent"] = sum(positionwise_dist[:3]) / total_stats["total"]
    total_stats["top5-percent"] = sum(positionwise_dist) / total_stats["total"]
    if log_output:
        print(total_stats)

    # Return the total stats & detailed breakdown
    return total_stats, category_stats


def import_chromosome(path, dim):
    # Open the path to the chromosome
    embedding = {}
    dimension = dim
    num_inconsistent = 0
    with open(path) as chromosome_file:
        for line in chromosome_file:
            line = line.strip().split("\t")
            embedding[line[0]] = line[1]
            if not len(line[1]) == dimension:
                #print(line[1], len(line[1]))
                num_inconsistent += 1
    if num_inconsistent > 0:
        print("WARNING: {}/{} embeddings deviate from the standard {} bits".format(num_inconsistent, len(embedding), dimension))
    return embedding


def import_analogies(path, uncased = True):
    # Open the path to the analogies
    analogies = {}
    with open(path) as analogy_file:
        category = "misc"
        analogies[category] = []
        for line in analogy_file:
            line = line.strip()
            if line[0] == ":":
                # This line starts a new category
                category = line[2:]
                analogies[category] = []
            else:
                # This line holds an analogy
                if uncased:
                    line = line.lower()
                analogies[category].append(line.split())
    if not analogies["misc"]:
        del analogies["misc"]
    return analogies


def similarity(embed1, embed2, dimension):
    similarity = 0
    for i in range(int(dimension / chunk_size)):
        binary = bin(int(embed1[i * chunk_size:(i+1) * chunk_size], 2) ^ int(embed2[i * chunk_size:(i+1) * chunk_size], 2))
        similarity += binary[2:].rjust(chunk_size, "0").count("0")
    return similarity
    #binary = bin(int(embed1, 2) ^ int(embed2, 2))[2:]
    #return binary.rjust(dimension, "0").count("0")


def distance(embed1, embed2, dimension):
    similarity = 0
    for i in range(int(dimension / chunk_size)):
        binary = bin(int(embed1[i * chunk_size:(i+1) * chunk_size], 2) ^ int(embed2[i * chunk_size:(i+1) * chunk_size], 2))
        similarity += binary[2:].rjust(chunk_size, "0").count("1")
    return similarity
    #binary = bin(int(embed1, 2) ^ int(embed2, 2))[2:]
    #return binary.rjust(dimension, "0").count("1")


'''def xor(embed1, embed2):
    result = ""
    for i in range(int(len(embed1) / chunk_size)):
        binary = bin(int(embed1[i:chunk_size * (i+1)], 2) ^ int(embed2[i:chunk_size * (i+1)], 2))
        result += binary[2:].rjust(chunk_size, "0")
    return result'''
def xor(embed1, embed2):
    result = ""
    for i in range(int(len(embed1)/chunk_size)):
        binary = bin(int(embed1[i * chunk_size:(i+1) * chunk_size], 2) ^ int(embed2[i * chunk_size:(i+1) * chunk_size], 2))
        result += binary[2:].rjust(chunk_size, "0")
    return result


'''def bitwise_or(embed1, embed2):
    result = ""
    for i in range(int(len(embed1) / chunk_size)):
        binary = bin(int(embed1[i:chunk_size * (i+1)], 2) | int(embed2[i:chunk_size * (i+1)], 2))
        result += binary[2:].rjust(chunk_size, "0")
    return result'''
def bitwise_or(embed1, embed2):
    result = ""
    for i in range(int(len(embed1)/chunk_size)):
        binary = bin(int(embed1[i * chunk_size:(i+1) * chunk_size], 2) | int(embed2[i * chunk_size:(i+1) * chunk_size], 2))
        result += binary[2:].rjust(chunk_size, "0")
    return result


if __name__ == "__main__":
    #main()
    mrr_analysis(log_output = True)