import sys
import os

import argparse
import numpy as np
import json
import multiprocessing as mp
from matplotlib import pyplot as plt
from MulticoreTSNE import MulticoreTSNE as TSNE

def getParser():
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--ds", type=str, help="dataset name", default="r20kf")
    parser.add_argument("--typefile", type=str, help="file containing entity types", default="r20kf")
    parser.add_argument("--entsfile", type=str, help="file containing list of entities (parallel to type file)", default="r20kf")
    parser.add_argument("--perplexity", type=int, help="TSNE perplexity", default=30)
    parser.add_argument("--niter", type=int, help="TSNE number of iterations", default=1000)
    return parser

def extractAllEnts(typefile, entsfile):
    with open(typefile, 'r') as fin:
        types = json.load(fin)
    entids = []
    locations = []
    orgs = []
    persons = []
    line_num = 0
    outdict = {}
    typeset = set()
    with open(entsfile, 'r') as fin:
        for line in fin:
            line = line.strip()
            if line:
                ent = int(line)
                entids.append(ent)
                key="reverb%d"%line_num
                cur_typs = types[key]['pred']
                for typ in cur_typs:
                    outdict.setdefault(typ, []).append(ent)
                    typeset.add(typ)
            line_num += 1
    new_outdict = {key:val for key, val in outdict.items() if len(val)>500}
    typeset = list(new_outdict.keys())
    return new_outdict, list(typeset)

def extractEnts(typefile, entsfile):
    with open(typefile, 'r') as fin:
        types = json.load(fin)
    entids = []
    locations = []
    orgs = []
    persons = []
    line_num = 0
    with open(entsfile, 'r') as fin:
        for line in fin:
            line = line.strip()
            if line:
                ent = int(line)
                entids.append(ent)
                key="reverb%d"%line_num
                cur_typs = types[key]['pred']
                # common = set(cur_typs).intersection(['/person', '/location']) # , '/organization'])
                # if len(common) != 1:
                    # continue
                if "/person" in cur_typs and '/location' not in cur_typs:
                    persons.append(ent)
                elif "/location" in cur_typs and '/person' not in cur_typs:
                    locations.append(ent)
                # elif "/organization" in cur_typs:
                #     orgs.append(ent)
            line_num += 1
    return {"location":locations, "organization":orgs, "person":persons}

def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)

def save_for_tensorboard(dataset, names):
    indir = os.path.join("emb", dataset)
    np_file = os.path.join(indir, 'np_embed')
    type_file = os.path.join(indir, 'type_embed')
    np_embed = np.load(np_file+".npy")
    type_embed = np.load(type_file+".npy")
    np.savetxt(np_file+".txt", np_embed, delimiter="\t")
    np.savetxt(type_file+".txt", type_embed, delimiter="\t")


def plotEntityTyping(params):
    perplexity = params.perplexity
    n_iter = params.niter
    indir = os.path.join("emb", params.ds)
    np_embed = np.load(os.path.join(indir, 'np_embed.npy'))
    type_embed = np.load(os.path.join(indir, 'type_embed.npy'))

    tsne = TSNE(n_components=2, n_jobs=mp.cpu_count()-1, perplexity=perplexity, n_iter=n_iter)
    np_tsne = tsne.fit_transform(np_embed)
    tsne = TSNE(n_components=2, n_jobs=mp.cpu_count()-1, perplexity=perplexity, n_iter=n_iter)
    type_tsne = tsne.fit_transform(type_embed)
    dataset = "ReVerb20KF"
    # plotit(np_tsne, entlist, entIDlist, dataset, os.path.join(indir, 'np.png'), axlabel="NP")
    # plotit(type_tsne, entlist, entIDlist, dataset, os.path.join(indir, 'type.png'), axlabel="Type")
    entIDlist, typelist = extractAllEnts(params.typefile, params.entsfile)
    print(typelist)
    colors = get_cmap(len(typelist))
    type2colors = {}
    for idx, typ in enumerate(typelist):
        type2colors[typ] = np.array(colors(idx)).reshape(1,4)
    plotit(np_tsne, [], entIDlist, dataset, os.path.join(indir, f'np_tc.png'), axlabel="NP", K=1000000, colors=type2colors)
    plotit(type_tsne, [], entIDlist, dataset, os.path.join(indir, f'type_tc.png'), axlabel="Type", K=1000000, colors=type2colors)
    # plotit(np_tsne, [], entIDlist, dataset, os.path.join(indir, f'np_tc.{perplexity}.{n_iter}.png'), axlabel="NP", K=1000000)
    # plotit(type_tsne, [], entIDlist, dataset, os.path.join(indir, f'type_tc.{perplexity}.{n_iter}.png'), axlabel="Type", K=1000000)

def plotAll(params):
    indir = os.path.join("emb", params.ds)
    np_embed = np.load(os.path.join(indir, 'np_embed.npy'))
    type_embed = np.load(os.path.join(indir, 'type_embed.npy'))
    with open(os.path.join(indir, 'ent2id.json'), 'r') as fin:
        ent2id = json.load(fin)
    # with open(os.path.join(indir, 'id2ent.json'), 'r') as fin:
    #     id2ent = json.load(fin)
    labels = ['location', 'person', 'dates']# 'sports', 'organization', 'dates']
    entlist = {}
    entIDlist = {}
    for label in labels:
        entlist[label] = []
        entIDlist[label] = []
        with open(os.path.join(indir, "%s.txt"%label), 'r') as fin:
            for line in fin:
                entlist[label].append(line.strip())
                entIDlist[label].append(ent2id[line.strip()])
    # entIDlist = extractEnts(typefile, entsfile)
    tsne = TSNE(n_components=2, n_jobs=mp.cpu_count()-1, perplexity=10, n_iter=2000)
    np_tsne = tsne.fit_transform(np_embed)
    tsne = TSNE(n_components=2, n_jobs=mp.cpu_count()-1, perplexity=10, n_iter=2000)
    type_tsne = tsne.fit_transform(type_embed)
    dataset = "ReVerb20KF"
    # plotit(np_tsne, entlist, entIDlist, dataset, os.path.join(indir, 'np.png'), axlabel="NP")
    # plotit(type_tsne, entlist, entIDlist, dataset, os.path.join(indir, 'type.png'), axlabel="Type")
    entIDlist = extractEnts(params.typefile, params.entsfile)
    plotit(np_tsne, entlist, entIDlist, dataset, os.path.join(indir, 'np_tc.png'), axlabel="NP")
    plotit(type_tsne, entlist, entIDlist, dataset, os.path.join(indir, 'type_tc.png'), axlabel="Type")

def plotit(embed, entlist, entIDlist, dataset, fname=None, axlabel="TSNE", K=5, colors=None):
    # colors = {'query':'r', 'value':'b', 'all':'#999966', 'tail':'g'}
    # colors = {'location':'r', 'person':'b', 'dates':'g'}#, 'organization': 'c', 'dates':'y'}
    if colors is None:
        colors = {'location':'r', 'person':'b', 'sports':'g', 'organization': 'g', 'dates':'g'}
    fontsize = 20
    weight = 'medium'
    fig, ax = plt.subplots()
    plt.figure(figsize=(8,8))
    # plot the data points
    for label, idlist in entIDlist.items():
        plt.scatter(embed[idlist][:K, 0], embed[idlist][:K, 1], c=colors[label])
        # for idx, eid in enumerate(idlist[:K]):
        #     plt.annotate(entlist[label][idx], (embed[eid, 0], embed[eid, 1]), fontsize=0.6*fontsize, weight=weight)
    # plt.scatter(query_emb[0], query_emb[1], c=colors['query'])
    # plt.scatter(values_embs[:,0], values_embs[:,1], c=colors['value'])
    # plt.scatter(all_embs[:,0], all_embs[:,1], c=colors['all'], s=1)

    plt.xlabel('%s TSNE Dim-1'%axlabel, weight=weight, fontsize=fontsize)
    plt.ylabel('%s TSNE Dim-2'%axlabel, weight=weight, fontsize=fontsize)
    plt.title(dataset, weight=weight, fontsize=fontsize)
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(10)

    if fname is None:
        plt.show()
    else:
        plt.savefig(fname)

def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    plotEntityTyping(params)
    # entids = extractEnts(params.typefile, params.entsfile)

if __name__ == "__main__":
    main()

