import sys
import os

import argparse
import matplotlib.pyplot as plt
import numpy as np

def getParser():
    basedir = "./results"
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-x", "--base", type=str, help="base file x-axis", required=True)
    parser.add_argument("-y", "--hypothesis", type=str, help="combined file y-axis", required=True)
    parser.add_argument("-r", "--ranksdir", type=str, help="directory containing rank files", default=os.path.join(basedir, "ranks"))
    parser.add_argument("-p", "--plotsdir", type=str, help="directory to save the plots", default=os.path.join(basedir, "plots"))
    return parser

def render(careranks, bertranks, name, diffs, **kwargs):
    xlim = kwargs.get('xlim', 5000)
    ylim = kwargs.get('ylim', 5000)
    identity = kwargs.get('identity', [0, max(xlim, ylim)])
    fig, ax = plt.subplots()
    fontweight = kwargs.get('fontweight', 'bold')
    fontsize = kwargs.get('fontsize', 10)
    gap = 10*fontsize
    ax.set_xlim(0, xlim)
    ax.set_ylim(0, ylim)
    plt.plot(identity, identity, c='r')
    plt.xticks(fontsize=fontsize/2, weight=fontweight)
    plt.yticks(fontsize=fontsize/2, weight=fontweight)
    ax.set_xlabel('CaRE ranks', weight=fontweight, fontsize=fontsize)
    ax.set_ylabel('CaRE+BERT ranks', weight=fontweight, fontsize=fontsize)
    ax.set_title(name, weight=fontweight, fontsize=fontsize)
    ax.text(10, ylim-gap, str(diffs[1]), weight=fontweight, fontsize=fontsize)
    ax.text(xlim-1.6*gap, 10, str(diffs[0]), weight=fontweight, fontsize=fontsize)
    plt.scatter(careranks, bertranks, c='b')

def plotit(params):
    sp = "valid"
    basefile = os.path.join(params.ranksdir, "%s.%s.npy" % (params.hypothesis, sp))
    hypothesisfile = os.path.join(params.ranksdir, "%s.%s.npy" % (params.base, sp))
    care = np.load(basefile)
    bert = np.load(hypothesisfile)
    careranks = care[:,3]
    bertranks = bert[:,3]
    diff = careranks-bertranks
    diffs = [sum(diff>0), sum(diff<0)]
    # plt.scatter(careranks, bertranks)
    maxrank = max(careranks.max(), bertranks.max()) + 10
    identity = [0, maxrank]
    render(careranks, bertranks, params.hypothesis, diffs, identity=identity, xlim=maxrank, ylim=maxrank, fontsize=25, fontweight='bold')
    # plt.plot(identity, identity)
    outdir = os.path.join(params.plotsdir, params.hypothesis)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    outfile = os.path.join(outdir, 'rank_profile.%s.png' % sp)
    fig = plt.gcf()
    fig.set_size_inches(8,8)
    plt.savefig(outfile, dpi=100)
    plt.show()

def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    plotit(params)

if __name__ == "__main__":
    main()

