'''
Reference implementation of node2vec.

Author: Aditya Grover

For more details, refer to the paper:
node2vec: Scalable Feature Learning for Networks
Aditya Grover and Jure Leskovec
Knowledge Discovery and Data Mining (KDD), 2016
'''

import argparse
import numpy as np
import networkx as nx
import node2vec
from gensim.models import Word2Vec
from tqdm import tqdm
import random


def parse_args():
    '''
    Parses the node2vec arguments.
    '''
    parser = argparse.ArgumentParser(description="Run node2vec.")
    # colex_from_AllBabelNet_Concepts/cross_colex_binary.edgelist
    parser.add_argument('--input', nargs='?', default='dummy.edgelist', help='Input graph path')
    # colex_from_AllBabelNet_Concepts/cross_colex_binary_MYNode2Vec
    parser.add_argument('--output', nargs='?', default='dummy', help='Embeddings path')

    parser.add_argument('--dimensions', type=int, default=300,
                        help='Number of dimensions. Default is 300 to match fastText.')

    parser.add_argument('--walk-length', type=int, default=80,
                        help='Length of walk per source. Default is 80.')

    parser.add_argument('--num-walks', type=int, default=10,
                        help='Number of walks per source. Default is 10.')

    parser.add_argument('--window-size', type=str, default="5_10_20",
                        help='Context size(s) for optimization. Default is 10.')

    parser.add_argument('--iter', default=1, type=int,
                        help='Number of epochs in SGD')

    parser.add_argument('--workers', type=int, default=6,
                        help='Number of parallel workers. Default is 6.')

    parser.add_argument('--p', type=float, default=1,
                        help='Return hyperparameter. Default is 1.')

    parser.add_argument('--q', type=float, default=1,
                        help='Inout hyperparameter. Default is 1.')

    parser.add_argument('--weighted', dest='weighted', action='store_true',
                        help='Boolean specifying (un)weighted. Default is weighted.')
    parser.add_argument('--unweighted', dest='unweighted', action='store_false')
    parser.set_defaults(weighted=True)

    parser.add_argument('--directed', dest='directed', action='store_true',
                        help='Graph is (un)directed. Default is undirected.')
    parser.add_argument('--undirected', dest='undirected', action='store_false')
    parser.set_defaults(directed=False)  # our graphs are NEVER directed

    return parser.parse_args()


def read_edgelist():
    '''
    Reads the input network into a dictionary.
    '''
    file1 = open(args.input, 'r')
    Lines = file1.readlines()
    edgelist = {}  # initialize edgelist
    for line in tqdm(Lines):
        line = line[:-1]  # remove the newline character
        pieces = line.split(" ")
        node1 = pieces[0]
        node2 = pieces[1]
        weight = float(pieces[2])
        edge = node1 + "_" + node2
        if args.weighted:
            edge_weight = weight
        else:
            edge_weight = 1.0
        edgelist[edge] = edge_weight
    return edgelist

def create_graph(edgelist):
    G = {}
    for edge, weight in tqdm(edgelist.items()):
        node1 = edge.split("_")[0]
        node2 = edge.split("_")[1]
        """Add the edge to both nodes' connections in the graph"""
        if node1 in G:
            G[node1][node2] = weight
        else:
            G[node1] = {node2: weight}
        if node2 in G:
            G[node2][node1]: weight
        else:
            G[node2] = {node1: weight}
    return G

def preprocess_transition_probs(G):
    for src_node, connections_dict in tqdm(G.items()):
        total_weight = 0
        for dst_node, edge_weight in connections_dict.items():
            total_weight += edge_weight
        for dst_node, edge_weight in connections_dict.items():
            G[src_node][dst_node] = edge_weight / total_weight
    return G

def random_walks(G, num_walks, walk_length):
    """Generate random walks."""
    walks = []
    for walk_iter in range(num_walks):
        print("Walk " + str(walk_iter) + "...")
        for node, connections_dict in tqdm(G.items()):
            new_walk = [node]
            walk_node = node
            for walk_counter in range(walk_length):
                next_node_dict = G[walk_node]
                sample_list = []
                weights = []
                cum_probs = []
                cum_prob = 0
                for dst_node, prob in next_node_dict.items():
                    sample_list.append(dst_node)
                    cum_prob += prob
                    cum_probs.append(cum_prob)
                new_node = random.choices(sample_list, cum_weights=cum_probs, k=1)[0]
                new_walk.append(new_node)
                walk_node = new_node
            walks.append(new_walk)
    return walks

def learn_embeddings(walks):
    '''
    Learn embeddings by optimizing the Skipgram objective using SGD.
    '''
    new_walks = []
    for walk in walks:
        newlist = [str(x) for x in walk]
        new_walks.append(newlist)
    walks = new_walks
    # walks = [str(walk) for walk in walks]
    context_sizes = args.window_size
    context_sizes = [int(x) for x in context_sizes.split("_")]
    for cont_size in context_sizes:
        print("Context size " + str(cont_size) + "...")
        save_path = args.output + "_" + str(cont_size) + ".emb"
        model = Word2Vec(walks, size=args.dimensions, window=cont_size, min_count=0, sg=1, workers=args.workers,
                         iter=args.iter)
        model.wv.save_word2vec_format(save_path)
        print("Saved to " + save_path)

    return


def main(args):
    '''
    Pipeline for representational learning for all nodes in a graph.
    '''
    print("Context sizes: " + args.window_size)
    print("Loading graph into memory...")
    edgelist = read_edgelist()
    print("Done.")
    print("Creating custom graph object...")
    G = create_graph(edgelist)
    print("Done.")
    print("Preprocessing transition probabilities...")
    G = preprocess_transition_probs(G)
    print("Done.")
    print("Creating random walks...")
    walks = random_walks(G, args.num_walks, args.walk_length)
    print("Done.")
    # walks = [[1, 2], [3, 4]]  # you left this debugging statement in there!!!
    print("Running Word2Vec on walks...")
    learn_embeddings(walks)
    print("Done.")


if __name__ == "__main__":
    args = parse_args()
    main(args)
