"""
Introduction: The reinforcement learning pipeline to train the multi-task reinforcement learning.
Author: Mingyang Zhou
"""
import os
import gc
import random
import pprint
from six.moves import range
from markdown2 import markdown
from time import gmtime, strftime
from timeit import default_timer as timer

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

import options
from dataloader import VisDialDataset, VisDialDataset_20ImGuess
from torch.utils.data import DataLoader
from eval_utils.rank_answerer import rankABot
from eval_utils.rank_questioner import rankQBot, rankQBot_guess, rankQABots_guess
from utils import utilities as utils
from utils.visualize import VisdomVisualize
import visdial.loss.loss_utils as loss_utils
from visdial.loss.rl_imGuess_loss import Ranker, rl_rollout_search
from visdial.loss.infoGain import imgPrior, prepareBatch

#-------------------------------------------------------------------------
# Setup
#-------------------------------------------------------------------------

# Read the command line options
# Writing a separate file to load the parameters.
params = options.readCommandLine()


# Seed rng for reproducibility
random.seed(params['randomSeed'])
torch.manual_seed(params['randomSeed'])
if params['useGPU']:
    torch.cuda.manual_seed_all(params['randomSeed'])

# Setup dataloader
splits = ['train', 'val', 'test']

# dataset = VisDialDataset(params, splits)
dataset = VisDialDataset_20ImGuess(params, splits)

# Params to transfer from dataset
transfer = ['vocabSize', 'numOptions', 'numRounds']
for key in transfer:
    if hasattr(dataset, key):
        params[key] = getattr(dataset, key)

# Create save path and checkpoints folder
os.makedirs('checkpoints', exist_ok=True)
# os.mkdir(params['savePath'])
os.makedirs(params['savePath'], exist_ok=True)


# Loading Modules
parameters = []
aBot = None
qBot = None
qBot_Target = None  # The Target Model for Policy Improvement

# Loading a pretrained A-Bot
if params['trainMode'] in ['sl-abot', 'rl-full-QAf', 'rl-guess-QAf', 'rl-full-QAf-imGuess', 'rl-full-QAf-imGuess-WordRL']:
    aBot, loadedParams, optim_state = utils.loadModel(params, 'abot')
    for key in loadedParams:
        params[key] = loadedParams[key]
    parameters.extend(aBot.parameters())

    if params['trainMode'] in ['rl-full-QAf-language-gen']:
        aBot_Aproxy, loadedParams, optim_state = utils.loadModel(
            params, 'abot')
        # The answerer is to be optimized.
        # We will also keep updating aBot_Aproxy
        # paramters.extend(aBot_Aproxy.parameters())
if params['trainMode'] in ['sl-qbot', 'rl-full-QAf', 'rl-guess-QAf', 'rl-full-QAf-imGuess', 'rl-full-QAf-imGuess-WordRL']:
    # Loading a pretrained Q-Bot and Q-Bot-Target
    qBot, loadedParams, optim_state = utils.loadModel(params, 'qbot')
    if params['trainMode'] in ['rl-full-QAf-imGuess']:
        qBot_Target, loadedParams, optim_state = utils.loadModel(
            params, 'qbot')

    for key in loadedParams:
        # There will be a potential issue here.
        params[key] = loadedParams[key]

    if params['trainMode'] in ['rl-full-QAf', 'rl-guess-QAf', 'rl-full-QAf-imGuess', 'rl-full-QAf-imGuess-WordRL'] and params['freezeQFeatNet']:
        qBot.freezeFeatNet()
        # qBot_Target.freezeFeatNet()
    if params['trainMode'] in ['rl-full-QAf-imGuess'] and params['freezeQImgNet']:
        qBot.freezeImgNet()
    # Filtering parameters which require a gradient update
    parameters.extend(filter(lambda p: p.requires_grad, qBot.parameters()))
    # parameters.extend(filter(lambda p: p.requires_grad, qBot_Target.parameters()))
    # parameters.extend(qBot.parameters())


# Setup pytorch dataloader
dataset.split = 'train'
dataloader = DataLoader(
    dataset,
    batch_size=params['batchSize'],
    shuffle=False,
    num_workers=params['numWorkers'],
    drop_last=True,
    collate_fn=dataset.collate_fn,
    pin_memory=False)

# Initializing visdom environment for plotting data
viz = VisdomVisualize(
    enable=bool(params['enableVisdom']),
    env_name=params['visdomEnv'],
    server=params['visdomServer'],
    port=params['visdomServerPort'])
pprint.pprint(params)
viz.addText(pprint.pformat(params, indent=4))

# Setup optimizer
if params['continue']:
    # Continuing from a loaded checkpoint restores the following
    startIterID = params['ckpt_iterid'] + 1  # Iteration ID
    lRate = params['ckpt_lRate']  # Learning rate
    print("Continuing training from iterId[%d]" % startIterID)
else:
    # Beginning training normally, without any checkpoint
    lRate = params['learningRate']
    startIterID = 0

optimizer = optim.Adam(parameters, lr=lRate)
if params['continue']:  # Restoring optimizer state
    print("Restoring optimizer state dict from checkpoint")
    optimizer.load_state_dict(optim_state)
runningLoss = None

mse_criterion = nn.MSELoss(reduce=False)
# # Initialize a new loss, modified by Mingyang Zhou
# pairwiseRanking_criterion = loss_utils.PairwiseRankingLoss(
#     margin=0.1)  # Need to be Tuned
# pairwiseRanking_score = loss_utils.PairwiseRankingScore(margin=0.1)

# numIterPerEpoch = dataset.numDataPoints['train'] // params['batchSize']
# print(len(dataset))
numIterPerEpoch = len(dataset) // params['batchSize']
print('\n%d iter per epoch.' % numIterPerEpoch)

if params['useCurriculum']:
    if params['continue']:
        # rlRound = max(0, 9 - (startIterID // numIterPerEpoch))
        rlRound = max(0, 8 - (startIterID // numIterPerEpoch))
    else:
        # rlRound = params['numRounds'] - 1
        rlRound = params['numRounds'] - 2
else:
    rlRound = 0

#-------------------------------------------------------------------------
# Training
#-------------------------------------------------------------------------


def batch_iter(dataloader):
    for epochId in range(params['numEpochs']):
        for idx, batch in enumerate(dataloader):
            yield epochId, idx, batch


def imgLoader(dataloader, dataset):
    all_im_feat = []
    for idx, batch in enumerate(dataloader):
        batch = prepareBatch(dataset, batch)
        all_im_feat.append(Variable(batch['img_feat'], requires_grad=False))
    all_im_feat = torch.cat(all_im_feat, 0)
    return all_im_feat

# #Manually set rl_round
# rlRound = 7

start_t = timer()
best_val_winrate = 0
best_val_QA_winrate = 0

# Define the all_input
# print(dataset.split)
# all_im_feat = dataset.data['%s_img_fv' % dataset.split]
all_im_feat = imgLoader(dataloader, dataset)
# batch = dataloader.next()
# for batch in dataloader:
#     with torch.no_grad():
#         all_im_feat.append(Variable(batch['img_feat'], requires_grad=False))
# all_im_feat = torch.cat(all_im_feat,0)
# print(all_im_feat.size())

# Initialize the ranker
im_ranker = Ranker()

# Initialize for language generation
rl_language_training = False

if params['trainMode'] == "rl-full-QAf-imGuess":
    im_ranker.update_rep(qBot_Target, all_im_feat)

for epochId, idx, batch in batch_iter(dataloader):
    # print("current_batch is: {}".format(idx))
    # Keeping track of iterId and epoch
    iterId = startIterID + idx + (epochId * numIterPerEpoch)
    # Inserted by Mingyang Zhou
#     if iterId > 0:
#         break
    ##########################

    epoch = iterId // numIterPerEpoch
    gc.collect()

    # Moving current batch to GPU, if availabled
    if dataset.useGPU:
        batch = {key: v.cuda() if hasattr(v, 'cuda')
                 else v for key, v in batch.items()}

    image = Variable(batch['img_feat'], requires_grad=False)
    caption = Variable(batch['cap'], requires_grad=False)
    captionLens = Variable(batch['cap_len'], requires_grad=False)
    gtQuestions = Variable(batch['ques'], requires_grad=False)
    gtQuesLens = Variable(batch['ques_len'], requires_grad=False)
    gtAnswers = Variable(batch['ans'], requires_grad=False)
    gtAnsLens = Variable(batch['ans_len'], requires_grad=False)
    options = Variable(batch['opt'], requires_grad=False)
    optionLens = Variable(batch['opt_len'], requires_grad=False)
    gtAnsId = Variable(batch['ans_id'], requires_grad=False)
    # Added a usr_idx_id
    gtIdx = Variable(torch.LongTensor(batch['index']), requires_grad=False)

    # Update the Ranker
    if params['trainMode'] in ["sl-qbot", "rl-full-QAf-imGuess-WordRL"]:
        im_ranker.update_rep(qBot, all_im_feat)

    # Initializing optimizer and losses
    optimizer.zero_grad()
    loss = 0
    qBotLoss = 0
    aBotLoss = 0
    rlLoss = 0
    featLoss = 0
    qBotRLLoss = 0
    aBotRLLoss = 0
    predFeatures = None
    initialGuess = None
    numRounds = params['numRounds']
    # numRounds = 1 # Override for debugging lesser rounds of dialog

    # Setting training modes for both bots and observing captions, images
    # where needed
    if aBot:
        aBot.train(), aBot.reset()
        if params['rlAbotReward']:
            aBot.reset_reinforce()
        aBot.observe(-1, image=image, caption=caption, captionLens=captionLens)
    if qBot:
        qBot.train(), qBot.reset()
        qBot.reset_reinforce()
        qBot.observe(-1, caption=caption, captionLens=captionLens)
        # Added by Mingyang Zhou, observe the group of images every round
        if qBot.new_questioner:
            qBot.observe_im(image)
        # Randomly sample some image to start
        if params['imgEncodingMode'] == "imGuess":
            with torch.no_grad():
                act_index = torch.randint(
                    0, all_im_feat.size(0) - 1, (image.size(0), 1))
                predicted_image = all_im_feat[act_index].squeeze(1)
#                 print(predicted_image.size())
                qBot.observe_im(predicted_image)

    initialGuess = qBot.predictImage()
    imageProjection = qBot.forwardImage(image)
    # prevFeatDist = mse_criterion(initialGuess, image)
    prevFeatDist = mse_criterion(initialGuess, imageProjection)
    featLoss += torch.mean(prevFeatDist)
    prevFeatDist = torch.mean(prevFeatDist, 1)

    # Predict The Current Closes Images and Insert it into the Bot.
#     if params['imgEncodingMode'] == "imGuess":
#         with torch.no_grad():
#             act_index = im_ranker.nearest_neighbor(initialGuess.data, all_im_feat)
#             predicted_image = all_im_feat[act_index]

    if params['trainMode'] == 'rl-full-QAf-imGuess':
        # In Policy Improvement We don't update qBot_Target
        if qBot_Target:
            qBot_Target.eval()
            qBot_Target.reset()
            qBot_Target.reset_reinforce()
            qBot_Target.observe(-1, caption=caption, captionLens=captionLens)
            # Observe the predicted image
            qBot_Target.observe_im(predicted_image)

        with torch.no_grad():
            target_im_state = qBot_Target.predictImage()

        # Compute the loss_sum
        # TODO: Implement rl_rollout_search
        # act_img_idx_mc, loss = rl_rollout_search(behavior_im_state, target_im_state, k, dialog_turns, all_im_feat)
#         act_idx, current_rlloss = rl_rollout_search(
#             qBot, qBot_Target, aBot, initialGuess, target_im_state, 0, numRounds, gtIdx, all_im_feat, im_ranker)
#         # Update rlLoss
#         rlLoss += current_rlloss

#         #Remove the unnecessary local variable
#         del current_rlloss
#         #     game_rewards = 0
#     #     # Iterating over dialog rounds
#         with torch.no_grad():
#             predicted_image = all_im_feat[act_idx]
#     else:
    with torch.no_grad():
        # act_index = im_ranker.nearest_neighbor(initialGuess.data,
        # all_im_feat)
        act_index = im_ranker.nearest_neighbor(initialGuess.data)
        predicted_image = all_im_feat[act_index]

    for round in range(numRounds):
        """
        Loop over rounds of dialog. Currently three modes of training are
        supported:
            sl-abot :
                Supervised pre-training of A-Bot model using cross
                entropy loss with ground truth answers
            sl-qbot :
                Supervised pre-training of Q-Bot model using cross
                entropy loss with ground truth questions for the
                dialog model and mean squared error loss for image
                feature regression (i.e. image prediction)
            sl-abot-vse:
                Supervised pre-training of A-Bot model using cross
                entropy loss with ground truth answers while simultaneously learn to
                retrieve the image
            rl-full-QAf :
                RL-finetuning of A-Bot and Q-Bot in a cooperative
                setting where the common reward is the difference
                in mean squared error between the current and
                previous round of Q-Bot's image prediction.
                Annealing: In order to ease in the RL objective,
                fine-tuning starts with first N-1 rounds of SL
                objective and last round of RL objective - the
                number of RL rounds are increased by 1 after
                every epoch until only RL objective is used for
                all rounds of dialog.
            rl-full-QAf-imGuess:
                RL-finetuning of Q-Bot in a cooperative setting where the reward
                is the ranking of the target image.
        """
        # Tracking components which require a forward pass
        # A-Bot dialog model
        forwardABot = (params['trainMode'] in ['sl-abot']
                       or('rl' in params['trainMode'] and round < rlRound) or rl_language_training)

        # Q-Bot dialog model
        forwardQBot = (params['trainMode'] in ['sl-qbot']
                       or('rl' in params['trainMode'] and round < rlRound) or rl_language_training)

        # Q-Bot feature regression network
#         forwardFeatNet = (params['trainMode'] in [
#                           'sl-qbot'] or ('rl' in params['trainMode'] and round < rlRound))
        forwardFeatNet = (forwardQBot or ('rl' in params['trainMode']))

        # Answerer Forward Pass
        if forwardABot:
            # Observe Ground Truth (GT) question
            aBot.observe(
                round,
                ques=gtQuestions[:, round],
                quesLens=gtQuesLens[:, round])
            # Observe GT answer for teacher forcing
            aBot.observe(
                round,
                ans=gtAnswers[:, round],
                ansLens=gtAnsLens[:, round])
            ansLogProbs = aBot.forward()
            # Cross Entropy (CE) Loss for Ground Truth Answers
            aBotLoss += utils.maskedNll(ansLogProbs,
                                        gtAnswers[:, round].contiguous())

        # Questioner Forward Pass (dialog model)
        if forwardQBot:
            if qBot.imgEncodingMode == "imGuess":
                # Observe the predicted_Image
                qBot.observe_im(predicted_image)

            # Observe GT question for teacher forcing
            qBot.observe(
                round,
                ques=gtQuestions[:, round],
                quesLens=gtQuesLens[:, round])

            # print(len(qBot_Target.encoder.questionEmbeds))
            if params['trainMode'] == 'rl-full-QAf-imGuess':
                qBot_Target.observe_im(predicted_image)
                qBot_Target.observe(
                    round,
                    ques=gtQuestions[:, round],
                    quesLens=gtQuesLens[:, round])
                qBot_Target.observe(
                    round,
                    ans=gtAnswers[:, round],
                    ansLens=gtAnsLens[:, round])

            # print(len(qBot_Target.encoder.questionEmbeds))
            quesLogProbs = qBot.forward()
            # Cross Entropy (CE) Loss for Ground Truth Questions
            qBotLoss += utils.maskedNll(quesLogProbs,
                                        gtQuestions[:, round].contiguous())

            # Observe GT answer for updating dialog history
            qBot.observe(
                round,
                ans=gtAnswers[:, round],
                ansLens=gtAnsLens[:, round])

            if params['trainMode'] == "rl-full-QAf-imGuess":
                with torch.no_grad():
                    target_quesLogProbs = qBot_Target.forward()
                    del target_quesLogProbs
#                 qBot_Target.observe(
#                     round,
#                     ans=gtAnswers[:, round],
#                     ansLens=gtAnsLens[:, round])

        # In order to stay true to the original implementation, the feature
        # regression network makes predictions before dialog begins and for
        # the first 9 rounds of dialog. This can be set to 10 if needed.
        MAX_FEAT_ROUNDS = 9

        # Questioner feature regression network forward pass
        if forwardFeatNet and round < MAX_FEAT_ROUNDS:
            # Make an image prediction after each round
            predFeatures = qBot.predictImage()
            if not params["disableFeatForwardLoss"]:
                imageProjection = qBot.forwardImage(image)
                featDist = mse_criterion(predFeatures, imageProjection)
                # Why Taking the Mean Value
                featDist = torch.mean(featDist)
                featLoss += featDist

            if qBot.imgEncodingMode == "imGuess" and round < rlRound:
                # Update the predicted_image for the next round
                with torch.no_grad():
                    #                     act_index = im_ranker.nearest_neighbor(
                    #                         predFeatures.data, all_im_feat)
                    act_index = im_ranker.nearest_neighbor(predFeatures.data)
                    predicted_image = all_im_feat[act_index]

        # # A-Bot and Q-Bot interacting in RL rounds
        if params['trainMode'] == 'rl-full-QAf-imGuess' and round >= rlRound and round < numRounds - 1 and (not rl_language_training):
            qBot.observe_im(predicted_image)
            qBot_Target.observe_im(predicted_image)
            # Run one round of conversation
            questions, quesLens = qBot.forwardDecode(inference='sample')
            qBot.observe(round, ques=questions, quesLens=quesLens)
            qBot_Target.observe(round, ques=questions, quesLens=quesLens)
            aBot.observe(round, ques=questions, quesLens=quesLens)

            answers, ansLens = aBot.forwardDecode(inference='sample')
            aBot.observe(round, ans=answers, ansLens=ansLens)
            qBot.observe(round, ans=answers, ansLens=ansLens)
            qBot_Target.observe(round, ans=answers, ansLens=ansLens)

            # Q-Bot makes a guess at the end of each round
            behavior_im_state = qBot.predictImage()
            with torch.no_grad():
                target_im_state = qBot_Target.predictImage()
#             print(round)
#             print(len(qBot_Target.questions))
            act_index, current_rlloss = rl_rollout_search(
                qBot, qBot_Target, aBot, behavior_im_state, target_im_state, round + 1, numRounds, gtIdx, all_im_feat, im_ranker)
            with torch.no_grad():
                # Update the predicted image
                predicted_image = all_im_feat[act_index]
            # current_rlloss = 0

            rlLoss += current_rlloss
            del current_rlloss
        elif params['trainMode'] == 'rl-full-QAf-imGuess-WordRL' and round >= rlRound and (not rl_language_training):
            qBot.observe_im(predicted_image)
            # Run one round of conversation
            questions, quesLens = qBot.forwardDecode(inference='sample')
            qBot.observe(round, ques=questions, quesLens=quesLens)
            aBot.observe(round, ques=questions, quesLens=quesLens)

            answers, ansLens = aBot.forwardDecode(inference='sample')
            aBot.observe(round, ans=answers, ansLens=ansLens)
            qBot.observe(round, ans=answers, ansLens=ansLens)

            predFeatures = qBot.predictImage()
            imageProjection = qBot.forwardImage(image)
            featDist = mse_criterion(predFeatures, imageProjection)
            # Why Taking the Mean Value
            featDist = torch.mean(featDist)

            reward = prevFeatDist.data - featDist
            prevFeatDist = featDist

            qBotRLLoss = qBot.reinforce(reward)
            if params['rlAbotReward']:
                aBotRLLoss = aBot.reinforce(reward)
            rlLoss += torch.mean(aBotRLLoss)
            rlLoss += torch.mean(qBotRLLoss)

            # Guess the Next Image
            with torch.no_grad():
                # act_index = im_ranker.nearest_neighbor(
                # predFeatures.data, all_im_feat)
                act_index = im_ranker.nearest_neighbor(predFeatures.data)
                predicted_image = all_im_feat[act_index]

    rlCoeff = params['rlLossCoeff']
    rlLoss = rlLoss * rlCoeff
    featLoss = featLoss * params['featLossCoeff']
    # Averaging over rounds
    qBotLoss = (params['CELossCoeff'] * qBotLoss) / numRounds
    aBotLoss = (params['CELossCoeff'] * aBotLoss) / numRounds
    featLoss = featLoss / numRounds  # / (numRounds+1)
    rlLoss = rlLoss / numRounds
    # Total loss
#     print(qBotLoss)
#     print(rlLoss)
#     print(featLoss)
    loss = qBotLoss + aBotLoss + rlLoss + featLoss
    loss.backward()
    optimizer.step()

    # Tracking a running average of loss
    if runningLoss is None:
        runningLoss = loss.data.item()
    else:
        runningLoss = 0.95 * runningLoss + 0.05 * loss.data.item()

    # Decay learning rate
    if lRate > params['minLRate']:
        for gId, group in enumerate(optimizer.param_groups):
            optimizer.param_groups[gId]['lr'] *= params['lrDecayRate']
        lRate *= params['lrDecayRate']
        if iterId % 10 == 0:  # Plot learning rate till saturation
            viz.linePlot(iterId, lRate, 'learning rate', 'learning rate')

    # RL Annealing: Every epoch after the first, decrease rlRound
    if iterId % numIterPerEpoch == 0 and iterId > 0:
        if params['trainMode'] in ['rl-full-QAf', 'rl-guess-QAf', 'rl-full-QAf-imGuess']:
            rlRound = max(0, rlRound - 1)
            print('Using rl starting at round {}'.format(rlRound))

    # Print every now and then
    if iterId % 10 == 0:
        end_t = timer()  # Keeping track of iteration(s) time
        curEpoch = float(iterId) / numIterPerEpoch
        timeStamp = strftime('%a %d %b %y %X', gmtime())
        printFormat = '[%s][Ep: %.2f][Iter: %d][Time: %5.2fs][Loss: %.3g]'
        printFormat += '[lr: %.3g]'
        printInfo = [timeStamp, curEpoch, iterId,
                     end_t - start_t, loss.data.item(), lRate]
        start_t = end_t
        print(printFormat % tuple(printInfo))
#         print("question bot supervised loss: {}".format(qBotLoss.data.item()))
#         print("answer bot supervised loss: {}".format(aBotLoss.data.item()))
#         print("reinforcement loss: {}".format(rlLoss.data.item()))
#         print("Image Retrieval Loss: {}".format(featLoss.data.item()))

        # Update line plots
        if isinstance(aBotLoss, Variable):
            viz.linePlot(iterId, aBotLoss.data.item(), 'aBotLoss', 'train CE')
        if isinstance(qBotLoss, Variable):
            viz.linePlot(iterId, qBotLoss.data.item(), 'qBotLoss', 'train CE')
        if isinstance(rlLoss, Variable):
            viz.linePlot(iterId, rlLoss.data.item(), 'rlLoss', 'train')
        if isinstance(featLoss, Variable):
            viz.linePlot(iterId, featLoss.data.item(), 'featLoss',
                         'train FeatureRegressionLoss')
        # Plot the loss as well
#         if params['trainMode'] in ["rl-full-QAf", "rl-guess-QAf"]:
#             # Plot the game_rewards
#             viz.linePlot(iterId, game_rewards, 'rewards', 'game_rewards')

        viz.linePlot(iterId, loss.data.item(), 'loss', 'train loss')
        viz.linePlot(iterId, runningLoss, 'loss', 'running train loss')

    # Evaluate every epoch
    if iterId % (numIterPerEpoch // 1) == 0:
        # Keeping track of epochID
        curEpoch = float(iterId) / numIterPerEpoch
        epochId = (1.0 * iterId / numIterPerEpoch) + 1

        # Set eval mode
        if aBot:
            aBot.eval()
        if qBot:
            qBot.eval()

        if params['enableVisdom']:
            # Printing visdom environment name in terminal
            print("Currently on visdom env [%s]" % (params['visdomEnv']))

        # Mapping iteration count to epoch count
        viz.linePlot(iterId, epochId, 'iter x epoch', 'epochs')

        print('Performing validation...')
        if aBot and 'ques' in batch:
            print("aBot Validation:")

            # NOTE: A-Bot validation is slow, so adjust exampleLimit as needed
            rankMetrics = rankABot(
                aBot, dataset, 'val', scoringFunction=utils.maskedNll, exampleLimit=25 * params['batchSize'])

            for metric, value in rankMetrics.items():
                viz.linePlot(epochId, value, 'val - aBot',
                             metric, xlabel='Epochs')

            if 'logProbsMean' in rankMetrics:
                logProbsMean = params['CELossCoeff'] * \
                    rankMetrics['logProbsMean']
                viz.linePlot(iterId, logProbsMean, 'aBotLoss', 'val CE')

                if params['trainMode'] == 'sl-abot':
                    valLoss = logProbsMean
                    viz.linePlot(iterId, valLoss, 'loss', 'val loss')

        if qBot:
            print("qBot Validation:")
            if params['trainMode'] == 'rl-guess-QAf':
                # rankMetrics, roundMetrics = rankQBot_guess(qBot, dataset,'val', policy=params['guessPolicy'])
                # rankMetrics, roundMetrics = rankQBot_guess(qBot, dataset,'val')
                rankMetrics, roundMetrics = rankQBot_guess(qBot, dataset, 'val', policy=params[
                                                           'guessPolicy'], im_retrieval_mode=params['imgRetrievalMode'], imgEncodingMode=params['imgEncodingMode'])
            else:
                rankMetrics, roundMetrics = rankQBot(
                    qBot, dataset, 'val', im_retrieval_mode=params['imgRetrievalMode'])

            for metric, value in rankMetrics.items():
                viz.linePlot(epochId, value, 'val - qBot',
                             metric, xlabel='Epochs')

            viz.linePlot(iterId, epochId, 'iter x epoch', 'epochs')

            if 'logProbsMean' in rankMetrics:
                logProbsMean = params['CELossCoeff'] * \
                    rankMetrics['logProbsMean']
                viz.linePlot(iterId, logProbsMean, 'qBotLoss', 'val CE')

            if 'featLossMean' in rankMetrics:
                featLossMean = params['featLossCoeff'] * \
                    (rankMetrics['featLossMean'])
                viz.linePlot(iterId, featLossMean, 'featLoss',
                             'val FeatureRegressionLoss')

            if 'logProbsMean' in rankMetrics and 'featLossMean' in rankMetrics:
                if params['trainMode'] == 'sl-qbot':
                    valLoss = logProbsMean + featLossMean
                    viz.linePlot(iterId, valLoss, 'loss', 'val loss')
            if 'winrateMean' in rankMetrics:
                # if params['trainMode'] in ['rl-full-QAf', 'rl-guess-QAf']:
                viz.linePlot(iterId, rankMetrics[
                    'winrateMean'], 'win_rates', 'val win_rates')
                if rankMetrics['winrateMean'] > best_val_winrate:
                    # Save the model
                    params['ckpt_iterid'] = iterId
                    params['ckpt_lRate'] = lRate
                    saveFile = os.path.join(params['savePath'],
                                            'qbot_best.vd')
                    print('Saving the best qBot: ' + saveFile)
                    utils.saveModel(qBot, optimizer, saveFile, params)
                    # update the best_val_winrate
                    best_val_winrate = rankMetrics['winrateMean']
            print("Current best val winning rates: {}".format(best_val_winrate))

            if aBot and params['trainMode'] == 'rl-guess-QAf':
                rankMetrics2, roundMetrics2 = rankQABots_guess(qBot, aBot, dataset, 'val', beamSize=params['beamSize'], policy=params[
                                                               'guessPolicy'], im_retrieval_mode=params['imgRetrievalMode'], imgEncodingMode=params['imgEncodingMode'])
                # Plot the winrates using the qBot and aBot
                viz.linePlot(iterId, rankMetrics2[
                             'winrateMean'], 'QA win_rates', 'val QA win_rates')
                if rankMetrics2['winrateMean'] > best_val_QA_winrate:
                    # Save the model
                    params['ckpt_iterid'] = iterId
                    params['ckpt_lRate'] = lRate
                    q_saveFile = os.path.join(params['savePath'],
                                              'qbot_best_QA.vd')
                    print('Saving the best qBot based on QARankGuess: ' + q_saveFile)
                    utils.saveModel(qBot, optimizer, q_saveFile, params)
                    a_saveFile = os.path.join(params['savePath'],
                                              'abot_best_QA.vd')
                    print('Saving the best aBot based on QARankGuess: ' + a_saveFile)
                    utils.saveModel(aBot, optimizer, a_saveFile, params)

                    # update the best_val_winrate
                    best_val_QA_winrate = rankMetrics2['winrateMean']
            print("Current best val winning rates with QABotsRanking: {}".format(
                best_val_QA_winrate))

            # # Update the target qBot Net Weights with the current qBot
            # if params['trainMode'] == "rl-full-QAf-imGuess":
            #     qBot_Target.load_state_dict(qBot.state_dict())

    # Save the model after every 5 epoch
    if "sl" in params["trainMode"]:
        save_every_iter = 5
    else:
        save_every_iter = 1
    if iterId % numIterPerEpoch == 0 and curEpoch % save_every_iter == 0:
        params['ckpt_iterid'] = iterId
        params['ckpt_lRate'] = lRate

        if aBot:
            saveFile = os.path.join(
                params['savePath'], 'abot_ep_%d.vd' % curEpoch)
            print('Saving model: ' + saveFile)
            utils.saveModel(aBot, optimizer, saveFile, params)
        if qBot:
            saveFile = os.path.join(
                params['savePath'], 'qbot_ep_%d.vd' % curEpoch)
            print('Saving model: ' + saveFile)
            utils.saveModel(qBot, optimizer, saveFile, params)
    # If alternating On and curEpoch
    if params['alternateTraining'] and iterId % (numIterPerEpoch // 1) == 0:
        rl_language_training = curEpoch % 2 == 1  # Switching when curEpoch is odd
        print(rl_language_training)
    # Save the qBot if it reaches the best overall winning rates
#     del loss
#     del rlLoss
    # torch.cuda.empty_cache()
    # print(torch.cuda.memory_allocated(device=0))
