import argparse
import numpy as np

from util import load_legacy_glove, write_w2v, writeAnalogies, evalTerms, convert_legacy_to_keyvec, load_legacy_w2v, pruneWordVecs
from biasOps import normalize, identify_bias_subspace, project_onto_subspace, neutralize_and_equalize, equalize_and_soften, calculateDirectBias, equalize_and_soften
from evalBias import generateAnalogies, multiclass_evaluation
from loader import load_def_sets, load_analogy_templates, load_test_terms, load_eval_terms
from scipy.stats import ttest_rel, spearmanr

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('embeddingPath')
    parser.add_argument('vocabPath')
    parser.add_argument('mode')
    parser.add_argument('-hard', action='store_true')
    parser.add_argument('-soft', action='store_true')
    parser.add_argument('-w', action='store_true')
    parser.add_argument('-g', action='store_true')
    parser.add_argument('-v', action='store_true')
    return parser.parse_args()

def main():
    args = parse_arguments()

    outprefix = args.vocabPath.replace("/", "_").replace("\\", "_").replace(".", "_")

    print "Loading vocabulary from {}".format(args.vocabPath)
    analogyTemplates = load_analogy_templates(args.vocabPath, args.mode)
    defSets = load_def_sets(args.vocabPath)
    testTerms = load_test_terms(args.vocabPath)

    neutral_words = []
    for value in analogyTemplates.values():
        neutral_words.extend(value)

    print "Loading embeddings from {}".format(args.embeddingPath)
    if(args.g):
        word_vectors, embedding_dim = load_legacy_glove(args.embeddingPath)
    else:
        word_vectors, embedding_dim = load_legacy_w2v(args.embeddingPath)

    print "Pruning Word Vectors... Starting with", len(word_vectors)
    word_vectors = pruneWordVecs(word_vectors)
    print "\tEnded with", len(word_vectors)

    print "Identifying bias subspace"
    subspace = identify_bias_subspace(word_vectors, defSets, 1, embedding_dim)[0]

    if(args.hard):
        print "Neutralizing and Equalizing"
        new_hard_word_vectors = neutralize_and_equalize(word_vectors, neutral_words,
                            defSets.values(), subspace, embedding_dim)
    if(args.soft):
        print "Equalizing and Softening"
        new_soft_word_vectors = equalize_and_soften(word_vectors, neutral_words,
                            defSets.values(), subspace, embedding_dim, verbose=args.v)

    print "Generating Analogies"
    biasedAnalogies = generateAnalogies(analogyTemplates, convert_legacy_to_keyvec(word_vectors))
    if(args.hard):
        hardDebiasedAnalogies = generateAnalogies(analogyTemplates, convert_legacy_to_keyvec(new_hard_word_vectors))
    if(args.soft):
        softDebiasedAnalogies = generateAnalogies(analogyTemplates, convert_legacy_to_keyvec(new_soft_word_vectors))

    if(args.w):
        print "Writing biased analogies to disk"
        writeAnalogies(biasedAnalogies, "out/" + outprefix + "_biasedAnalogiesOut.csv")
    elif(args.v):
        print "Biased Analogies (0-150)"
        for score, analogy in biasedAnalogies[:150]:
            print score, analogy

    if(args.w):
        if(args.hard):
            print "Writing hard debiased analogies to disk"
            writeAnalogies(hardDebiasedAnalogies, "out/" + outprefix + "_hardDebiasedAnalogiesOut.csv")
        if(args.soft):
            print "Writing soft debiased analogies to disk"
            writeAnalogies(softDebiasedAnalogies, "out/" + outprefix + "_softDebiasedAnalogiesOut.csv")
    elif(args.v):
        if(args.hard):
            print "="*20, "\n\n"
            print "Hard Debiased Analogies (0-150)"
            for score, analogy in hardDebiasedAnalogies[:150]:
                print score, analogy
        if(args.soft):
            print "="*20, "\n\n"
            print "Soft Debiased Analogies (0-150)"
            for score, analogy in softDebiasedAnalogies[:150]:
                print score, analogy
        
    if(args.w):
        print "Writing data to disk"
        write_w2v("out/" + outprefix + "_" + args.mode + "_biasedEmbeddingsOut.w2v", word_vectors)
        if(args.hard):
            write_w2v("out/" + outprefix + "_" + args.mode + "_hardDebiasedEmbeddingsOut.w2v", new_hard_word_vectors)
        if(args.soft):
            write_w2v("out/" + outprefix + "_" + args.mode + "_softDebiasedEmbeddingsOut.w2v", new_soft_word_vectors)
    
    print "Performing Evaluation"
    evalTargets, evalAttrs = load_eval_terms(args.vocabPath, args.mode)
    
    print "Biased Evaluation Results"
    biasedMean, biasedDistribution = multiclass_evaluation(word_vectors, evalTargets, evalAttrs)
    print "Biased Mean:", biasedMean

    if(args.hard):
        print "HARD Debiased Evaluation Results"
        debiasedMean, debiasedDistribution = multiclass_evaluation(new_hard_word_vectors, evalTargets, evalAttrs)
        print "HARD Debiased Mean:", debiasedMean

        statistics, pvalue = ttest_rel(biasedDistribution, debiasedDistribution)
        print "HARD Debiased Cosine difference t-test", pvalue

    if(args.soft):
        print "SOFT Debiased Evaluation Results"
        debiasedMean, debiasedDistribution = multiclass_evaluation(new_soft_word_vectors, evalTargets, evalAttrs)
        print "SOFT Debiased Mean:", debiasedMean

        statistics, pvalue = ttest_rel(biasedDistribution, debiasedDistribution)
        print "SOFT Debiased Cosine difference t-test", pvalue

    if(args.w):
        print "Writing statistics to disk"
        f = open("out/" + outprefix + "_statistics.csv", "w")
        f.write("BiasedMean,DebiasedMean,Statistics,P-Value\n")
        f.write(str(biasedMean) + "," +  str(debiasedMean) + "," + str(statistics) + "," + str(pvalue) + "\n")
        f.close()

if __name__ == "__main__":
    main()