'''
Reference implementation of node2vec.

Author: Aditya Grover

For more details, refer to the paper:
node2vec: Scalable Feature Learning for Networks
Aditya Grover and Jure Leskovec
Knowledge Discovery and Data Mining (KDD), 2016
'''

import argparse
import numpy as np
import networkx as nx
import node2vec
from gensim.models import Word2Vec


def parse_args():
    '''
    Parses the node2vec arguments.
    '''
    parser = argparse.ArgumentParser(description="Run node2vec.")

    parser.add_argument('--input', nargs='?', default='dummy.edgelist', help='Input graph path')

    parser.add_argument('--output', nargs='?', default='dummy', help='Embeddings path')

    parser.add_argument('--dimensions', type=int, default=300,
                        help='Number of dimensions. Default is 300 to match fastText.')

    parser.add_argument('--walk-length', type=int, default=80,
                        help='Length of walk per source. Default is 80.')

    parser.add_argument('--num-walks', type=int, default=10,
                        help='Number of walks per source. Default is 10.')

    parser.add_argument('--window-size', type=str, default="5_10_20",
                        help='Context size(s) for optimization. Default is 10.')

    parser.add_argument('--iter', default=1, type=int,
                        help='Number of epochs in SGD')

    parser.add_argument('--workers', type=int, default=6,
                        help='Number of parallel workers. Default is 6.')

    parser.add_argument('--p', type=float, default=1,
                        help='Return hyperparameter. Default is 1.')

    parser.add_argument('--q', type=float, default=1,
                        help='Inout hyperparameter. Default is 1.')

    parser.add_argument('--weighted', dest='weighted', action='store_true',
                        help='Boolean specifying (un)weighted. Default is weighted.')
    parser.add_argument('--unweighted', dest='unweighted', action='store_false')
    parser.set_defaults(weighted=True)

    parser.add_argument('--directed', dest='directed', action='store_true',
                        help='Graph is (un)directed. Default is undirected.')
    parser.add_argument('--undirected', dest='undirected', action='store_false')
    parser.set_defaults(directed=False)

    return parser.parse_args()


def read_graph():
    '''
    Reads the input network in networkx.
    '''
    if args.weighted:
        G = nx.read_edgelist(args.input, nodetype=int, data=(('weight', float),), create_using=nx.DiGraph())
    else:
        G = nx.read_edgelist(args.input, nodetype=int, create_using=nx.DiGraph())
        for edge in G.edges():
            G[edge[0]][edge[1]]['weight'] = 1

    if not args.directed:
        G = G.to_undirected()

    return G


def learn_embeddings(walks):
    '''
    Learn embeddings by optimizing the Skipgram objective using SGD.
    '''
    new_walks = []
    for walk in walks:
        newlist = [str(x) for x in walk]
        new_walks.append(newlist)
    walks = new_walks
    # walks = [str(walk) for walk in walks]
    context_sizes = args.window_size
    context_sizes = [int(x) for x in context_sizes.split("_")]
    for cont_size in context_sizes:
        print("Context size " + str(cont_size) + "...")
        save_path = args.output + "_" + str(cont_size) + ".emb"
        model = Word2Vec(walks, size=args.dimensions, window=cont_size, min_count=0, sg=1, workers=args.workers,
                         iter=args.iter)
        model.wv.save_word2vec_format(save_path)
        print("Saved to " + save_path)

    return


def main(args):
    '''
    Pipeline for representational learning for all nodes in a graph.
    '''
    print("Context sizes: " + args.window_size)
    print("Loading graph into memory...")
    nx_G = read_graph()
    print("Done.")
    print("Creating node2vec graph object...")
    G = node2vec.Graph(nx_G, args.directed, args.p, args.q)
    print("Done.")
    print("Preprocessing transition probabilities...")
    G.preprocess_transition_probs()
    print("Done.")
    print("Creating random walks...")
    walks = G.simulate_walks(args.num_walks, args.walk_length)
    print("Done.")
    # walks = [[1, 2], [3, 4]]  # you left this debugging statement in there!!!
    print("Running Word2Vec on walks...")
    learn_embeddings(walks)
    print("Done.")


if __name__ == "__main__":
    args = parse_args()
    main(args)
