import argparse
import pickle
import numpy as np



parser = argparse.ArgumentParser()

parser.add_argument('--summaries_path', type = str, default = "../summaries/")
parser.add_argument('--scored_summaries_path', type = str, default = "../reranking_data/")
parser.add_argument('--set', type = str, default = "test")
parser.add_argument('--dataset_size', type = str, default = 11490) # 11490
parser.add_argument('--model_name', type = str, default = "pegasus_unsupervised")
parser.add_argument('--num_beams', type = int, default = 15)
parser.add_argument('--label_metric', type = str, default = "mean_rouge") # in ["rouge_1", "rouge_2", "rouge_l", "mean_rouge"]

args = parser.parse_args()

print("*"*50)
print(args)


def main(args):

    # load summaries + scores
    scored_summaries_path = args.scored_summaries_path + "{}/{}_scored_summaries_{}_{}_beams_{}.pkl".format(args.set, args.set, args.model_name, args.dataset_size, args.num_beams)
    with open(scored_summaries_path, "rb") as f:
        scored_summaries = pickle.load(f)
        print("loaded new data!", len(scored_summaries))

    # collect scores
    all_scores = np.array([np.array(scored_summaries[i][1]) for i in range(len(scored_summaries))])
    print(all_scores.shape)

    # overall 0_stats
    avg_min = np.mean(np.min(all_scores, 1))
    avg_q1 = np.mean(np.percentile(all_scores, 25, 1))
    avg_mean = np.mean(np.mean(all_scores, 1))
    avg_q3 = np.mean(np.percentile(all_scores, 75, 1))
    avg_max = np.mean(np.max(all_scores, 1))
    print("\nAvg min / q1 / mean / q3 / max ROUGE scores across set of candidates:")
    print(avg_min, avg_q1, avg_mean, avg_q3, avg_max)
    avg_spread = np.mean(np.max(all_scores, 1) - np.min(all_scores, 1))
    avg_var = np.mean(np.std(all_scores, 1))
    avg_unique = np.mean([len(np.unique(x)) for x in all_scores])
    print("Avg spread (max-min) / std / # unique scores across set of candidates:")
    print(avg_spread, avg_var, avg_unique)
    print("\n")

    # build ranks
    #all_ranks = np.array([args.num_beams - np.argsort(np.argsort(x)) for x in all_scores])
    all_ranks = np.array([rank(x) for x in all_scores])
    print(all_ranks.shape)

    print(all_scores[0:2])
    print(all_ranks[0:2])

    # mean score per beam search rank
    mean_scores = np.mean(all_scores, axis = 0)
    print("\nMean scores for each beam search rank:")
    for i in range(mean_scores.shape[0]):
        print(i, "mean score: {:.4f}".format(mean_scores[i]))

    # mean rank per beam search rank
    mean_ranks = np.mean(all_ranks, axis = 0)
    print("\nMean ranks for each beam search rank:")
    for i in range(mean_ranks.shape[0]):
        print(i, "mean rank: {:.4f}".format(mean_ranks[i]))

    # distribution of top score
    counts = np.zeros(args.num_beams)
    for i in range(all_ranks.shape[0]):
        pos_top_rank = np.argmin(all_ranks[i])
        #print("*"*20)
        #print(all_scores[i])
        #print(all_ranks[i])
        #print(pos_top_rank)
        counts[pos_top_rank] += 1
    counts /= len(all_ranks)
    counts *= 100
    print("\n% of times each beam search rank is the top score:")
    for i in range(counts.shape[0]):
        print(i, "fraction of times this rank is the best score: {:.4f}".format(counts[i]))


def rank(x):
    ranks = np.zeros(len(x))
    y = np.copy(x)
    y.sort()
    y = y[::-1]
    tags = np.zeros(len(x))
    for i in range(len(x)):
        for j in range(len(x)):
            if x[i] == y[j] and tags[j] == 0:
                ranks[i] = j
                tags[j] = 1
                break
    return ranks



if __name__ == '__main__':

    main(args)
