import sys
import os

import argparse
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

def getParser():
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-i", "--interval", type=int, help="animation interval(milliseconds)", default=1000)
    parser.add_argument("-r", "--ranks", type=str, help="directory containing the ranks", default="ranks")
    parser.add_argument("-d", "--dataset", type=str, help="dataset", default="reverb20k")
    parser.add_argument("-m", "--maxrank", type=int, help="maximum rank (limit for plot)", default=12000)
    parser.add_argument("--save", dest="save", action="store_true", default=False, help="flag to save individual outputs")
    parser.add_argument("-o", "--outdir", type=str, help="output directory to save the plots", default="plots")
    return parser

fig, ax = plt.subplots()
xdata, ydata = [], []
xlim = 12000
ylim = xlim
ln = plt.scatter([], [])
identity = [0, 20000]
gap = 700
berttext = ax.text(gap//2, ylim-gap, 0)
caretext = ax.text(xlim-gap, gap//2, 0)


def init():
    ax.set_xlim(0, xlim)
    ax.set_ylim(0, ylim)
    plt.plot(identity, identity, c='r')
    plt.xlabel('CaRE ranks')
    plt.ylabel('CaRE+BERT ranks')
    berttext.set_position((gap//2, ylim-gap))
    caretext.set_position((xlim-gap, gap//2))
    return ln,

def update(i, ranks, names, diffs):
    fig.suptitle(names[i])
    berttext.set_text(str(diffs[i][1]))
    caretext.set_text(str(diffs[i][0]))
    ln.set_offsets(np.vstack([ranks[i][0], ranks[i][1]]).T)
    return ln,

def render(i, ranks, names, diffs, **kwargs):
    global xlim
    global ylim
    fig, ax = plt.subplots()
    fontweight = kwargs.get('fontweight', 'bold')
    fontsize = kwargs.get('fontsize', 10)
    ax.set_xlim(0, xlim)
    ax.set_ylim(0, ylim)
    plt.plot(identity, identity, c='r')
    plt.xticks(fontsize=fontsize/2, weight=fontweight)
    plt.yticks(fontsize=fontsize/2, weight=fontweight)
    ax.set_xlabel('CaRE ranks', weight=fontweight, fontsize=fontsize)
    ax.set_ylabel('CaRE+BERT ranks', weight=fontweight, fontsize=fontsize)
    ax.set_title(names[i], weight=fontweight, fontsize=fontsize)
    ax.text(10, ylim-gap, str(diffs[i][1]), weight=fontweight, fontsize=fontsize)
    ax.text(xlim-1.6*gap, 10, str(diffs[i][0]), weight=fontweight, fontsize=fontsize)
    plt.scatter(ranks[i][0], ranks[i][1], c='b')

def readdata(params):
    ranks = []
    filename = os.path.join(params.ranks, params.dataset, "%.2f_CaRE_%.2f_BERT.npy")
    files = []
    files.append(filename%(0.0, 1.0))
    files.append(filename%(1.0, 0.0))
    wts = [0.2, 0.4, 0.6, 0.8]
    # for wt in wts:
    #     files.append(filename%(wt, 1.0))
    for wt in wts:
        files.append(filename%(1.0, wt))
        files.append(filename%(wt, 1.0))
    files.append(filename%(1.0, 1.0))
    # files.append(filename%(1.0, 0.0))
    basefile = filename%(1.0, 0.0)
    care = np.load(basefile)[:,3]
    names = []
    diffs = []
    for curfile in files:
        bert = np.load(curfile)[:,3]
        ranks.append((care, bert))
        diff = care-bert
        diffs.append([sum(diff>0), sum(diff<0)])
        names.append(curfile.replace('.npy','').split('/')[-1])
    return ranks, names, diffs

def plotit(params):
    global xlim, ylim
    xlim = params.maxrank
    ylim = params.maxrank
    ranks, names, diffs = readdata(params)
    ani = FuncAnimation(fig,
                        update,
                        frames=range(len(ranks)),
                        fargs=(ranks, names, diffs),
                        init_func=init,
                        # blit=True,
                        interval=params.interval,
                        repeat=False)
    plt.show()

def saveit(params):
    plt.close()
    global xlim, ylim
    xlim = params.maxrank
    ylim = params.maxrank
    ranks, names, diffs = readdata(params)
    for i in range(len(ranks)):
        render(i, ranks, names, diffs, xlim=params.maxrank, ylim=params.maxrank, fontsize=40, fontweight='bold')
        outfile = os.path.join(params.outdir, params.dataset, "%s.png" % names[i])
        fig = plt.gcf()
        fig.set_size_inches(16,10)
        plt.savefig(outfile, dpi=100)

def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    if params.save:
        saveit(params)
    else:
        plotit(params)

if __name__ == "__main__":
    main()

