import sys
import json
import h5py
import numpy as np
from timeit import default_timer as timer

import torch
from torch.autograd import Variable
import torch.nn.functional as F

import options
import visdial.metrics as metrics
from utils import utilities as utils
from dataloader import VisDialDataset
from torch.utils.data import DataLoader

from sklearn.metrics.pairwise import pairwise_distances

from six.moves import range
import visdial.loss.loss_utils as loss_utils
from visdial.loss.infoGain import imgPrior, prepareBatch, normProb
from visdial.loss.rl_imGuess_loss import Ranker, rl_rollout_search

pairwiseRanking_criterion = loss_utils.PairwiseRankingLoss(margin=0.1)


def imgLoader(dataloader, dataset):
    all_im_feat = []
    for idx, batch in enumerate(dataloader):
        batch = prepareBatch(dataset, batch)
        all_im_feat.append(Variable(batch['img_feat'], requires_grad=False))
    all_im_feat = torch.cat(all_im_feat, 0)
    return all_im_feat


def rankQBot(qBot, dataset, split, exampleLimit=None, verbose=0, im_retrieval_mode='mse'):
    '''
        Evaluates Q-Bot performance on image retrieval when it is shown
        ground truth captions, questions and answers. Q-Bot does not
        generate dialog in this setting - it only encodes ground truth
        captions and dialog in order to perform image retrieval by
        predicting FC-7 image features after each round of dialog.

        Arguments:
            qBot    : Q-Bot
            dataset : VisDialDataset instance
            split   : Dataset split, can be 'val' or 'test'

            exampleLimit : Maximum number of data points to use from
                           the dataset split. If None, all data points.
    '''
    print("image retrieval mode is: {}".format(im_retrieval_mode))
    batchSize = dataset.batchSize
    numRounds = dataset.numRounds
    if exampleLimit is None:
        numExamples = dataset.numDataPoints[split]
    else:
        numExamples = exampleLimit
    numBatches = (numExamples - 1) // batchSize + 1
    original_split = dataset.split
    dataset.split = split
    dataloader = DataLoader(
        dataset,
        batch_size=batchSize,
        shuffle=True,
        num_workers=0,
        collate_fn=dataset.collate_fn)

    # enumerate all gt features and all predicted features
    gtImgFeatures = []
    # caption + dialog rounds
    roundwiseFeaturePreds = [[] for _ in range(numRounds + 1)]
    logProbsAll = [[] for _ in range(numRounds)]
    featLossAll = [[] for _ in range(numRounds + 1)]
    # Added by Mingyang Zhou for Perplexity Computation
    perplexityAll = [[] for _ in range(numRounds)]
    start_t = timer()

    # Modified by Mingyang Zhou
    # Record the wining rates for the questioner in multiple games
    win_rate = [0] * (numRounds + 1)
    num_games = 0

    # Modified by Mingyang Zhou
    all_im_feat = imgLoader(dataloader, dataset)
    im_ranker = Ranker()

    for idx, batch in enumerate(dataloader):
        if idx == numBatches:
            break

        if dataset.useGPU:
            batch = {
                key: v.cuda()
                for key, v in batch.items() if hasattr(v, 'cuda')
            }
        else:
            batch = {
                key: v.contiguous()
                for key, v in batch.items() if hasattr(v, 'cuda')
            }
        # caption = Variable(batch['cap'], volatile=True)
        # captionLens = Variable(batch['cap_len'], volatile=True)
        # gtQuestions = Variable(batch['ques'], volatile=True)
        # gtQuesLens = Variable(batch['ques_len'], volatile=True)
        # answers = Variable(batch['ans'], volatile=True)
        # ansLens = Variable(batch['ans_len'], volatile=True)
        # gtFeatures = Variable(batch['img_feat'], volatile=True)
        with torch.no_grad():
            caption = Variable(batch['cap'])
            captionLens = Variable(batch['cap_len'])
            gtQuestions = Variable(batch['ques'])
            gtQuesLens = Variable(batch['ques_len'])
            answers = Variable(batch['ans'])
            ansLens = Variable(batch['ans_len'])
            if im_retrieval_mode == "mse":
                if qBot.imgEncodingMode == "imGuess":
                    gtFeatures = qBot.forwardImage(Variable(batch['img_feat']))
                else:
                    gtFeatures = Variable(batch['img_feat'])
            else:
                gtFeatures = Variable(batch['img_feat'])
                gtFeatures = qBot.multimodalpredictIm(gtFeatures)
            image = Variable(batch['img_feat'])  # Added by Mingyang Zhou
            # Update the Ranker
            if qBot.imgEncodingMode == "imGuess":
                im_ranker.update_rep(qBot, all_im_feat)

            qBot.reset()
            qBot.observe(-1, caption=caption, captionLens=captionLens)
            if qBot.new_questioner:
                qBot.observe_im(image)

            if qBot.imgEncodingMode == "imGuess":
                act_index = torch.randint(
                    0, all_im_feat.size(0) - 1, (image.size(0), 1))
                predicted_image = all_im_feat[act_index].squeeze(1)
                qBot.observe_im(predicted_image)

            if im_retrieval_mode == "mse":
                predFeatures = qBot.predictImage()
                # Evaluating round 0 feature regression network
                featLoss = F.mse_loss(predFeatures, gtFeatures)
                #featLoss = F.mse_loss(predFeatures, gtFeatures)
                featLossAll[0].append(torch.mean(featLoss))
                # Keeping round 0 predictions
                roundwiseFeaturePreds[0].append(predFeatures)

                # Modified by Mingyang Zhou for imgEncoding Mode == "imGuess"
                if qBot.imgEncodingMode == "imGuess":
                    # act_index = im_ranker.nearest_neighbor(
                    #     predFeatures.data, all_im_feat)
                    act_index = im_ranker.nearest_neighbor(
                        predFeatures.data)
                    predicted_image = all_im_feat[act_index]

                # Compute the winning rate at round 0, modified by Mingyang
                # Zhou
                round_dists = pairwise_distances(
                    predFeatures.cpu().numpy(), gtFeatures.cpu().numpy())

                for i in range(round_dists.shape[0]):
                    current_rank = int(
                        np.where(round_dists[i, :].argsort() == i)[0]) + 1
                    if current_rank <= 1:
                        win_rate[0] += 1
                    # update the num_games
                    num_games += 1

            elif im_retrieval_mode == "cosine_similarity":
                dialogEmbedding = qBot.multimodalpredictText()
                featLoss = pairwiseRanking_criterion(
                    gtFeatures, dialogEmbedding)
                featLossAll[0].append(torch.sum(featLoss))
                roundwiseFeaturePreds[0].append(
                    dialogEmbedding)
                # Initailize the round_dists, with each row as the cosine
                # similarity
                round_dists = np.matmul(
                    dialogEmbedding.cpu().numpy(), gtFeatures.cpu().numpy().transpose())
                for i in range(round_dists.shape[0]):
                    current_rank = int(
                        np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                    if current_rank <= 1:
                        win_rate[0] += 1
                    # update the num_games
                    num_games += 1

            # convert gtFeatures back to tensor
            # gtFeatures = torch.from_numpy(gtFeatures)

            for round in range(numRounds):
                if qBot.imgEncodingMode == "imGuess":
                    qBot.observe_im(predicted_image)
                qBot.observe(
                    round,
                    ques=gtQuestions[:, round],
                    quesLens=gtQuesLens[:, round])
                qBot.observe(
                    round, ans=answers[:, round], ansLens=ansLens[:, round])
                logProbsCurrent = qBot.forward()

                # Evaluating logProbs for cross entropy
                logProbsAll[round].append(
                    utils.maskedNll(logProbsCurrent,
                                    gtQuestions[:, round].contiguous()))
                perplexityAll[round].append(utils.maskedPerplexity(logProbsCurrent,
                                                                   gtQuestions[:, round].contiguous()))

                if im_retrieval_mode == "mse":
                    predFeatures = qBot.predictImage()
                    # Evaluating feature regression network

                    # Deal with different imgEncodingMode
                    featLoss = F.mse_loss(predFeatures, gtFeatures)

                    featLossAll[round + 1].append(torch.mean(featLoss))
                    # Keeping predictions
                    roundwiseFeaturePreds[round + 1].append(predFeatures)

                    # Modified by Mingyang Zhou
                    if qBot.imgEncodingMode == "imGuess":
                        # act_index = im_ranker.nearest_neighbor(
                        #     predFeatures.data, all_im_feat)
                        act_index = im_ranker.nearest_neighbor(
                            predFeatures.data)
                        predicted_image = all_im_feat[act_index].squeeze(1)

                    # Compute the winning rate at round 0, modified by Mingyang
                    # Zhou
                    round_dists = pairwise_distances(
                        predFeatures.cpu().numpy(), gtFeatures.cpu().numpy())
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort() == i)[0]) + 1
                        if current_rank <= 1:
                            win_rate[round + 1] += 1

                elif im_retrieval_mode == "cosine_similarity":
                    dialogEmbedding = qBot.multimodalpredictText()
                    featLoss = pairwiseRanking_criterion(
                        gtFeatures, dialogEmbedding)
                    featLossAll[round + 1].append(torch.sum(featLoss))
                    roundwiseFeaturePreds[round + 1].append(
                        dialogEmbedding)  # Keep the dialogEmbedding, To be modified later.
                    # Initailize the round_dists, with each row as the cosine
                    # similarity
                    round_dists = np.matmul(
                        dialogEmbedding.cpu().numpy(), gtFeatures.cpu().numpy().transpose())
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                        if current_rank <= 1:
                            win_rate[round + 1] += 1

                # convert gtFeatures back to tensor
                # gtFeatures = torch.from_numpy(gtFeatures)

            gtImgFeatures.append(gtFeatures)

            end_t = timer()
            delta_t = " Time: %5.2fs" % (end_t - start_t)
            start_t = end_t
            progressString = "\r[Qbot] Evaluating split '%s' [%d/%d]\t" + delta_t
            sys.stdout.write(progressString % (split, idx + 1, numBatches))
            sys.stdout.flush()

    sys.stdout.write("\n")
    # Compute the win_rate, modified by Mingyang Zhou
    win_rate = [x / num_games for x in win_rate]
    print("The winning rates for {} are: {}".format(split, win_rate))

    gtFeatures = torch.cat(gtImgFeatures, 0).data.cpu().numpy()
    rankMetricsRounds = []
    poolSize = len(dataset)

    # Keeping tracking of feature regression loss and CE logprobs
    # logProbsAll = [torch.cat(lprobs, 0).mean() for lprobs in logProbsAll]
    # featLossAll = [torch.cat(floss, 0).mean() for floss in featLossAll]
    # roundwiseLogProbs = torch.cat(logProbsAll, 0).data.cpu().numpy()
    # roundwiseFeatLoss = torch.cat(featLossAll, 0).data.cpu().numpy()
    logProbsAll = [torch.stack(lprobs, 0).mean() for lprobs in logProbsAll]
    # Compute the Mean Perplexity for each round
    perplexityAll = [torch.cat(perplexity, 0).mean().data.item()
                     for perplexity in perplexityAll]

    featLossAll = [torch.stack(floss, 0).mean() for floss in featLossAll]
    roundwiseLogProbs = torch.stack(logProbsAll, 0).data.cpu().numpy()
    roundwiseFeatLoss = torch.stack(featLossAll, 0).data.cpu().numpy()
    # Compute the Mean Perplexity over all rounds
    # roundwisePerplexity = torch.stack(perplexityAll, 0).data.cpu().numpy()
    logProbsMean = roundwiseLogProbs.mean()
    featLossMean = roundwiseFeatLoss.mean()
    perplexityMean = sum(perplexityAll) / len(perplexityAll)
    print("The Perplxity of current Questioner is: {}".format(perplexityMean))
    # Added by Mingyang Zhou
    winrateMean = sum(win_rate) / len(win_rate)

    if verbose:
        print("Percentile mean rank (round, mean, low, high)")
    for round in range(numRounds + 1):
        if im_retrieval_mode == "mse":
            predFeatures = torch.cat(roundwiseFeaturePreds[round],
                                     0).data.cpu().numpy()
            # num_examples x num_examples
            dists = pairwise_distances(predFeatures, gtFeatures)
            ranks = []
            for i in range(dists.shape[0]):
                rank = int(np.where(dists[i, :].argsort() == i)[0]) + 1
                ranks.append(rank)
        elif im_retrieval_mode == "cosine_similarity":
            predFeatures = torch.cat(roundwiseFeaturePreds[round],
                                     0).data.cpu().numpy()
            dists = np.matmul(predFeatures, gtFeatures.transpose())
            ranks = []
            for i in range(dists.shape[0]):
                rank = int(np.where(dists[i, :].argsort()[::-1] == i)[0]) + 1
                ranks.append(rank)

        ranks = np.array(ranks)
        rankMetrics = metrics.computeMetrics(Variable(torch.from_numpy(ranks)))
        meanRank = ranks.mean()
        se = ranks.std() / np.sqrt(poolSize)
        meanPercRank = 100 * (1 - (meanRank / poolSize))
        percRankLow = 100 * (1 - ((meanRank + se) / poolSize))
        percRankHigh = 100 * (1 - ((meanRank - se) / poolSize))
        if verbose:
            print((round, meanPercRank, percRankLow, percRankHigh))
        rankMetrics['percentile'] = meanPercRank
        rankMetrics['featLoss'] = roundwiseFeatLoss[round]
        if round < len(roundwiseLogProbs):
            rankMetrics['logProbs'] = roundwiseLogProbs[round]
        rankMetricsRounds.append(rankMetrics)

    rankMetricsRounds[-1]['logProbsMean'] = logProbsMean
    rankMetricsRounds[-1]['featLossMean'] = featLossMean
    rankMetricsRounds[-1]['winrateMean'] = winrateMean
    # Added the perplexity in eval metrics
    rankMetricsRounds[-1]['perplexityMean'] = perplexityMean

    dataset.split = original_split
    return rankMetricsRounds[-1], rankMetricsRounds


def rankQABots(qBot, aBot, dataset, split, exampleLimit=None, beamSize=1, im_retrieval_mode='mse'):
    '''
        Evaluates Q-Bot and A-Bot performance on image retrieval where
        both agents must converse with each other without any ground truth
        dialog. The common caption shown to both agents is not the ground
        truth caption, but is instead a caption generated (pre-computed)
        by a pre-trained captioning model (neuraltalk2).

        Arguments:
            qBot    : Q-Bot
            aBot    : A-Bot
            dataset : VisDialDataset instance
            split   : Dataset split, can be 'val' or 'test'

            exampleLimit : Maximum number of data points to use from
                           the dataset split. If None, all data points.
            beamSize     : Beam search width for generating utterrances
    '''
    print("Image Encoding Mode is: {}".format(qBot.imgEncodingMode))
    batchSize = dataset.batchSize
    numRounds = dataset.numRounds
    if exampleLimit is None:
        numExamples = dataset.numDataPoints[split]
    else:
        numExamples = exampleLimit
    numBatches = (numExamples - 1) // batchSize + 1
    original_split = dataset.split
    dataset.split = split
    dataloader = DataLoader(
        dataset,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        collate_fn=dataset.collate_fn)

    gtImgFeatures = []
    roundwiseFeaturePreds = [[] for _ in range(numRounds + 1)]
    # Added by Mingyang Zhou for Perplexity Computation
    # perplexityAll = [[] for _ in range(numRounds)]

    start_t = timer()

    # Defined by Mingyang Zhou
    win_rate = [0] * (numRounds + 1)
    num_games = 0

    # Modified by Mingyang Zhou
    all_im_feat = imgLoader(dataloader, dataset)
    im_ranker = Ranker()

    # Update the Ranker
    qBot.eval(), qBot.reset()
    if qBot.imgEncodingMode == "imGuess":
        im_ranker.update_rep(qBot, all_im_feat)

    for idx, batch in enumerate(dataloader):
        if idx == numBatches:
            break

        if dataset.useGPU:
            batch = {key: v.cuda() for key, v in batch.items()
                     if hasattr(v, 'cuda')}
        else:
            batch = {key: v.contiguous() for key, v in batch.items()
                     if hasattr(v, 'cuda')}

        # caption = Variable(batch['cap'], volatile=True)
        # captionLens = Variable(batch['cap_len'], volatile=True)
        # gtQuestions = Variable(batch['ques'], volatile=True)
        # gtQuesLens = Variable(batch['ques_len'], volatile=True)
        # answers = Variable(batch['ans'], volatile=True)
        # ansLens = Variable(batch['ans_len'], volatile=True)
        # gtFeatures = Variable(batch['img_feat'], volatile=True)
        # image = Variable(batch['img_feat'], volatile=True)
        with torch.no_grad():
            caption = Variable(batch['cap'])
            captionLens = Variable(batch['cap_len'])
            gtQuestions = Variable(batch['ques'])
            gtQuesLens = Variable(batch['ques_len'])
            answers = Variable(batch['ans'])
            ansLens = Variable(batch['ans_len'])
            if im_retrieval_mode == "mse":
                if qBot.imgEncodingMode == "imGuess":
                    gtFeatures = qBot.forwardImage(Variable(batch['img_feat']))
                else:
                    gtFeatures = Variable(batch['img_feat'])
            else:
                gtFeatures = Variable(batch['img_feat'])
                gtFeatures = qBot.multimodalpredictIm(gtFeatures)
            image = Variable(batch['img_feat'])

            aBot.eval(), aBot.reset()
            aBot.observe(-1, image=image, caption=caption,
                         captionLens=captionLens)
            qBot.eval(), qBot.reset()
            qBot.observe(-1, caption=caption, captionLens=captionLens)
            if qBot.new_questioner:
                qBot.observe_im(image)

            if qBot.imgEncodingMode == "imGuess":
                act_index = torch.randint(
                    0, all_im_feat.size(0) - 1, (image.size(0), 1))
                predicted_image = all_im_feat[act_index].squeeze(1)
                qBot.observe_im(predicted_image)

            if im_retrieval_mode == "mse":
                predFeatures = qBot.predictImage()
                roundwiseFeaturePreds[0].append(predFeatures)

                # Modified by Mingyang Zhou for imgEncoding Mode == "imGuess"
                if qBot.imgEncodingMode == "imGuess":
                    # act_index = im_ranker.nearest_neighbor(
                    #     predFeatures.data, all_im_feat)
                    act_index = im_ranker.nearest_neighbor(
                        predFeatures.data)
                    predicted_image = all_im_feat[act_index]
                    # Should observe the current predicted image
                    qBot.observe_im(predicted_image)

                # Compute the winning rate at round 0, modified by Mingyang
                # Zhou
                round_dists = pairwise_distances(
                    predFeatures.cpu().numpy(), gtFeatures.cpu().numpy())
                for i in range(round_dists.shape[0]):
                    current_rank = int(
                        np.where(round_dists[i, :].argsort() == i)[0]) + 1
                    if current_rank <= 1:
                        win_rate[0] += 1
                    # update the num_games
                    num_games += 1
            elif im_retrieval_mode == "cosine_similarity":
                dialogEmbedding = qBot.multimodalpredictText()
                roundwiseFeaturePreds[0].append(
                    dialogEmbedding)
                # Initailize the round_dists, with each row as the cosine
                # similarity
                round_dists = np.matmul(
                    dialogEmbedding.cpu().numpy(), gtFeatures.cpu().numpy().transpose())
                for i in range(round_dists.shape[0]):
                    current_rank = int(
                        np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                    if current_rank <= 1:
                        win_rate[0] += 1
                    # update the num_games
                    num_games += 1

            for round in range(numRounds):
                # questions, quesLens = qBot.forwardDecode(
                #     inference='greedy', beamSize=beamSize)
                questions, quesLens = qBot.forwardDecode(
                    inference='greedy', beamSize=beamSize)
                # print(logProbsCurrent.size())
                qBot.observe(round, ques=questions, quesLens=quesLens)
                aBot.observe(round, ques=questions, quesLens=quesLens)
                # answers, ansLens = aBot.forwardDecode(
                #     inference='greedy', beamSize=beamSize)
                answers, ansLens = aBot.forwardDecode(
                    inference='greedy', beamSize=beamSize)
                aBot.observe(round, ans=answers, ansLens=ansLens)
                qBot.observe(round, ans=answers, ansLens=ansLens)
                if qBot.new_questioner:
                    qBot.observe_im(image)
                if qBot.imgEncodingMode == "imGuess":
                    qBot.observe_im(predicted_image)

                # Added by Mingyang Zhou
                # logProbsCurrent = qBot.forward()
                # perplexityAll[round].append(utils.maskedPerplexity(logProbsCurrent,
                # gtQuestions[:, round].contiguous()))
                if im_retrieval_mode == "mse":
                    predFeatures = qBot.predictImage()
                    roundwiseFeaturePreds[round + 1].append(predFeatures)

                    # Modified by Mingyang Zhou for imgEncoding Mode ==
                    # "imGuess"
                    if qBot.imgEncodingMode == "imGuess":
                        # act_index = im_ranker.nearest_neighbor(
                        #     predFeatures.data, all_im_feat)
                        act_index = im_ranker.nearest_neighbor(
                            predFeatures.data)
                        predicted_image = all_im_feat[act_index]
                    # Compute the winning rate at round 0, modified by Mingyang
                    # Zhou
                    round_dists = pairwise_distances(
                        predFeatures.cpu().numpy(), gtFeatures.cpu().numpy())
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort() == i)[0]) + 1
                        if current_rank <= 1:
                            win_rate[round + 1] += 1
                elif im_retrieval_mode == "cosine_similarity":
                    dialogEmbedding = qBot.multimodalpredictText()
                    roundwiseFeaturePreds[round + 1].append(
                        dialogEmbedding)  # Keep the dialogEmbedding, To be modified later.
                    # Initailize the round_dists, with each row as the cosine
                    # similarity
                    round_dists = np.matmul(
                        dialogEmbedding.cpu().numpy(), gtFeatures.cpu().numpy().transpose())
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                        if current_rank <= 1:
                            win_rate[round + 1] += 1

            gtImgFeatures.append(gtFeatures)

            end_t = timer()
            delta_t = " Rate: %5.2fs" % (end_t - start_t)
            start_t = end_t
            progressString = "\r[Qbot] Evaluating split '%s' [%d/%d]\t" + delta_t
            sys.stdout.write(progressString % (split, idx + 1, numBatches))
            sys.stdout.flush()
    sys.stdout.write("\n")
    # Compute the win_rate, modified by Mingyang Zhou
    win_rate = [x / num_games for x in win_rate]
    print("The winning rates for {} are: {}".format(split, win_rate))

    gtFeatures = torch.cat(gtImgFeatures, 0).data.cpu().numpy()
    rankMetricsRounds = []
    # Added by Mingyang Zhou
    # perplexityAll = [sum(perplexity) / len(perplexity)
    #                  for perplexity in perplexityAll]
    # perplexityMean = sum(perplexityAll) / len(perplexityAll)
    # print("The Perplxity of current Questioner in the Dialog with a User Simulator is: {}".format(
    #     perplexityMean))

    winrateMean = sum(win_rate) / len(win_rate)
    print("Percentile mean rank (round, mean, low, high)")
    for round in range(numRounds + 1):
        if im_retrieval_mode == "mse":
            predFeatures = torch.cat(roundwiseFeaturePreds[round],
                                     0).data.cpu().numpy()
            dists = pairwise_distances(predFeatures, gtFeatures)
            # num_examples x num_examples
            ranks = []
            for i in range(dists.shape[0]):
                # Computing rank of i-th prediction vs all images in split
                rank = int(np.where(dists[i, :].argsort() == i)[0]) + 1
                ranks.append(rank)
        elif im_retrieval_mode == "cosine_similarity":
            predFeatures = torch.cat(roundwiseFeaturePreds[round],
                                     0).data.cpu().numpy()
            dists = np.matmul(predFeatures, gtFeatures.transpose())
            ranks = []
            for i in range(dists.shape[0]):
                rank = int(np.where(dists[i, :].argsort()[::-1] == i)[0]) + 1
                ranks.append(rank)

        ranks = np.array(ranks)
        rankMetrics = metrics.computeMetrics(Variable(torch.from_numpy(ranks)))
        assert len(ranks) == len(dataset)
        poolSize = len(dataset)
        meanRank = ranks.mean()
        se = ranks.std() / np.sqrt(poolSize)
        meanPercRank = 100 * (1 - (meanRank / poolSize))
        percRankLow = 100 * (1 - ((meanRank + se) / poolSize))
        percRankHigh = 100 * (1 - ((meanRank - se) / poolSize))
        print((round, meanPercRank, percRankLow, percRankHigh))
        rankMetrics['percentile'] = meanPercRank
        rankMetricsRounds.append(rankMetrics)

    dataset.split = original_split
    rankMetricsRounds[-1]['winrateMean'] = winrateMean
    return rankMetricsRounds[-1], rankMetricsRounds


def rankQBot_guess(qBot, dataset, split, exampleLimit=None, verbose=0, policy="random", im_retrieval_mode="mse", imgEncodingMode=None):
    '''
        Evaluates Q-Bot performance on image retrieval and decision making on
        when to make a guess, when it is shown
        ground truth captions, questions and answers. Q-Bot does not
        generate dialog in this setting - it only encodes ground truth
        captions dialog in order to and make a decision on
        whether to make a guess. If making a guess, it will perform image retrieval by
        predicting FC-7 image features after that round of dialog.
        We evaluate the performance by the average turns of making the guess and the winning
        rates in total.

        Arguments:
            qBot    : Q-Bot
            dataset : VisDialDataset instance
            split   : Dataset split, can be 'val' or 'test'

            exampleLimit : Maximum number of data points to use from
                           the dataset split. If None, all data points.
    '''
    print("The current im_retrieval_mode is: {}".format(im_retrieval_mode))
    print("The current guess policy: {}".format(policy))
    batchSize = dataset.batchSize
    numRounds = dataset.numRounds
    if exampleLimit is None:
        numExamples = dataset.numDataPoints[split]
    else:
        numExamples = exampleLimit
    numBatches = (numExamples - 1) // batchSize + 1
    original_split = dataset.split
    dataset.split = split
    dataloader = DataLoader(
        dataset,
        batch_size=batchSize,
        shuffle=True,
        num_workers=0,
        collate_fn=dataset.collate_fn)

    # enumerate all gt features and all predicted features
    gtImgFeatures = []
    # caption + dialog rounds
    roundwiseFeaturePreds = [[] for _ in range(numRounds + 1)]
    logProbsAll = [[] for _ in range(numRounds)]
    featLossAll = [[] for _ in range(numRounds + 1)]
    start_t = timer()

    # Modified by Mingyang Zhou
    # Record the wining rates for the questioner in multiple games
    win_rate = 0
    decision_making_turns = 0
    num_games = 0
    for idx, batch in enumerate(dataloader):
        if idx == numBatches:
            break
        # if idx == 100:
        #     break
        # print("batch: {}".format(idx))

        if dataset.useGPU:
            batch = {
                key: v.cuda()
                for key, v in batch.items() if hasattr(v, 'cuda')
            }
        else:
            batch = {
                key: v.contiguous()
                for key, v in batch.items() if hasattr(v, 'cuda')
            }
        # caption = Variable(batch['cap'], volatile=True)
        # captionLens = Variable(batch['cap_len'], volatile=True)
        # gtQuestions = Variable(batch['ques'], volatile=True)
        # gtQuesLens = Variable(batch['ques_len'], volatile=True)
        # answers = Variable(batch['ans'], volatile=True)
        # ansLens = Variable(batch['ans_len'], volatile=True)
        # gtFeatures = Variable(batch['img_feat'], volatile=True)
        with torch.no_grad():
            caption = Variable(batch['cap'])
            captionLens = Variable(batch['cap_len'])
            gtQuestions = Variable(batch['ques'])
            gtQuesLens = Variable(batch['ques_len'])
            answers = Variable(batch['ans'])
            ansLens = Variable(batch['ans_len'])
            # gtFeatures = Variable(batch['img_feat'])
            if im_retrieval_mode == "mse":
                gtFeatures = Variable(batch['img_feat'])
            else:
                gtFeatures = Variable(batch['img_feat'])
                gtFeatures = qBot.multimodalpredictIm(gtFeatures)
            image = Variable(batch['img_feat'])  # Added by Mingyang Zhou
            # Initialize a Done Vector, edited by Mingyang Zhou
            Done = torch.zeros(gtFeatures.size()[0]).cuda()
            # Initialize a End Turns
            End_Rounds = np.ones(gtFeatures.size()[0]) * 10

            qBot.reset()
            qBot.observe(-1, caption=caption, captionLens=captionLens)
            if qBot.new_questioner:
                qBot.observe_im(image)
            if im_retrieval_mode == "mse":
                predFeatures = qBot.predictImage()
                # Evaluating round 0 feature regression network
                featLoss = F.mse_loss(predFeatures, gtFeatures)
                featLossAll[0].append(torch.mean(featLoss))
                # Keeping round 0 predictions
                roundwiseFeaturePreds[0].append(predFeatures)

                # Determine whether to make a guess at this moment
                round_dists = pairwise_distances(
                    predFeatures.cpu().numpy(), gtFeatures.cpu().numpy())

                # Initialize current_guess_result
                current_guess_result = []
                for i in range(round_dists.shape[0]):
                    current_rank = int(
                        np.where(round_dists[i, :].argsort() == i)[0]) + 1
                    current_guess_result.append(1 if current_rank == 1 else -1)
                    num_games += 1

                # # convert current_guess_result to tensor
                # current_guess_result = torch.cuda.FloatTensor(
                #     current_guess_result)
            elif im_retrieval_mode == "cosine_similarity":
                dialogEmbedding = qBot.multimodalpredictText()
                featLoss = pairwiseRanking_criterion(
                    gtFeatures, dialogEmbedding)
                featLossAll[0].append(torch.sum(featLoss))
                roundwiseFeaturePreds[0].append(
                    dialogEmbedding)
                # Initailize the round_dists, with each row as the cosine
                # similarity
                round_dists = np.matmul(
                    dialogEmbedding.cpu().numpy(), gtFeatures.cpu().numpy().transpose())
                # Initialize current_guess_result
                current_guess_result = []
                for i in range(round_dists.shape[0]):
                    current_rank = int(
                        np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                    current_guess_result.append(1 if current_rank == 1 else -1)
                    num_games += 1

            current_guess_result = torch.cuda.FloatTensor(
                current_guess_result)
            # for round in range(numRounds):
            round = 0
            while round < numRounds and torch.nonzero(Done).size()[0] < Done.size()[0]:
                guess_action = qBot.determine_action(
                    Done, policy=policy, imgEncodingMode=imgEncodingMode)
                guess_action = guess_action.cuda()

                # print(guess_action.size())
                # Verify winning rates
                # print("current_guess_result is: {}".format(current_guess_result))
                # print("current_guess_action is: {}".format(guess_action))
                winning_game = guess_action.cpu().numpy() * current_guess_result.cpu().numpy()
                winning_game[winning_game < 0] = 0
                win_rate += np.sum(winning_game)

                # Update the End_Round
                End_Rounds[guess_action.cpu().numpy() == 1] = round + 1

                # Update the Done Vector
                Done = torch.where(guess_action > 0, guess_action, Done)

                qBot.observe(
                    round,
                    ques=gtQuestions[:, round],
                    quesLens=gtQuesLens[:, round])
                qBot.observe(
                    round, ans=answers[:, round], ansLens=ansLens[:, round])
                logProbsCurrent = qBot.forward()
                # Evaluating logProbs for cross entropy
                logProbsAll[round].append(
                    utils.maskedNll(logProbsCurrent,
                                    gtQuestions[:, round].contiguous()))
                if im_retrieval_mode == "mse":
                    predFeatures = qBot.predictImage()
                    # Evaluating feature regression network
                    featLoss = F.mse_loss(
                        predFeatures, gtFeatures)
                    featLossAll[round + 1].append(torch.mean(featLoss))
                    # Keeping predictions
                    roundwiseFeaturePreds[round + 1].append(predFeatures)

                    # Determine the current_guess_result
                    round_dists = pairwise_distances(
                        predFeatures.cpu().numpy(), gtFeatures.cpu().numpy())
                    # Update the current_guess_result
                    current_guess_result = []
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort() == i)[0]) + 1
                        current_guess_result.append(
                            1 if current_rank == 1 else -1)
                elif im_retrieval_mode == "cosine_similarity":
                    dialogEmbedding = qBot.multimodalpredictText()
                    featLoss = pairwiseRanking_criterion(
                        gtFeatures, dialogEmbedding)
                    featLossAll[round + 1].append(torch.sum(featLoss))
                    roundwiseFeaturePreds[round + 1].append(
                        dialogEmbedding)
                    # Initailize the round_dists, with each row as the cosine
                    # similarity
                    round_dists = np.matmul(
                        dialogEmbedding.cpu().numpy(), gtFeatures.cpu().numpy().transpose())
                    # Initialize current_guess_result
                    current_guess_result = []
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                        current_guess_result.append(
                            1 if current_rank == 1 else -1)

                # convert current_guess_result to tensor
                current_guess_result = torch.cuda.FloatTensor(
                    current_guess_result)

                # convert gtFeatures back to tensor
                # gtFeatures = torch.from_numpy(gtFeatures)
                round += 1

            # When it reaches round 10, we will evaluate the winning game again
            winning_game = (1 - Done).cpu().numpy() * \
                current_guess_result.cpu().numpy()
            winning_game[winning_game < 0] = 0
            win_rate += np.sum(winning_game)

            # Update the End Rounds
            decision_making_turns += np.sum(End_Rounds)

            gtImgFeatures.append(gtFeatures)

            end_t = timer()
            delta_t = " Time: %5.2fs" % (end_t - start_t)
            start_t = end_t
            progressString = "\r[Qbot] Evaluating split '%s' [%d/%d]\t" + delta_t
            sys.stdout.write(progressString % (split, idx + 1, numBatches))
            sys.stdout.flush()

    sys.stdout.write("\n")
    # Compute the win_rate, modified by Mingyang Zhou
    win_rate = win_rate / num_games
    print("The winning rates is: {}".format(win_rate))

    # Compute the game average Ending Rounds
    average_guess_turns = decision_making_turns / num_games
    print("The average guess_turns is: {}".format(average_guess_turns))

    gtFeatures = torch.cat(gtImgFeatures, 0).data.cpu().numpy()
    rankMetricsRounds = []
    poolSize = len(dataset)

    # print(logProbsAll)
    # Keeping tracking of feature regression loss and CE logprobs
    # logProbsAll = [torch.cat(lprobs, 0).mean() for lprobs in logProbsAll]
    # featLossAll = [torch.cat(floss, 0).mean() for floss in featLossAll]
    # roundwiseLogProbs = torch.cat(logProbsAll, 0).data.cpu().numpy()
    # roundwiseFeatLoss = torch.cat(featLossAll, 0).data.cpu().numpy()

    # Remove the empty round of the logProbs
    logProbsAll = [x for x in logProbsAll if x]
    featLossAll = [x for x in featLossAll if x]
    rondwiseFeaturePreds = [x for x in roundwiseFeaturePreds if x]

    logProbsAll = [torch.stack(lprobs, 0).mean() for lprobs in logProbsAll]
    featLossAll = [torch.stack(floss, 0).mean() for floss in featLossAll]
    roundwiseLogProbs = torch.stack(logProbsAll, 0).data.cpu().numpy()
    roundwiseFeatLoss = torch.stack(featLossAll, 0).data.cpu().numpy()
    logProbsMean = roundwiseLogProbs.mean()
    featLossMean = roundwiseFeatLoss.mean()

    if verbose:
        print("Percentile mean rank (round, mean, low, high)")
    # for round in range(numRounds + 1):
    for round in range(len(featLossAll)):
        if im_retrieval_mode == "mse":
            predFeatures = torch.cat(roundwiseFeaturePreds[round],
                                     0).data.cpu().numpy()
            # num_examples x num_examples
            dists = pairwise_distances(predFeatures, gtFeatures)
            ranks = []
            for i in range(dists.shape[0]):
                rank = int(np.where(dists[i, :].argsort() == i)[0]) + 1
                ranks.append(rank)
        elif im_retrieval_mode == "cosine_similarity":
            predFeatures = torch.cat(roundwiseFeaturePreds[round],
                                     0).data.cpu().numpy()
            dists = np.matmul(predFeatures, gtFeatures.transpose())
            ranks = []
            for i in range(dists.shape[0]):
                rank = int(np.where(dists[i, :].argsort()[::-1] == i)[0]) + 1
                ranks.append(rank)
        ranks = np.array(ranks)
        rankMetrics = metrics.computeMetrics(Variable(torch.from_numpy(ranks)))
        meanRank = ranks.mean()
        se = ranks.std() / np.sqrt(poolSize)
        meanPercRank = 100 * (1 - (meanRank / poolSize))
        percRankLow = 100 * (1 - ((meanRank + se) / poolSize))
        percRankHigh = 100 * (1 - ((meanRank - se) / poolSize))
        if verbose:
            print((round, meanPercRank, percRankLow, percRankHigh))
        rankMetrics['percentile'] = meanPercRank
        rankMetrics['featLoss'] = roundwiseFeatLoss[round]
        if round < len(roundwiseLogProbs):
            rankMetrics['logProbs'] = roundwiseLogProbs[round]
        rankMetricsRounds.append(rankMetrics)

    rankMetricsRounds[-1]['logProbsMean'] = logProbsMean
    rankMetricsRounds[-1]['featLossMean'] = featLossMean
    rankMetricsRounds[-1]['winrateMean'] = win_rate  # Added by Mingyang Zhou
    dataset.split = original_split
    return rankMetricsRounds[-1], rankMetricsRounds


def rankQABots_guess(qBot, aBot, dataset, split, exampleLimit=None, beamSize=1, policy='random', im_retrieval_mode='mse', imgEncodingMode=None):
    '''
        Evaluates Q-Bot and A-Bot performance on 20 image guessing where
        both agents must converse with each other without any ground truth
        dialog. The common caption shown to both agents is not the ground
        truth caption, but is instead a caption generated (pre-computed)
        by a pre-trained captioning model (neuraltalk2).

        Arguments:
            qBot    : Q-Bot
            aBot    : A-Bot
            dataset : VisDialDataset instance
            split   : Dataset split, can be 'val' or 'test'

            exampleLimit : Maximum number of data points to use from
                           the dataset split. If None, all data points.
            beamSize     : Beam search width for generating utterrances
    '''
    print("The current im_retrieval_mode is: {}".format(im_retrieval_mode))
    # policy = "random"
    # print("The real guess_policy: {}".format(policy))
    batchSize = dataset.batchSize
    numRounds = dataset.numRounds
    if exampleLimit is None:
        numExamples = dataset.numDataPoints[split]
    else:
        numExamples = exampleLimit
    numBatches = (numExamples - 1) // batchSize + 1
    original_split = dataset.split
    dataset.split = split
    dataloader = DataLoader(
        dataset,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        collate_fn=dataset.collate_fn)

    gtImgFeatures = []
    roundwiseFeaturePreds = [[] for _ in range(numRounds + 1)]

    start_t = timer()

    # Modified by Mingyang Zhou
    # Record the wining rates for the questioner in multiple games
    win_rate = 0
    decision_making_turns = 0
    num_games = 0
    for idx, batch in enumerate(dataloader):
        if idx == numBatches:
            break

        if dataset.useGPU:
            batch = {key: v.cuda() for key, v in batch.items()
                     if hasattr(v, 'cuda')}
        else:
            batch = {key: v.contiguous() for key, v in batch.items()
                     if hasattr(v, 'cuda')}

        # caption = Variable(batch['cap'], volatile=True)
        # captionLens = Variable(batch['cap_len'], volatile=True)
        # gtQuestions = Variable(batch['ques'], volatile=True)
        # gtQuesLens = Variable(batch['ques_len'], volatile=True)
        # answers = Variable(batch['ans'], volatile=True)
        # ansLens = Variable(batch['ans_len'], volatile=True)
        # gtFeatures = Variable(batch['img_feat'], volatile=True)
        # image = Variable(batch['img_feat'], volatile=True)
        with torch.no_grad():
            caption = Variable(batch['cap'])
            captionLens = Variable(batch['cap_len'])
            gtQuestions = Variable(batch['ques'])
            gtQuesLens = Variable(batch['ques_len'])
            answers = Variable(batch['ans'])
            ansLens = Variable(batch['ans_len'])
            if im_retrieval_mode == "mse":
                gtFeatures = Variable(batch['img_feat'])
            else:
                gtFeatures = Variable(batch['img_feat'])
                gtFeatures = qBot.multimodalpredictIm(gtFeatures)
            image = Variable(batch['img_feat'])
            # Initialize a Done Vector, edited by Mingyang Zhou
            Done = torch.zeros(gtFeatures.size()[0]).cuda()
            # Initialize a End Turns
            End_Rounds = np.ones(gtFeatures.size()[0]) * numRounds

            aBot.eval(), aBot.reset()
            aBot.observe(-1, image=image, caption=caption,
                         captionLens=captionLens)
            qBot.eval(), qBot.reset()
            qBot.observe(-1, caption=caption, captionLens=captionLens)
            if qBot.new_questioner:
                qBot.observe_im(image)
            if im_retrieval_mode == "mse":
                predFeatures = qBot.predictImage()
                roundwiseFeaturePreds[0].append(predFeatures)

                # Determine whether to make a guess at this moment
                round_dists = pairwise_distances(
                    predFeatures.cpu().numpy(), gtFeatures.cpu().numpy())

                # Initialize current_guess_result
                current_guess_result = []
                for i in range(round_dists.shape[0]):
                    current_rank = int(
                        np.where(round_dists[i, :].argsort() == i)[0]) + 1
                    current_guess_result.append(1 if current_rank == 1 else -1)
                    # Update num_games
                    num_games += 1
            elif im_retrieval_mode == "cosine_similarity":
                dialogEmbedding = qBot.multimodalpredictText()
                # featLoss = pairwiseRanking_criterion(
                #     gtFeatures, dialogEmbedding)
                # featLossAll[0].append(torch.sum(featLoss))
                roundwiseFeaturePreds[0].append(
                    dialogEmbedding)
                # Initailize the round_dists, with each row as the cosine
                # similarity
                round_dists = np.matmul(
                    dialogEmbedding.cpu().numpy(), gtFeatures.cpu().numpy().transpose())
                # Initialize current_guess_result
                current_guess_result = []
                for i in range(round_dists.shape[0]):
                    current_rank = int(
                        np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                    current_guess_result.append(1 if current_rank == 1 else -1)
                    num_games += 1

            # convert current_guess_result to tensor
            current_guess_result = torch.cuda.FloatTensor(
                current_guess_result)

            # for round in range(numRounds):
            round = 0
            while round < numRounds and torch.nonzero(Done).size()[0] < Done.size()[0]:
                guess_action = qBot.determine_action(
                    Done, policy=policy, imgEncodingMode=imgEncodingMode)
                guess_action = guess_action.cuda()

                # print(guess_action.size())
                # Verify winning rates
                winning_game = guess_action.cpu().numpy() * current_guess_result.cpu().numpy()
                winning_game[winning_game < 0] = 0
                win_rate += np.sum(winning_game)

                # Update the End_Round
                End_Rounds[guess_action.cpu().numpy() == 1] = round + 1

                # Update the Done Vector
                Done = torch.where(guess_action > 0, guess_action, Done)

                questions, quesLens = qBot.forwardDecode(
                    inference='greedy', beamSize=beamSize)
                qBot.observe(round, ques=questions, quesLens=quesLens)
                aBot.observe(round, ques=questions, quesLens=quesLens)
                answers, ansLens = aBot.forwardDecode(
                    inference='greedy', beamSize=beamSize)
                aBot.observe(round, ans=answers, ansLens=ansLens)
                qBot.observe(round, ans=answers, ansLens=ansLens)
                if qBot.new_questioner:
                    qBot.observe_im(image)
                if im_retrieval_mode == "mse":
                    predFeatures = qBot.predictImage()
                    roundwiseFeaturePreds[round + 1].append(predFeatures)

                    # Determine the current_guess_result
                    round_dists = pairwise_distances(
                        predFeatures.cpu().numpy(), gtFeatures.cpu().numpy())
                    # Update the current_guess_result
                    current_guess_result = []
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort() == i)[0]) + 1
                        current_guess_result.append(
                            1 if current_rank == 1 else -1)
                    # convert current_guess_result to tensor
                    current_guess_result = torch.cuda.FloatTensor(
                        current_guess_result)
                elif im_retrieval_mode == "cosine_similarity":
                    dialogEmbedding = qBot.multimodalpredictText()
                    # featLoss = pairwiseRanking_criterion(
                    #     gtFeatures, dialogEmbedding)
                    # featLossAll[round + 1].append(torch.sum(featLoss))
                    roundwiseFeaturePreds[round + 1].append(
                        dialogEmbedding)
                    # Initailize the round_dists, with each row as the cosine
                    # similarity
                    round_dists = np.matmul(
                        dialogEmbedding.cpu().numpy(), gtFeatures.cpu().numpy().transpose())
                    # Initialize current_guess_result
                    current_guess_result = []
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                        current_guess_result.append(
                            1 if current_rank == 1 else -1)
                # convert current_guess_result to tensor
                current_guess_result = torch.cuda.FloatTensor(
                    current_guess_result)
                round += 1

            # When it reaches round 10, we will evaluate the winning game again
            winning_game = (1 - Done).cpu().numpy() * \
                current_guess_result.cpu().numpy()
            winning_game[winning_game < 0] = 0
            win_rate += np.sum(winning_game)

            # Update the End Rounds
            decision_making_turns += np.sum(End_Rounds)

            gtImgFeatures.append(gtFeatures)

            end_t = timer()
            delta_t = " Rate: %5.2fs" % (end_t - start_t)
            start_t = end_t
            progressString = "\r[Qbot] Evaluating split '%s' [%d/%d]\t" + delta_t
            sys.stdout.write(progressString % (split, idx + 1, numBatches))
            sys.stdout.flush()

    sys.stdout.write("\n")
    # Compute the win_rate, modified by Mingyang Zhou
    win_rate = win_rate / num_games
    print("The winning rates is: {}".format(win_rate))

    # Compute the game average Ending Rounds
    average_guess_turns = decision_making_turns / num_games
    print("The average guess_turns is: {}".format(average_guess_turns))

    gtFeatures = torch.cat(gtImgFeatures, 0).data.cpu().numpy()
    rankMetricsRounds = []

    print("Percentile mean rank (round, mean, low, high)")
    # Remove the empty list
    # roundwiseFeaturePreds = [
    #     x for x in roundwiseFeaturePreds if x[0].size()[0] == len(dataset)]
    for round in range(numRounds + 1):
        if im_retrieval_mode == "mse":
            predFeatures = torch.cat(roundwiseFeaturePreds[round],
                                     0).data.cpu().numpy()
            dists = pairwise_distances(predFeatures, gtFeatures)
            # num_examples x num_examples
            ranks = []
            for i in range(dists.shape[0]):
                # Computing rank of i-th prediction vs all images in split
                rank = int(np.where(dists[i, :].argsort() == i)[0]) + 1
                ranks.append(rank)
        elif im_retrieval_mode == "cosine_similarity":
            predFeatures = torch.cat(roundwiseFeaturePreds[round],
                                     0).data.cpu().numpy()
            dists = np.matmul(predFeatures, gtFeatures.transpose())
            ranks = []
            for i in range(dists.shape[0]):
                rank = int(np.where(dists[i, :].argsort()[::-1] == i)[0]) + 1
                ranks.append(rank)

        ranks = np.array(ranks)
        rankMetrics = metrics.computeMetrics(Variable(torch.from_numpy(ranks)))
        # print(len(ranks))
        # print(len(dataset))
        if len(ranks) != len(dataset):
            break
        # assert len(ranks) == len(dataset)
        poolSize = len(dataset)
        meanRank = ranks.mean()
        se = ranks.std() / np.sqrt(poolSize)
        meanPercRank = 100 * (1 - (meanRank / poolSize))
        percRankLow = 100 * (1 - ((meanRank + se) / poolSize))
        percRankHigh = 100 * (1 - ((meanRank - se) / poolSize))
        print((round, meanPercRank, percRankLow, percRankHigh))
        rankMetrics['percentile'] = meanPercRank
        rankMetricsRounds.append(rankMetrics)

    dataset.split = original_split
    rankMetricsRounds[-1]['winrateMean'] = win_rate
    return rankMetricsRounds[-1], rankMetricsRounds
