'''
Date: 2021-06-10 09:42:07
LastEditors: Wu Xianze (wuxianze.0@bytedance.com)
LastEditTime: 2021-06-30 21:31:33
'''
import os
import json
import argparse
import seaborn as sns
import numpy as np
# from sklearn.manifold import TSNE
from MulticoreTSNE import MulticoreTSNE as TSNE
import matplotlib.pyplot as plt

def label_points(xs, ys, vals, ax):
    for (x, y, val) in zip(xs, ys, vals):
        ax.text(x + .02, y, str(val), fontdict={'size': 8})

def readJsonl(fname):
    datas = []
    with open(fname, 'r') as fin:
        for line in fin:
            datas.append(json.loads(line.strip()))
    print("read {} instances from {}".format(len(datas), fname))
    return datas

def visualizeTokenEmb(args):
    embedding_infos = readJsonl(args.emb)[:args.k]
    oracle_infos = readJsonl(args.oracle)[:args.k]
    if not os.path.exists(args.o):
        os.makedirs(args.o)

    for (emb_info, oracle_info) in zip(embedding_infos, oracle_infos):
        doc_id = emb_info['doc_id']
        emb_list = emb_info["embeddings"]
        embedding_array = np.array([item["embedding"] for item in emb_list])

        # model = TSNE(
        #     n_components=2, perplexity=args.p, n_iter=args.n_iter
        # )
        model = TSNE(
            n_components=2, n_jobs=10
        )
        node_pos = model.fit_transform(embedding_array)
        x_pos = []
        y_pos = []
        for (_, item) in enumerate(node_pos):
            x_pos.append(item[0])
            y_pos.append(item[1])

        oracle_idx = oracle_info['oracle_idx']
        is_oracles = [1 if i in oracle_idx else 0 for i in range(embedding_array.shape[0])]

        sns.scatterplot(x=x_pos, y=y_pos, hue=is_oracles)
        # label_points(x_pos, y_pos, ids, plt.gca())
        graph_path = os.path.join(args.o, "{}.png".format(doc_id))
        plt.savefig(graph_path)
        print("save an image to {}".format(graph_path))
        plt.clf()
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--emb', type=str, help="embedding file")
    parser.add_argument('--oracle', type=str, help="oracle file")
    parser.add_argument('-o', type=str, help="output dir", default="/opt/tiger/sumtest/graphs/token_emb/en_train_zh_test")
    parser.add_argument('-m', type=str, help="mode", default="visualizeTokenEmb")
    parser.add_argument('-k', type=int, help="top k instances", default=20)
    parser.add_argument(
        '--n-iter', type=int, help="number of iteration", 
        default=5000
    )
    parser.add_argument(
        '-p', type=int, help="perplexity", 
        default=30
    )
    args = parser.parse_args()

    eval("{}(args)".format(args.m))