# regression_based.py
# This file handles calculating fitness in a regression-based manner inspired by the GloVe model


# Internal Imports

# External Imports
import math
from itertools import combinations


'''
----------sample_function_comment----------
- <function_description>
-----Inputs-----
- <input_listing> - description
-----Output-----
- <output_listing> - description
'''

'''
----------calculate_fitness----------
- This function calculates the fitness according to a Mikolov-style specification
-----Inputs-----
- chromosome - The chromosome for which to calculate the fitness
- vocab - The vocab to use
- cooccurrence - The co-occurrence matrix
-----Output-----
- fitness - The fitness of the chromosome
'''
def calculate_fitness(chromosome, vocab, cooccurrence):
    # Decode the chromosome
    embedding = chromosome.decode(vocab, binary = False)# Convert the embedding to an integer-based representation for easy calculation

    # Get the word pairs
    word_pairs = list(combinations(sorted(vocab), 2))

    # Define an inner function to calculate the loss between two given words
    def inter_word_loss(word1, word2):
        bitcount = bin(embedding[word1] ^ embedding[word2])[2:].count("0")# XOR 0-count
        prediction = simcheck(bitcount, chromosome.dimension)
        ground_truth = scaled_sigmoid(math.log(cooccurrence[word1][word2], 10), chromosome.dimension)
        return (prediction - ground_truth) ** 2

    loss = 0.0
    # For each word pair, calculate the award and penalty
    '''for (word1, word2) in word_pairs:
        if word1 in cooccurrence and word2 in cooccurrence[word1] and not (word1 == word2):
            bitcount = bin(embedding[word1] ^ embedding[word2])[2:].count("0") # XOR 0-count
            #bitcount = bin(int(embedding[word1], 2) ^ int(embedding[word2], 2))[2:].count("0") # XOR 0-count
            prediction = simcheck(bitcount, chromosome.dimension)
            ground_truth = scaled_sigmoid(math.log(cooccurrence[word1][word2], 10), chromosome.dimension)
            loss += (prediction - ground_truth) ** 2'''
    # Compressed loop into a list comprehension to make it run slightly faster
    loss_numbers = [inter_word_loss(word1, word2) for (word1, word2) in word_pairs if word1 in cooccurrence and word2 in cooccurrence[word1] and not (word1 == word2)]
    loss = sum(loss_numbers)
    
    # Return the fitness
    return 0 - loss


'''
----------simcheck----------
- This function checks the similarity between two words, and rejects it if the two embeddings are exactly the same
-----Inputs-----
- bitcount - The value of the XOR 0-count
- dimension - The dimension of the chromosome
-----Output-----
- value - The new value after the similarity check
'''
def simcheck(bitcount, dimension):
    alpha = 5
    if (bitcount == dimension):
        return alpha * bitcount
    else:
        return bitcount


'''
----------scaled_sigmoid----------
- This function takes a value and returns that value crushed to between 0 and the dimension supplied (Using the sigmoid function)
-----Inputs-----
- x - The value to scale
- dimension - The dimension of the chromosome
-----Output-----
- y - The resulting scaled value
'''
def scaled_sigmoid(x, dimension):
    y = dimension / (1 + math.exp(-x))
    return y