import os
import gc
import random
import pprint
from six.moves import range
from markdown2 import markdown
from time import gmtime, strftime
from timeit import default_timer as timer

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

import options
from dataloader import VisDialDataset, VisDialDataset_20ImGuess, FashionDataset_20ImGuess, RoundSampler
from torch.utils.data import DataLoader
from eval_utils.rank_answerer import rankABot
from eval_utils.rank_questioner import rankQBot, rankQBot_guess, rankQABots_guess
from eval_utils.rank_questioner_fashion import fashion_rankQBot
from utils import utilities as utils
from utils.visualize import VisdomVisualize
import visdial.loss.loss_utils as loss_utils
from utils import utilities_fashion as utils_fashion
#-------------------------------------------------------------------------
# Setup
#-------------------------------------------------------------------------

# Read the command line options
# Writing a separate file to load the parameters.
params = options.readCommandLine()

# Seed rng for reproducibility
random.seed(params['randomSeed'])
torch.manual_seed(params['randomSeed'])
if params['useGPU']:
    torch.cuda.manual_seed_all(params['randomSeed'])

# Setup dataloader
splits = ['train', 'val', 'test']

# Load the Fashion Data
dataset = FashionDataset_20ImGuess(params, splits)


# Params to transfer from dataset
transfer = ['vocabSize', 'numOptions', 'numRounds']
for key in transfer:
    if hasattr(dataset, key):
        params[key] = getattr(dataset, key)

# Create save path and checkpoints folder
os.makedirs('checkpoints', exist_ok=True)
# os.mkdir(params['savePath'])
os.makedirs(params['savePath'], exist_ok=True)

# Implement this If we want to Load a Pretrained Model
if params['pretrainedVisdialModel']:
    visdial_params = params.copy()
    visdial_params['inputImg'] = 'data/visdial/data_img.h5'
    visdial_params['inputQues'] = 'data/visdial/chat_processed_data.h5'
    visdial_params['inputJson'] = 'data/visdial/chat_processed_params.json'
    visdial_params['cocoDir'] = ''
    visdial_params['cocoInfo'] = ''
    visdial_params['fashionData'] = False

    # Load the VisData
    # Setup dataloader
    splits = ['train', 'val', 'test']
    #dataset = VisDialDataset(params, splits)
    # if params['fashionData']:
    #     dataset = FashionDataset_20ImGuess(params,splits)
    # else:
    visdial_dataset = VisDialDataset_20ImGuess(visdial_params, splits)

    # Params to transfer from dataset
    transfer = ['vocabSize', 'numOptions', 'numRounds']
    for key in transfer:
        if hasattr(visdial_dataset, key):
            visdial_params[key] = getattr(visdial_dataset, key)

    # Load the Ind2Words Dictionary from Fashion Dataset
    #ind2word = dataset.ind2word

    # Get the overlap ids
    visdial_dataset.vocabSize
    visdial_dataset_dict = visdial_dataset.ind2word
    fashion_word2ind = {}
    for k, v in dataset.ind2word.items():
        fashion_word2ind[v] = k

    overlapwords_visdial_index = []
    overlapwords_fashion_index = []
    for k, v in visdial_dataset.ind2word.items():
        if fashion_word2ind.get(v):
            overlapwords_visdial_index.append(k)
            overlapwords_fashion_index.append(fashion_word2ind[v])

# Loading Modules
parameters = []
aBot = None
qBot = None

# Loading A-Bot
if params['trainMode'] in ['sl-abot', 'sl-abot-vse', 'rl-full-QAf', 'rl-guess-QAf']:
    aBot, loadedParams, optim_state = utils.loadModel(params, 'abot')
    for key in loadedParams:
        params[key] = loadedParams[key]
    parameters.extend(aBot.parameters())

# Loading Q-Bot
if params['trainMode'] in ['sl-qbot', 'sl-qbot-vse', 'rl-full-QAf', 'rl-guess-QAf']:
    if params['pretrainedVisdialModel']:
        qBot, loadedParams, optim_state = utils_fashion.loadModel_Fashion(
            params, 'qbot', overlap_fashion=overlapwords_fashion_index, overlap_visdial=overlapwords_visdial_index)
    else:
        qBot, loadedParams, optim_state = utils.loadModel(params, 'qbot')

    for key in loadedParams:
        params[key] = loadedParams[key]

    if params['trainMode'] in ['rl-full-QAf', 'rl-guess-QAf'] and params['freezeQFeatNet']:
        qBot.freezeFeatNet()
    # Filtering parameters which require a gradient update
    parameters.extend(filter(lambda p: p.requires_grad, qBot.parameters()))
    # parameters.extend(qBot.parameters())

# Setup pytorch dataloader
dataset.split = 'train'
# if params['fashionData']:
dataloader = DataLoader(
    dataset,
    num_workers=params['numWorkers'],
    collate_fn=dataset.collate_fn,
    pin_memory=False,
    batch_sampler=RoundSampler(dataset, params['batchSize']))

# Initializing visdom environment for plotting data
viz = VisdomVisualize(
    enable=bool(params['enableVisdom']),
    env_name=params['visdomEnv'],
    server=params['visdomServer'],
    port=params['visdomServerPort'])
pprint.pprint(params)
viz.addText(pprint.pformat(params, indent=4))

# Setup optimizer
if params['continue']:
    # Continuing from a loaded checkpoint restores the following
    startIterID = params['ckpt_iterid'] + 1  # Iteration ID
    lRate = params['ckpt_lRate']  # Learning rate
    print("Continuing training from iterId[%d]" % startIterID)
else:
    # Beginning training normally, without any checkpoint
    lRate = params['learningRate']
    startIterID = 0

optimizer = optim.Adam(parameters, lr=lRate)
if params['continue']:  # Restoring optimizer state
    print("Restoring optimizer state dict from checkpoint")
    optimizer.load_state_dict(optim_state)
runningLoss = None

mse_criterion = nn.MSELoss(reduction='none')
# Initialize a new loss, modified by Mingyang Zhou
pairwiseRanking_criterion = loss_utils.PairwiseRankingLoss(
    margin=0.1)  # Need to be Tuned
pairwiseRanking_score = loss_utils.PairwiseRankingScore(margin=0.1)

numIterPerEpoch = dataset.numDataPoints['train'] // params['batchSize']
print('\n%d iter per epoch.' % numIterPerEpoch)

if params['useCurriculum']:
    if params['continue']:
        rlRound = max(0, 9 - (startIterID // numIterPerEpoch))
    else:
        rlRound = params['numRounds'] - 1
else:
    rlRound = 0

#-------------------------------------------------------------------------
# Training
#-------------------------------------------------------------------------


def batch_iter(dataloader):
    for epochId in range(params['numEpochs']):
        for idx, batch in enumerate(dataloader):
            yield epochId, idx, batch

# #Manually set rl_round
# rlRound = 7

start_t = timer()
best_val_winrate = 0
best_val_QA_winrate = 0

# # Run Evaluation First
# qBot.eval()
# if qBot:
#     print("qBot Validation:")
#     if params['trainMode'] == 'rl-guess-QAf':
#         rankMetrics, roundMetrics = rankQBot_guess(qBot, dataset,'val', policy=params['guessPolicy'], im_retrieval_mode=params['imgRetrievalMode'])
#         #rankMetrics, roundMetrics = rankQBot_guess(qBot, dataset,'val')
#     else:
#         rankMetrics, roundMetrics = rankQBot(qBot, dataset, 'val',im_retrieval_mode=params['imgRetrievalMode'])

for epochId, idx, batch in batch_iter(dataloader):
    #print("current_batch is: {}".format(idx))
    # Keeping track of iterId and epoch
    iterId = startIterID + idx + (epochId * numIterPerEpoch)

#     #Inserted by Mingyang Zhou
#     if iterId > 0:
#         break
    ##########################

    epoch = iterId // numIterPerEpoch
    gc.collect()

    # Moving current batch to GPU, if availabled
    if dataset.useGPU:
        batch = {key: v.cuda() if hasattr(v, 'cuda')
                 else v for key, v in batch.items()}

    image = Variable(batch['img_feat'], requires_grad=False)
    caption = Variable(batch['cap'], requires_grad=False)
    captionLens = Variable(batch['cap_len'], requires_grad=False)
    gtQuestions = Variable(batch['ques'], requires_grad=False)
    gtQuesLens = Variable(batch['ques_len'], requires_grad=False)
    gtAnswers = Variable(batch['ans'], requires_grad=False)
    gtAnsLens = Variable(batch['ans_len'], requires_grad=False)
    # if not params['fashionData']:
    # options = Variable(batch['opt'], requires_grad=False)
    # optionLens = Variable(batch['opt_len'], requires_grad=False)
    # gtAnsId = Variable(batch['ans_id'], requires_grad=False)

    # Initializing optimizer and losses
    optimizer.zero_grad()
    loss = 0
    qBotLoss = 0
    aBotLoss = 0
    rlLoss = 0
    featLoss = 0
    qBotRLLoss = 0
    aBotRLLoss = 0
    predFeatures = None
    initialGuess = None
    # if params["fashionData"]:
    # Define Dynamic numRounds based on each separate dialog.
    numRounds = batch['ques_numRounds'][0].item()
    # else:
    #     numRounds = params['numRounds']
    if numRounds == 0:
        continue
    # numRounds = 1 # Override for debugging lesser rounds of dialog

    # Setting training modes for both bots and observing captions, images
    # where needed
    if aBot:
        aBot.train(), aBot.reset()
        if params['rlAbotReward']:
            aBot.reset_reinforce()
        aBot.observe(-1, image=image, caption=caption, captionLens=captionLens)
    if qBot:
        qBot.train(), qBot.reset()
        qBot.reset_reinforce()
        qBot.observe(-1, caption=caption, captionLens=captionLens)
        # Added by Mingyang Zhou, observe the group of images every round
        if qBot.new_questioner:
            qBot.observe_im(image)

    # Q-Bot image feature regression ('guessing') only occurs if Q-Bot is
    # present
    if params['trainMode'] in ['sl-qbot-vse', 'rl-full-QAf', 'rl-guess-QAf'] and params['imgRetrievalMode'] == "cosine_similarity":
        # Compute the feat_Loss with different
        initialembeddingtext = qBot.multimodalpredictText()
        initialembeddingim = qBot.multimodalpredictIm(image)
        # Compute the loss
        prevFeatDist = pairwiseRanking_score(
            initialembeddingim, initialembeddingtext)
        prevFeatDist = torch.mean(prevFeatDist, 1)
        featDist = pairwiseRanking_criterion(
            initialembeddingim, initialembeddingtext)
        featLoss += torch.sum(featDist)
        #prevFeatDist = torch.mean(prevFeatDist, 1)
        # prevFeatDist = featLoss #This needs to be Evaluated Later
    elif params['trainMode'] in ['sl-qbot', 'rl-full-QAf', 'rl-guess-QAf']:
        initialGuess = qBot.predictImage()
        prevFeatDist = mse_criterion(initialGuess, image)
        featLoss += torch.mean(prevFeatDist)
        prevFeatDist = torch.mean(prevFeatDist, 1)

    if params['trainMode'] == 'rl-guess-QAf':
        # Initialize the rewards
        winRewards = params['winRewards']
#         loseRewards = -10
#         wrongReward = -3
        game_rewards = 0

        # print("Enter rl-guess-QAf Mode")
        # Initialize a Done Vector
        Done = torch.zeros(params['batchSize']).cuda()

        # print(torch.nonzero(Done))
        for round in range(numRounds):
            """
            Loop over rounds of dialog. Currently one mode of training is supported
            rl-guess-QAf:
               Hierachical RL-finetuning of A-Bot and Q-Bot in a cooperative setting where the reward
               comes from making the correct guess of the target image and the difference in mean squared
               error loss for image feature regression (i.e. image prediction)
            """
            # Tracking components which require a foward pass
            # A-Bot dialog model
            forwardABot = (round < rlRound)
            # Q-Bot dialog model
            forwardQBot = (round < rlRound)
            # Q-Bot feature regression network
            forwardFeatNet = True

            # Answerer Forward Pass
            if forwardABot:
                # Observe Ground Truth (GT) question
                aBot.observe(
                    round,
                    ques=gtQuestions[:, round],
                    quesLens=gtQuesLens[:, round])
                # Observe GT answer for teacher forcing
                aBot.observe(
                    round,
                    ans=gtAnswers[:, round],
                    ansLens=gtAnsLens[:, round])
                ansLogProbs = aBot.forward()
                # Cross Entropy (CE) Loss for Ground Truth Answers
                aBotLoss += utils.maskedNll(ansLogProbs,
                                            gtAnswers[:, round].contiguous())

            # Questioner Forward Pass (dialog model)
            if forwardQBot:
                # Observe GT question for teacher forcing
                qBot.observe(
                    round,
                    ques=gtQuestions[:, round],
                    quesLens=gtQuesLens[:, round])
                quesLogProbs = qBot.forward()
                # Cross Entropy (CE) Loss for Ground Truth Questions
                qBotLoss += utils.maskedNll(quesLogProbs,
                                            gtQuestions[:, round].contiguous())
                # Observe GT answer for updating dialog history
                qBot.observe(
                    round,
                    ans=gtAnswers[:, round],
                    ansLens=gtAnsLens[:, round])

            # In order to stay true to the original implementation, the feature
            # regression network makes predictions before dialog begins and for
            # the first 9 rounds of dialog. This can be set to 10 if needed.
            MAX_FEAT_ROUNDS = 9

            # Questioner feature regression network forward pass
            # We Change the regression network such that it only performs on
            # the game that is not Done
            if forwardFeatNet and round < MAX_FEAT_ROUNDS:
                if params['imgRetrievalMode'] == "cosine_similarity":
                    # Compute the feat_Loss with different
                    initialembeddingtext = qBot.multimodalpredictText()
                    embeddingim = qBot.multimodalpredictIm(image)
                    # Compute the loss
                    featDist = pairwiseRanking_criterion(
                        embeddingim, initialembeddingtext)
                    featLoss += torch.sum(featDist)
                    # Compute the current retrieval results
                    import numpy as np
                    round_dists = np.matmul(
                        initialembeddingtext.data.cpu().numpy(), embeddingim.data.cpu().numpy().transpose())
                    # Initialize current_guess_result
                    current_guess_result = []
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort()[::-1] == i)[0]) + 1
                        current_guess_result.append(
                            1 if current_rank == 1 else -1)
                    # convert current_guess_result to tensor
                    current_guess_result = torch.cuda.FloatTensor(
                        current_guess_result)
                else:
                    # Make an image prediction after each round for the game that
                    # is not Done
                    predFeatures = qBot.predictImage()
                    featDist = mse_criterion(predFeatures, image)
                    # Why Taking the Mean Value
                    featDist = torch.mean(featDist)
                    featLoss += featDist
                    # Compute the current retrieval results
                    from sklearn.metrics.pairwise import pairwise_distances
                    import numpy as np
                    # Compute the distance against all other features
                    round_dists = pairwise_distances(
                        predFeatures.data.cpu().numpy(), image.data.cpu().numpy())
                    # Initialize current_guess_result
                    current_guess_result = []
                    for i in range(round_dists.shape[0]):
                        current_rank = int(
                            np.where(round_dists[i, :].argsort() == i)[0]) + 1
                        current_guess_result.append(
                            1 if current_rank == 1 else -1)
                    # convert current_guess_result to tensor
                    current_guess_result = torch.cuda.FloatTensor(
                        current_guess_result)

            # A-Bot and Q-Bot interacting in RL rounds
            if round >= rlRound and torch.nonzero(Done).size()[0] < params['batchSize']:
                # if torch.nonzero(Done).size()[0] < params['batchSize']:
                # print(round)
                # Determine whether to make a guess
                guess_action = qBot.determine_action(
                    Done, policy="policy_gradient", imgEncodingMode=params['imgEncodingMode'])
                # Update the Done Vector
                Done = torch.where(guess_action > 0, guess_action, Done)
                # print(Done)

                questions, quesLens = qBot.forwardDecode(inference='sample')
                qBot.observe(round, ques=questions, quesLens=quesLens)
                aBot.observe(round, ques=questions, quesLens=quesLens)
                answers, ansLens = aBot.forwardDecode(inference='sample')
                aBot.observe(round, ans=answers, ansLens=ansLens)
                qBot.observe(round, ans=answers, ansLens=ansLens)
                if params['imgRetrievalMode'] == "cosine_similarity":
                    # Compute the feat_Loss with different
                    embeddingtext = qBot.multimodalpredictText()
                    #embeddingim = qBot.multimodalpredictIm(image)
                    # Compute the loss
                    featDist = pairwiseRanking_score(
                        embeddingim, embeddingtext)
                    # print(featDist.size())
                    featDist = torch.mean(featDist, 1)
                else:
                    # Q-Bot makes a guess at the end of each round
                    predFeatures = qBot.predictImage()

                    # Computing reward based on Q-Bot's predicted image
                    featDist = mse_criterion(predFeatures, image)
                    featDist = torch.mean(featDist, 1)

                reward = (prevFeatDist.detach() - featDist).mul(1 - Done) + \
                    guess_action.mul(current_guess_result * winRewards)
                # print(reward)
                prevFeatDist = featDist

                # Update the reward
                qBot.update_reward(reward)
                if params['rlAbotReward']:
                    aBot.update_reward(reward)

            # Update Round
            # print(round)
            #round += 1
        # Compute the Loss
        qBotRLLoss = qBot.reinforce_guess()
        if params['rlAbotReward']:
            aBotRLLoss = aBot.reinforce_guess()
        rlLoss += torch.mean(aBotRLLoss)
        rlLoss += torch.mean(qBotRLLoss)
        game_rewards = qBot.compute_game_rewards()
    else:
        game_rewards = 0
        # Iterating over dialog rounds
        for round in range(numRounds):
            """
            Loop over rounds of dialog. Currently three modes of training are
            supported:
                sl-abot :
                    Supervised pre-training of A-Bot model using cross
                    entropy loss with ground truth answers
                sl-qbot :
                    Supervised pre-training of Q-Bot model using cross
                    entropy loss with ground truth questions for the
                    dialog model and mean squared error loss for image
                    feature regression (i.e. image prediction)
                sl-abot-vse:
                    Supervised pre-training of A-Bot model using cross 
                    entropy loss with ground truth answers while simultaneously learn to 
                    retrieve the image
                rl-full-QAf :
                    RL-finetuning of A-Bot and Q-Bot in a cooperative
                    setting where the common reward is the difference
                    in mean squared error between the current and
                    previous round of Q-Bot's image prediction.
                    Annealing: In order to ease in the RL objective,
                    fine-tuning starts with first N-1 rounds of SL
                    objective and last round of RL objective - the
                    number of RL rounds are increased by 1 after
                    every epoch until only RL objective is used for
                    all rounds of dialog.
            """
            # Tracking components which require a forward pass
            # A-Bot dialog model
            forwardABot = (params['trainMode'] in ['sl-abot', 'sl-abot-vse']
                           or (params['trainMode'] == 'rl-full-QAf'
                               and round < rlRound))
            # Q-Bot dialog model
            forwardQBot = (params['trainMode'] in ['sl-qbot', 'sl-qbot-vse']
                           or (params['trainMode'] == 'rl-full-QAf'
                               and round < rlRound))
            # Q-Bot feature regression network
            forwardFeatNet = (forwardQBot or params['trainMode'] in [
                              'rl-full-QAf'])

            # Answerer Forward Pass
            if forwardABot:
                # Observe Ground Truth (GT) question
                aBot.observe(
                    round,
                    ques=gtQuestions[:, round],
                    quesLens=gtQuesLens[:, round])
                # Observe GT answer for teacher forcing
                aBot.observe(
                    round,
                    ans=gtAnswers[:, round],
                    ansLens=gtAnsLens[:, round])
                ansLogProbs = aBot.forward()
                # Cross Entropy (CE) Loss for Ground Truth Answers
                aBotLoss += utils.maskedNll(ansLogProbs,
                                            gtAnswers[:, round].contiguous())

            # Questioner Forward Pass (dialog model)
            if forwardQBot:
                # Observe GT question for teacher forcing
                qBot.observe(
                    round,
                    ques=gtQuestions[:, round],
                    quesLens=gtQuesLens[:, round])
                quesLogProbs = qBot.forward()
                # Cross Entropy (CE) Loss for Ground Truth Questions
                qBotLoss += utils.maskedNll(quesLogProbs,
                                            gtQuestions[:, round].contiguous())
                # Observe GT answer for updating dialog history
                qBot.observe(
                    round,
                    ans=gtAnswers[:, round],
                    ansLens=gtAnsLens[:, round])

            # In order to stay true to the original implementation, the feature
            # regression network makes predictions before dialog begins and for
            # the first 9 rounds of dialog. This can be set to 10 if needed.
            MAX_FEAT_ROUNDS = 9

            # Questioner feature regression network forward pass
            if forwardFeatNet and round < MAX_FEAT_ROUNDS:
                if params['imgRetrievalMode'] == "cosine_similarity":
                    # Compute the feat_Loss with different
                    initialembeddingtext = qBot.multimodalpredictText()
                    embeddingim = qBot.multimodalpredictIm(image)
                    # Compute the loss
                    featDist = pairwiseRanking_criterion(
                        embeddingim, initialembeddingtext)
                    featLoss += torch.sum(featDist)
                else:
                    # Make an image prediction after each round
                    predFeatures = qBot.predictImage()
                    featDist = mse_criterion(predFeatures, image)
                    # Why Taking the Mean Value
                    featDist = torch.mean(featDist)
                    featLoss += featDist

            # A-Bot and Q-Bot interacting in RL rounds
            if params['trainMode'] == 'rl-full-QAf' and round >= rlRound:
                # Run one round of conversation
                questions, quesLens = qBot.forwardDecode(inference='sample')
                qBot.observe(round, ques=questions, quesLens=quesLens)
                aBot.observe(round, ques=questions, quesLens=quesLens)
                answers, ansLens = aBot.forwardDecode(inference='sample')
                aBot.observe(round, ans=answers, ansLens=ansLens)
                qBot.observe(round, ans=answers, ansLens=ansLens)

                if params['imgRetrievalMode'] == "cosine_similarity":
                    # Compute the feat_Loss with different
                    embeddingtext = qBot.multimodalpredictText()
                    #embeddingim = qBot.multimodalpredictIm(image)
                    # Compute the loss
                    featDist = pairwiseRanking_score(
                        embeddingim, embeddingtext)
                    # print(featDist.size())
                    featDist = torch.mean(featDist, 1)
                else:
                    # Q-Bot makes a guess at the end of each round
                    predFeatures = qBot.predictImage()

                    # Computing reward based on Q-Bot's predicted image
                    featDist = mse_criterion(predFeatures, image)
                    featDist = torch.mean(featDist, 1)

                reward = prevFeatDist.detach() - featDist
                prevFeatDist = featDist

                qBotRLLoss = qBot.reinforce(reward)
                if params['rlAbotReward']:
                    aBotRLLoss = aBot.reinforce(reward)
                rlLoss += torch.mean(aBotRLLoss)
                rlLoss += torch.mean(qBotRLLoss)
                # Append the game_rewards
                game_rewards += torch.mean(reward).data.item()

    # Loss coefficients
    rlCoeff = params['rlLossCoeff']
    rlLoss = rlLoss * rlCoeff
    featLoss = featLoss * params['featLossCoeff']
    # Averaging over rounds
    qBotLoss = (params['CELossCoeff'] * qBotLoss) / numRounds
    aBotLoss = (params['CELossCoeff'] * aBotLoss) / numRounds
    featLoss = featLoss / numRounds  # / (numRounds+1)
    rlLoss = rlLoss / numRounds
    # Total loss
    loss = qBotLoss + aBotLoss + rlLoss + featLoss
    loss.backward()
    optimizer.step()

    # Tracking a running average of loss
    if runningLoss is None:
        runningLoss = loss.data.item()
    else:
        runningLoss = 0.95 * runningLoss + 0.05 * loss.data.item()

    # Decay learning rate
    if lRate > params['minLRate']:
        for gId, group in enumerate(optimizer.param_groups):
            optimizer.param_groups[gId]['lr'] *= params['lrDecayRate']
        lRate *= params['lrDecayRate']
        if iterId % 10 == 0:  # Plot learning rate till saturation
            viz.linePlot(iterId, lRate, 'learning rate', 'learning rate')

    # RL Annealing: Every epoch after the first, decrease rlRound
    if iterId % numIterPerEpoch == 0 and iterId > 0:
        if params['trainMode'] in ['rl-full-QAf', 'rl-guess-QAf']:
            rlRound = max(0, rlRound - 1)
            print('Using rl starting at round {}'.format(rlRound))

    # Print every now and then
    if iterId % 10 == 0:
        end_t = timer()  # Keeping track of iteration(s) time
        curEpoch = float(iterId) / numIterPerEpoch
        timeStamp = strftime('%a %d %b %y %X', gmtime())
        printFormat = '[%s][Ep: %.2f][Iter: %d][Time: %5.2fs][Loss: %.3g]'
        printFormat += '[lr: %.3g]'
        printInfo = [
            timeStamp, curEpoch, iterId, end_t - start_t, loss.data.item(), lRate
        ]
        start_t = end_t
        print(printFormat % tuple(printInfo))
#         print("question bot supervised loss: {}".format(qBotLoss.data.item()))
#         print("answer bot supervised loss: {}".format(aBotLoss.data.item()))
#         print("reinforcement loss: {}".format(rlLoss.data.item()))
#         print("Image Retrieval Loss: {}".format(featLoss.data.item()))

        # Update line plots
        if isinstance(aBotLoss, Variable):
            viz.linePlot(iterId, aBotLoss.data.item(), 'aBotLoss', 'train CE')
        if isinstance(qBotLoss, Variable):
            viz.linePlot(iterId, qBotLoss.data.item(), 'qBotLoss', 'train CE')
        if isinstance(rlLoss, Variable):
            viz.linePlot(iterId, rlLoss.data.item(), 'rlLoss', 'train')
        if isinstance(featLoss, Variable):
            viz.linePlot(iterId, featLoss.data.item(), 'featLoss',
                         'train FeatureRegressionLoss')
        # Plot the loss as well
        if params['trainMode'] in ["rl-full-QAf", "rl-guess-QAf"]:
            # Plot the game_rewards
            viz.linePlot(iterId, game_rewards, 'rewards', 'game_rewards')

        viz.linePlot(iterId, loss.data.item(), 'loss', 'train loss')
        viz.linePlot(iterId, runningLoss, 'loss', 'running train loss')

    # Evaluate every epoch
    if iterId % (numIterPerEpoch // 1) == 0:
        # Keeping track of epochID
        curEpoch = float(iterId) / numIterPerEpoch
        epochId = (1.0 * iterId / numIterPerEpoch) + 1

        # Set eval mode
        if aBot:
            aBot.eval()
        if qBot:
            qBot.eval()

        if params['enableVisdom']:
            # Printing visdom environment name in terminal
            print("Currently on visdom env [%s]" % (params['visdomEnv']))

        # Mapping iteration count to epoch count
        viz.linePlot(iterId, epochId, 'iter x epoch', 'epochs')

        print('Performing validation...')
        if aBot and 'ques' in batch:
            print("aBot Validation:")

            # NOTE: A-Bot validation is slow, so adjust exampleLimit as needed
            rankMetrics = rankABot(
                aBot,
                dataset,
                'val',
                scoringFunction=utils.maskedNll,
                exampleLimit=25 * params['batchSize'])

            for metric, value in rankMetrics.items():
                viz.linePlot(
                    epochId, value, 'val - aBot', metric, xlabel='Epochs')

            if 'logProbsMean' in rankMetrics:
                logProbsMean = params['CELossCoeff'] * rankMetrics[
                    'logProbsMean']
                viz.linePlot(iterId, logProbsMean, 'aBotLoss', 'val CE')

                if params['trainMode'] == 'sl-abot':
                    valLoss = logProbsMean
                    viz.linePlot(iterId, valLoss, 'loss', 'val loss')

        if qBot:
            print("qBot Validation:")
            if params['trainMode'] == 'rl-guess-QAf':
                #rankMetrics, roundMetrics = rankQBot_guess(qBot, dataset,'val', policy=params['guessPolicy'])
                #rankMetrics, roundMetrics = rankQBot_guess(qBot, dataset,'val')
                rankMetrics, roundMetrics = rankQBot_guess(
                    qBot, dataset, 'val', policy=params['guessPolicy'], im_retrieval_mode=params['imgRetrievalMode'], imgEncodingMode=params['imgEncodingMode'])
            else:
                # if params['fashionData']:
                rankMetrics, roundMetrics = fashion_rankQBot(
                    qBot, dataset, 'test', im_retrieval_mode=params['imgRetrievalMode'])
                # else:
                #     rankMetrics, roundMetrics = rankQBot(
                # qBot, dataset, 'val',
                # im_retrieval_mode=params['imgRetrievalMode'])

            for metric, value in rankMetrics.items():
                viz.linePlot(
                    epochId, value, 'val - qBot', metric, xlabel='Epochs')

            viz.linePlot(iterId, epochId, 'iter x epoch', 'epochs')

            if 'logProbsMean' in rankMetrics:
                logProbsMean = params['CELossCoeff'] * rankMetrics[
                    'logProbsMean']
                viz.linePlot(iterId, logProbsMean, 'qBotLoss', 'val CE')

            if 'featLossMean' in rankMetrics:
                featLossMean = params['featLossCoeff'] * (
                    rankMetrics['featLossMean'])
                viz.linePlot(iterId, featLossMean, 'featLoss',
                             'val FeatureRegressionLoss')

            if 'logProbsMean' in rankMetrics and 'featLossMean' in rankMetrics:
                if params['trainMode'] == 'sl-qbot':
                    valLoss = logProbsMean + featLossMean
                    viz.linePlot(iterId, valLoss, 'loss', 'val loss')
            if 'winrateMean' in rankMetrics:
                # if params['trainMode'] in ['rl-full-QAf', 'rl-guess-QAf']:
                viz.linePlot(iterId, rankMetrics[
                    'winrateMean'], 'win_rates', 'val win_rates')
                if rankMetrics['winrateMean'] > best_val_winrate:
                    # Save the model
                    params['ckpt_iterid'] = iterId
                    params['ckpt_lRate'] = lRate
                    saveFile = os.path.join(params['savePath'],
                                            'qbot_best.vd')
                    print('Saving the best qBot: ' + saveFile)
                    utils.saveModel(qBot, optimizer, saveFile, params)
                    # update the best_val_winrate
                    best_val_winrate = rankMetrics['winrateMean']
            print("Current best val winning rates: {}".format(best_val_winrate))

            if aBot and params['trainMode'] == 'rl-guess-QAf':
                rankMetrics2, roundMetrics2 = rankQABots_guess(qBot, aBot, dataset, 'val', beamSize=params['beamSize'], policy=params[
                                                               'guessPolicy'], im_retrieval_mode=params['imgRetrievalMode'], imgEncodingMode=params['imgEncodingMode'])
                # Plot the winrates using the qBot and aBot
                viz.linePlot(iterId, rankMetrics2[
                             'winrateMean'], 'QA win_rates', 'val QA win_rates')
                if rankMetrics2['winrateMean'] > best_val_QA_winrate:
                    # Save the model
                    params['ckpt_iterid'] = iterId
                    params['ckpt_lRate'] = lRate
                    q_saveFile = os.path.join(params['savePath'],
                                              'qbot_best_QA.vd')
                    print('Saving the best qBot based on QARankGuess: ' + q_saveFile)
                    utils.saveModel(qBot, optimizer, q_saveFile, params)
                    a_saveFile = os.path.join(params['savePath'],
                                              'abot_best_QA.vd')
                    print('Saving the best aBot based on QARankGuess: ' + a_saveFile)
                    utils.saveModel(aBot, optimizer, a_saveFile, params)

                    # update the best_val_winrate
                    best_val_QA_winrate = rankMetrics2['winrateMean']
            # print("Current best val winning rates with QABotsRanking: {}".format(
            #     best_val_QA_winrate))

    # Save the model after every 5 epoch
    if iterId % numIterPerEpoch == 0 and curEpoch % 5 == 0:
        params['ckpt_iterid'] = iterId
        params['ckpt_lRate'] = lRate

        if aBot:
            saveFile = os.path.join(params['savePath'],
                                    'abot_ep_%d.vd' % curEpoch)
            print('Saving model: ' + saveFile)
            utils.saveModel(aBot, optimizer, saveFile, params)
        if qBot:
            saveFile = os.path.join(params['savePath'],
                                    'qbot_ep_%d.vd' % curEpoch)
            print('Saving model: ' + saveFile)
            utils.saveModel(qBot, optimizer, saveFile, params)
    # Save the qBot if it reaches the best overall winning rates
