"""
Define the Loss for RL_Reinforcement Loss
"""
import torch
from torch.autograd import Variable
import math

"""
Borrow the Ranker from https://github.com/XiaoxiaoGuo/fashion-retrieval
"""
#train_all_size = print(len(dataset))
# all_image_feat = dataset.data['%s_img_fv' % dataset.split]
# print(all_image_feat.size())


class Ranker():

    def __init__(self):
        super(Ranker, self).__init__()
        return

    def update_rep(self, qBot, all_input, batch_size=64):
        self.feat = torch.Tensor(all_input.size(
            0), qBot.shared_embedding_space)
        self.feat.required_grad = False

        if torch.cuda.is_available():
            self.feat = self.feat.cuda()

        for i in range(1, math.ceil(all_input.size(0) / batch_size)):
            x = all_input[(i - 1) * batch_size:(i * batch_size)]
            if torch.cuda.is_available():
                x = x.cuda()

            x = Variable(x)
            out = qBot.forwardImage(x)
            self.feat[(i - 1) * batch_size:i * batch_size].copy_(out.data)

        if all_input.size(0) % batch_size > 0:
            x = all_input[-(all_input.size(0) % batch_size)::]
            if torch.cuda.is_available():
                x = x.cuda()
            x = Variable(x)
            out = qBot.forwardImage(x)
            self.feat[-(all_input.size(0) % batch_size)::].copy_(out.data)
        # print(self.feat)
        return

    def k_nearest_neighbors(self, target, K=10):
        idx = torch.LongTensor(target.size(0), K)
        if torch.cuda.is_available():
            target = target.cuda()
            #all_feat = all_feat.cuda()

#         for i in range(target.size(0)):
#             val = self.feat - \
#                 target[i].expand(self.feat.size(0), self.feat.size(1))
#             val = val ** 2
#             val = val.sum(1)
#             v, id = torch.topk(val, k=K, dim=0, largest=False)
#             idx[i].copy_(id.view(-1))
#         return idx
        val = self.feat.unsqueeze(0).expand(
            target.size(0), -1, -1) - target.unsqueeze(1)
#         print(val.size())
        #val = val ** 2
        val = torch.pow(val, 2)
        val = val.sum(dim=-1)
        v, id = torch.topk(val, k=K, dim=1, largest=False)
#         print(id)
#         del val
#         del target
#         return id.data
        idx.copy_(id)
#         print(idx.size())
        return idx

    def nearest_neighbor(self, target):
        # L2 case
        idx = torch.LongTensor(target.size(0))
        if torch.cuda.is_available():
            target = target.cuda()
            # self.feat = self.feat.cuda()
            #all_feat = all_feat.cuda()

        # for i in range(target.size(0)):
        #     val = self.feat - \
        #         target[i].expand(self.feat.size(0), self.feat.size(1))
        #     val = val ** 2
        #     val = val.sum(1)
        #     v, id = val.min(0)
        #     idx[i] = id.item()
        val = self.feat.unsqueeze(0).expand(
            target.size(0), -1, -1) - target.unsqueeze(1)
#         print(val.size())
        #val = val ** 2
        val = torch.pow(val, 2)
        val = val.sum(dim=-1)
        v, id = val.min(1)
        # print(id.size())
        del val
        del target
        return id.data
        # return idx

    def compute_rank(self, input, target_idx):
        # input <---- a batch of vectors
        # targetIdx <----- ground truth index
        # all_feat <----- all the image features
        # return rank of input vectors in terms of rankings in distance to the
        # ground truth

        if torch.cuda.is_available():
            # input = input.cuda()
            target_idx = target_idx.cuda()
            # self.feat = self.feat.cuda()
            #all_feat = all_feat.cuda()
        target = self.feat[target_idx]

        value = target - input
        #value = value ** 2
        value = torch.pow(value, 2)
        value = value.sum(1)
        #rank = torch.LongTensor(value.size(0))
#         for i in range(value.size(0)):
#             val = self.feat - \
#                 input[i].expand(self.feat.size(0), self.feat.size(1))
#             val = val ** 2
#             val = val.sum(1)
#             rank[i] = val.lt(value[i]).sum()
        val = self.feat.unsqueeze(0).expand(
            target.size(0), -1, -1) - input.unsqueeze(1)
#         print(val.size())
        #val = val ** 2
        val = torch.pow(val, 2)
        val = val.sum(dim=-1)
        rank = val.lt(value.unsqueeze(1)).sum(dim=1)
        return rank

# class Ranker():

#     def __init__(self):
#         super(Ranker, self).__init__()
#         return

#     def k_nearest_neighbors(self, target, all_feat, K=10):
#         idx = torch.LongTensor(target.size(0), K)
#         if torch.cuda.is_available():
#             target = target.cuda()
#             all_feat = all_feat.cuda()

#         for i in range(target.size(0)):
#             val = all_feat - \
#                 target[i].expand(all_feat.size(0), all_feat.size(1))
#             val = val ** 2
#             val = val.sum(1)
#             v, id = torch.topk(val, k=K, dim=0, largest=False)
#             idx[i].copy_(id.view(-1))
#         return idx

#     def nearest_neighbor(self, target, all_feat):
#         # L2 case
#         idx = torch.LongTensor(target.size(0))
#         if torch.cuda.is_available():
#             target = target.cuda()
#             # self.feat = self.feat.cuda()
#             all_feat = all_feat.cuda()
#         for i in range(target.size(0)):
#             val = all_feat - \
#                 target[i].expand(all_feat.size(0), all_feat.size(1))
#             val = val ** 2
#             val = val.sum(1)
#             v, id = val.min(0)
#             idx[i] = id.item()
#         return idx

#     def compute_rank(self, input, target_idx, all_feat):
#         # input <---- a batch of vectors
#         # targetIdx <----- ground truth index
#         # all_feat <----- all the image features
#         # return rank of input vectors in terms of rankings in distance to the
#         # ground truth

#         if torch.cuda.is_available():
#             # input = input.cuda()
#             target_idx = target_idx.cuda()
#             # self.feat = self.feat.cuda()
#             all_feat = all_feat.cuda()
#         target = all_feat[target_idx]

#         value = target - input
#         value = value ** 2
#         value = value.sum(1)
#         rank = torch.LongTensor(value.size(0))
#         for i in range(value.size(0)):
#             val = all_feat - \
#                 input[i].expand(all_feat.size(0), all_feat.size(1))
#             val = val ** 2
#             val = val.sum(1)
#             rank[i] = val.lt(value[i]).sum()
#         return rank


def rl_rollout_search(behavior_model, target_model, aBot, behavior_im_state, target_im_state, curr_turn, max_turns, gt_img_idx, all_input, ranker, top_k=3, neg_num=5, lookAhead_Window=3, tau=0.2, rank_comparison=False):
    """
    Compute the rl loss of sampling the image. 
    Input:
       behavior_im_state: The image vector predicted from the qBot
       target_im_state: The image vector predicted from the target_qBot that is used to improve the policy
       k: the current turn
       dialog_turns: The maximum Dialog Turns
       all_input: The image feature vector in the dataset
    """
    with torch.no_grad():
        # 1. compute the top-k nearest neighbor for current state
        #top_k_act_img_idx = ranker.k_nearest_neighbors(target_im_state.data, all_input, K=top_k)

        top_k_act_img_idx = ranker.k_nearest_neighbors(
            target_im_state.data, K=top_k)
        # Seems that the top 10 images look completely the same. POTENTIAL PROBLEM!
        # print(top_k_act_img_idx)

        # 2. rollout for each candidate in top k, The estimated Q value from
        # Target Image
        # COMMENT: I think this step is not necessary as the initialization of our encoder is always the same.
        #target_hx_bk = target_model.hx
        rollout_values = []
        # Set aBot to eval
        # aBot.eval()

        # Extract the current observation and embedded history for aBot
        # TODO: Implement a caption and reset_caption functino in Encoder
        aBot_ori_quesTokens = aBot.encoder.questionTokens[:]
        aBot_ori_questionLens = aBot.encoder.questionLens[:]
        aBot_ori_questionEmbeds = aBot.encoder.questionEmbeds[:]

        aBot_ori_answerTokens = aBot.encoder.answerTokens[:]
        aBot_ori_answerLengths = aBot.encoder.answerLengths[:]
        aBot_ori_answerEmbeds = aBot.encoder.answerEmbeds[:]
        aBot_ori_answers = aBot.answers[:]

        aBot_ori_factEmbeds = aBot.encoder.factEmbeds[:]
        aBot_ori_questionRNNStates = aBot.encoder.questionRNNStates[:]
        aBot_ori_dialogRNNInputs = aBot.encoder.dialogRNNInputs[:]
        aBot_ori_dialogHiddens = aBot.encoder.dialogHiddens[:]

        # Extract the current observation and embedded history for target_qBot
        target_model_ori_quesTokens = target_model.encoder.questionTokens[:]
        target_model_ori_questionLens = target_model.encoder.questionLens[:]
        target_model_ori_questionEmbeds = target_model.encoder.questionEmbeds[
            :]
        target_model_ori_questions = target_model.questions[:]
        # print(target_model_ori_questions)

        target_model_ori_answerTokens = target_model.encoder.answerTokens[:]
        target_model_ori_answerLengths = target_model.encoder.answerLengths[:]
        target_model_ori_answerEmbeds = target_model.encoder.answerEmbeds[:]

        target_model_factEmbeds = target_model.encoder.factEmbeds[:]
        target_model_questionRNNStates = target_model.encoder.questionRNNStates[
            :]
        target_model_dialogRNNInputs = target_model.encoder.dialogRNNInputs[:]
        target_model_dialogHiddens = target_model.encoder.dialogHiddens[:]

        for i in range(top_k):
            # print(i)
            #         import gc
            #         for obj in gc.get_objects():
            #             try:
            #                 if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            #                     print(type(obj), obj.size())
            #                     print(obj)
            #             except:
            #                 pass
            # COMMENT: I think this step is not necessary as the initialization of our encoder is always the same.
            #         target_model.init_hid(args.batch_size)
            #         if torch.cuda.is_available():
            #             target_model.hx = target_model.hx.cuda()
            #         target_model.hx.data.copy_(target_hx_bk.data)

            # Get the act_img_idx
            act_img_idx = top_k_act_img_idx[:, i]
            act_img = all_input[act_img_idx]
            target_model.observe_im(act_img)

            score = 0
            if (curr_turn + lookAhead_Window) <= max_turns:
                upperbound = curr_turn + lookAhead_Window
            else:
                upperbound = max_turns
            if rank_comparison:
                prev_rank = ranker.compute_rank(
                    target_model.predictImage().data, gt_img_idx)
            for j in range(curr_turn, upperbound):
                #             print(len(target_model.questions))
                # Generate the questions and answers and update the question
                # and answer history state
                questions, quesLens = target_model.forwardDecode(
                    inference='sample')
                target_model.observe(j, ques=questions, quesLens=quesLens)
                aBot.observe(j, ques=questions, quesLens=quesLens)
                answers, ansLens = aBot.forwardDecode(inference='sample')
                aBot.observe(j, ans=answers, ansLens=ansLens)
                target_model.observe(j, ans=answers, ansLens=ansLens)

                # Get the predict embedding, added by Mingyang Zhou
                action = target_model.predictImage()
                #act_img_idx = ranker.nearest_neighbor(action.data, all_input)
                act_img_idx = ranker.nearest_neighbor(action.data)
                act_img = all_input[act_img_idx]
                target_model.observe_im(act_img)

                # Get the user_img_idx
                #ranking_candidate = ranker.compute_rank(action.data, gt_img_idx, all_input)
                ranking_candidate = ranker.compute_rank(
                    action.data, gt_img_idx)
                if rank_comparison:
                    score = score + (ranking_candidate - prev_rank)
                    prev_rank = ranking_candidate
                else:
                    score = score + ranking_candidate

                # Del all the local variables
                del action
                del ranking_candidate
            rollout_values.append(score)

            # In the last step Return the aBot and qBot to the original Status
            aBot.encoder.questionTokens.clear()
            aBot.encoder.questionTokens = aBot_ori_quesTokens.copy()
            aBot.encoder.questionLens.clear()
            aBot.encoder.questionLens = [
                element for element in aBot_ori_questionLens]
            aBot.encoder.questionEmbeds.clear()
            aBot.encoder.questionEmbeds = [
                element for element in aBot_ori_questionEmbeds]

            aBot.encoder.answerTokens.clear()
            aBot.encoder.answerTokens = [
                element for element in aBot_ori_answerTokens]
            aBot.encoder.answerLengths.clear()
            aBot.encoder.answerLengths = [
                element for element in aBot_ori_answerLengths]
            aBot.encoder.answerEmbeds.clear()
            aBot.encoder.answerEmbeds = [
                element for element in aBot_ori_answerEmbeds]
            aBot.answers.clear()
            aBot.answers = [element for element in aBot_ori_answers]

            aBot.encoder.factEmbeds.clear()
            aBot.encoder.factEmbeds = [x for x in aBot_ori_factEmbeds]
            aBot.encoder.questionRNNStates.clear()
            aBot.encoder.questionRNNStates = [
                x for x in aBot_ori_questionRNNStates]
            aBot.encoder.dialogRNNInputs.clear()
            aBot.encoder.dialogRNNInputs = [
                x for x in aBot_ori_dialogRNNInputs]
            aBot.encoder.dialogHiddens.clear()
            aBot.encoder.dialogHiddens = [x for x in aBot_ori_dialogHiddens]

            target_model.encoder.questionTokens.clear()
            target_model.encoder.questionTokens = [
                x for x in target_model_ori_quesTokens]
            target_model.encoder.questionLens.clear()
            target_model.encoder.questionLens = [
                x for x in target_model_ori_questionLens]
            target_model.encoder.questionEmbeds.clear()
            target_model.encoder.questionEmbeds = [
                x for x in target_model_ori_questionEmbeds]
            target_model.questions.clear()
            target_model.questions = [x for x in target_model_ori_questions]

            target_model.encoder.answerTokens.clear()
            target_model.encoder.answerTokens = [
                x for x in target_model_ori_answerTokens]
            target_model.encoder.answerLengths.clear()
            target_model.encoder.answerLengths = [
                x for x in target_model_ori_answerLengths]
            target_model.encoder.answerEmbeds.clear()
            target_model.encoder.answerEmbeds = [
                x for x in target_model_ori_answerEmbeds]

            target_model.encoder.factEmbeds.clear()
            target_model.encoder.factEmbeds = [
                x for x in target_model_factEmbeds]
            target_model.encoder.questionRNNStates.clear()
            target_model.encoder.questionRNNStates = [
                x for x in target_model_questionRNNStates]
            target_model.encoder.dialogRNNInputs.clear()
            target_model.encoder.dialogRNNInputs = [
                x for x in target_model_dialogRNNInputs]
            target_model.encoder.dialogHiddens.clear()
            target_model.encoder.dialogHiddens = [
                x for x in target_model_dialogHiddens]
        # Include the rollout_values
        rollout_values = torch.stack(rollout_values, dim=1)

        # compute greedy actions
        _, greedy_idx = rollout_values.min(dim=1)
        if torch.cuda.is_available():
            greedy_idx = greedy_idx.cuda()
        act_opt = torch.gather(top_k_act_img_idx, 1,
                               greedy_idx.cpu().unsqueeze(1)).view(-1)
     # 3. compute loss
    # compute the log prob for candidates
    dist_action = []
    act_input = all_input[act_opt]
    if torch.cuda.is_available():
        act_input = act_input.cuda()

    #dist = -torch.sum((behavior_state - act_emb) ** 2, dim=1) / tau
    # args.tau defines the temperature for softmax

    # Forward act_input to the shared embedding space
    act_input = target_model.forwardImage(act_input)
    dist = -torch.sum((behavior_im_state - act_input)**2, dim=1) / tau

    dist_action.append(dist.unsqueeze(1))

    for i in range(neg_num):
        neg_img_idx = torch.randint_like(torch.LongTensor(
            target_im_state.size(0)), 0, all_input.size(0) - 1)

        neg_input = all_input[neg_img_idx]
        if torch.cuda.is_available():
            neg_input = neg_input.cuda()
        # Project neg_input to shared_embedding space
        neg_input = target_model.forwardImage(neg_input)
        dist = -torch.sum((behavior_im_state - neg_input) ** 2, dim=1) / tau
        #dist = -torch.sum((behavior_im_state-act_input)**2, dim=1)
        dist_action.append(dist.unsqueeze(1))

    dist_action = torch.cat(dist_action, dim=1)
    label_idx = torch.LongTensor(target_im_state.size(0)).fill_(0)
    if torch.cuda.is_available():
        label_idx = label_idx.cuda()
    loss = torch.nn.functional.cross_entropy(
        input=dist_action, target=Variable(label_idx))
    # Remove dist_action
    del dist_action
    if torch.cuda.is_available():
        gt_img_idx = gt_img_idx.cuda()
#     target_emb = all_input[gt_img_idx]
#     if torch.cuda.is_available():
#         target_emb = target_emb.cuda()

    target_emb = ranker.feat[gt_img_idx]
    reg = torch.sum((behavior_im_state - Variable(target_emb))
                    ** 2, dim=1).mean()

    return act_opt, reg + loss
