
__author__ = 'Nick Dingwall and Christopher Potts'

%matplotlib inline
import random
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import euclidean
from mittens import Mittens
import utils

plt.style.use('mittens.mplstyle')

def get_random_count_matrix(n_words):
    """Returns a symmetric matrix where the entries are drawn from an
    exponential distribution. The goal is to provide some structure
    for GloVe to learn even with small vocabularies.
    """
    base = np.random.exponential(3.0, size=(n_words, n_words)) / 2    
    return np.floor(base + base.T)

def get_random_embedding_lookup(embedding_dim, vocab, percentage_embedded=0.5):
    """Returns a dict from `percentage_embedded` of the words in 
    `vocab` to random embeddings with dimension `embedding_dim`.
    We seek to make these representations look as much as possible
    like the ones we create when initializing GloVe parameters.
    """
    n_words = len(vocab)
    val = np.sqrt(6.0 / (n_words + embedding_dim)) * 2.0
    embed_size = int(n_words * percentage_embedded)
    return {w: np.random.uniform(-val, val, size=embedding_dim)
            for w in random.sample(vocab, embed_size)}    

def distance_test(mittens, G, embedding_dict, verbose=False):        
    dists = defaultdict(list)
    warm_start = mittens.G_start
    warm_orig = mittens.sess.run(mittens.original_embedding)
    for i in range(G.shape[0]):        
        if "w_{}".format(i) in embedding_dict:
            init = warm_orig[i]
            key = 'warm'
        else:
            init = warm_start[i]
            key = 'no warm'
        dist = euclidean(init, G[i]) 
        dists[key].append(dist)                    
    warm_mean = np.mean(dists['warm'])    
    no_warm_mean = np.mean(dists['no warm'])    
    return dists

def simulations(n_trials=5, n_words=500, embedding_dim=50, max_iter=1000, 
        mus=[0.001, 0.1, 0.5, 0, 1, 5, 10]):
    """Runs the simulations described in the paper. For `n_trials`, we
    
    * Generate a random count matrix
    * Generate initial embeddings for half the vocabulary.
    * For each of the specified `mus`:
        * Run Mittens at `mu` for `max_iter` times.
        * Assess the expected GloVe correlation between counts and
          representation dot products.
        * Get the mean distance from each vector to its initial
          embedding, with the expectation that Mittens will keep
          the learned embeddings closer on average, as governed
          by `mu`.
        
    The return value is a `pd.DataFrame` containing all the values
    we need for the plots.
    
    """    
    data = []
    vocab = ['w_{}'.format(i) for i in range(n_words)]
    for trial in range(1, n_trials+1):
        X = get_random_count_matrix(n_words)            
        embedding_dict = get_random_embedding_lookup(embedding_dim, vocab)  
        for mu in mus:                      
            mittens = Mittens(n=embedding_dim, max_iter=max_iter, mittens=mu)
            G = mittens.fit(X, vocab=vocab, initial_embedding_dict=embedding_dict)            
            correlations = utils.correlation_test(X, G)
            dists = distance_test(mittens, G, embedding_dict)                        
            d = {
                'trial': trial, 
                'mu': mu, 
                'corr_log_cooccur': correlations['log_cooccur'], 
                'corr_prob': correlations['prob'], 
                'corr_pmi': correlations['pmi'], 
                'warm_distance_mean': np.mean(dists['warm']),
                'no_warm_distance_mean': np.mean(dists['no warm'])
            }
            data.append(d)
    return pd.DataFrame(data)                        

data_df = simulations()

def get_corr_stats(vals, correlation_value='corr_prob'):
    """Helper function for `correlation_plot`: returns the mean
    and lower confidence interval bound in the format that 
    pandas expects.
    """
    mu = vals[correlation_value].mean() 
    lower, upper = utils.get_ci(vals[correlation_value])
    return pd.DataFrame([{'mean': mu, 'err': mu-lower}])

def correlation_plot(data_df, correlation_value='corr_prob'):
    """Produces Figure 1a."""
    corr_df = data_df.groupby('mu').apply(lambda x: get_corr_stats(x, correlation_value))
    corr_df = corr_df.reset_index().sort_values("mu", ascending=False)
    ax = corr_df.plot.barh(
        x='mu', y='mean', xerr='err', 
        legend=False, color=['gray'], 
        lw=1, edgecolor='black')
    ax.set_xlabel(r'Mean Pearson $\rho$')
    ax.set_ylabel(r'$\mu$')
    plt.savefig("naacl18-short/img/correlations-{}.pdf".format(correlation_value), layout='tight')

correlation_plot(data_df, correlation_value='corr_log_cooccur')

correlation_plot(data_df, correlation_value='corr_prob')

correlation_plot(data_df, correlation_value='corr_pmi')

def get_dist_stats(x):  
    """Helper function for `distance_plot`: returns the means
    and lower confidence interval bounds in the format that 
    pandas expects.
    """
    warm_mu = x['warm_distance_mean'].mean()
    warm_err = warm_mu-utils.get_ci(x['warm_distance_mean'])[0]
    no_warm_mu = x['no_warm_distance_mean'].mean()
    no_warm_err = no_warm_mu-utils.get_ci(x['no_warm_distance_mean'])[0]
    return pd.DataFrame([{
        'initial embedding': warm_mu,
        'initial embedding_ci': warm_err,
        'no initial embedding': no_warm_mu,
        'no initial embedding_ci': no_warm_err}])

def distance_plot(data_df):
    """Produces Figure 1b."""
    dist_df = data_df.groupby('mu').apply(get_dist_stats)
    dist_df = dist_df.reset_index(level=1).sort_index(ascending=False)
    err_df = dist_df[['initial embedding_ci', 'no initial embedding_ci']]
    err_df.columns = ['pretrained initialization', 'random initialization']
    data_df = dist_df[['initial embedding', 'no initial embedding']]
    data_df.columns = dist_df[['pretrained initialization', 'random initialization']]
    ax = data_df.plot.barh(
        color=['#0499CC', '#FFFFFF'], 
        xerr=err_df, lw=1, edgecolor='black')
    ax.set_xlabel('Mean distance from initialization')
    ax.set_ylabel('mu')
    legend = plt.legend(loc='center left', bbox_to_anchor=(0.4, 1.15))  
    plt.savefig("naacl18-short/img/distances.pdf", 
                bbox_extra_artists=(legend,), 
                bbox_inches='tight')

data_df = pd.read_csv("data_df.csv")

import pandas as pd

data_df.to_csv("data_df.csv")

dist_df = data_df.groupby('mu').apply(get_dist_stats)


dist_df

distance_plot(data_df)

from matplotlib import rc


rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
## for Palatino and other serif fonts use:
#rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)

