import torch
import torch.nn as nn
from visdial.models.agent import Agent
import visdial.models.encoders.hre as hre_enc
import visdial.models.encoders.hre_20im as hre_20_enc
import visdial.models.encoders.hre_20im_vse as hre_20_vse_enc
import visdial.models.encoders.hre_20im_vse_2 as hre_20_vse_enc_2
import visdial.models.encoders.hre_20im_vse_3 as hre_20_vse_enc_3
import visdial.models.encoders.hre_20im_vse_4 as hre_20_vse_enc_4
import visdial.models.encoders.hre_vse as hre_vse_enc
import visdial.models.decoders.gen as gen_dec
import visdial.models.reinforce.policy_gradient_agent as Policy_Agent
import visdial.loss.loss_utils as loss_utils
from utils import utilities as utils
from torch.distributions import Categorical


class Questioner(Agent):

    def __init__(self, encoderParam, decoderParam, imgFeatureSize=0,
                 verbose=1, shared_embedding_space=512):
        '''
            Q-Bot Model

            Uses an encoder network for input sequences (questions, answers and
            history) and a decoder network for generating a response (question).
        '''
        super(Questioner, self).__init__()
        self.encType = encoderParam['type']
        self.decType = decoderParam['type']
        self.dropout = encoderParam['dropout']
        self.rnnHiddenSize = encoderParam['rnnHiddenSize']
        self.imgFeatureSize = imgFeatureSize
        self.shared_embedding_space = shared_embedding_space  # Added by Mingyang Zhou
        self.new_questioner = False
        self.imgEncodingMode = encoderParam['imgEncodingMode']
        self.gamma = 0.99
        self.fuseType = encoderParam['fuseType']
        encoderParam = encoderParam.copy()
        encoderParam['isAnswerer'] = False

        # Encoder
        if verbose:
            print('Encoder: ' + self.encType)
            print('Decoder: ' + self.decType)
        # # Add a new Encoder Type
        # if 'vse' in self.encType:
        #     if '20imguess' in self.encType:
        #         encoderParam['im_questioner'] = True
        #         # Update the questioner type
        #         self.new_questioner = True
        #     self.encoder = hre_20_vse_enc.Encoder(**encoderParam)
        #     # Initialize a image embedding_layer
        #     self.im_embedding = nn.Linear(
        #         self.imgFeatureSize, self.shared_embedding_space)  # Added by Mingyang Zhou
        #     self.text_embedding = nn.Linear(
        #         self.rnnHiddenSize * 2, self.shared_embedding_space)
        if '20imguess' in self.encType:  # Added by Mingyang Zhou
            encoderParam['im_questioner'] = True
            # Update the questioner type
            self.new_questioner = True
            if encoderParam['imgEncodingMode'] == "state-adapt":
                encoderParam['useIm'] = 'single-concate'
            elif encoderParam['imgEncodingMode'] == "dual-view":
                encoderParam['useIm'] = 'dual-concate'
            if 'vse' in self.encType:
                if encoderParam['fuseType'] == 1:
                    self.encoder = hre_20_vse_enc.Encoder(**encoderParam)
                    self.text_embedding = nn.Linear(
                        self.rnnHiddenSize * 2, self.imgFeatureSize)
                elif encoderParam['fuseType'] == 2:
                    #print("fusetype is 2")
                    self.encoder = hre_20_vse_enc_2.Encoder(**encoderParam)
                    self.text_embedding = nn.Linear(
                        self.rnnHiddenSize * 2, self.imgFeatureSize)
                elif encoderParam['fuseType'] == 3:
                    #print("fusetype is 3")
                    self.encoder = hre_20_vse_enc_3.Encoder(**encoderParam)
                    self.text_embedding = nn.Linear(
                        self.rnnHiddenSize * 2, self.imgFeatureSize)
                elif encoderParam['fuseType'] == 4:
                    #print("fusetype is 3")
                    self.encoder = hre_20_vse_enc_4.Encoder(**encoderParam)
                    #self.encoder.featureNet.weight = self.text_embedding
                # Initialize a image embedding_layer
                # self.im_embedding = nn.Linear(
                # self.imgFeatureSize, self.shared_embedding_space)  # Added by
                # Mingyang Zhou

            else:
                self.encoder = hre_20_enc.Encoder(**encoderParam)
        elif 'hre' in self.encType:
            if 'vse' in self.encType:
                self.encoder = hre_vse_enc.Encoder(**encoderParam)
                # Initialize a image embedding_layer
                # self.im_embedding = nn.Linear(
                # self.imgFeatureSize, self.shared_embedding_space)  # Added by
                # Mingyang Zhou
                self.text_embedding = nn.Linear(
                    self.rnnHiddenSize * 2, self.imgFeatureSize)
            else:
                self.encoder = hre_enc.Encoder(**encoderParam)
        else:
            raise Exception('Unknown encoder {}'.format(self.encType))

        # Decoder
        if 'gen' == self.decType:
            self.decoder = gen_dec.Decoder(**decoderParam)
        else:
            raise Exception('Unkown decoder {}'.format(self.decType))

        # # Initialize the Policy, modified by Mingyang Zhou
        # The state will be the hidden state from encoder
        if self.imgEncodingMode == "state-adapt":
            # if self.fuseType == 2:
            #     self.policy_agent = Policy_Agent.Policy(
            #         3 * encoderParam['rnnHiddenSize'])
            # else:
            self.policy_agent = Policy_Agent.Policy(
                2 * encoderParam['rnnHiddenSize'] + self.imgFeatureSize)
            self.policy_agent_2 = Policy_Agent.Policy(
                2 * encoderParam['rnnHiddenSize'])
        elif self.imgEncodingMode == "dual-view":
            self.policy_agent = Policy_Agent.Policy(
                3 * encoderParam['rnnHiddenSize'] + self.imgFeatureSize)
        else:
            self.policy_agent = Policy_Agent.Policy(
                2 * encoderParam['rnnHiddenSize'])

        # Share word embedding parameters between encoder and decoder
        self.decoder.wordEmbed = self.encoder.wordEmbed

        # Setup feature regressor
        if self.imgFeatureSize:
            self.featureNet = nn.Linear(self.rnnHiddenSize,
                                        self.imgFeatureSize)
            self.featureNetInputDropout = nn.Dropout(0.5)

        # Initialize weights
        utils.initializeWeights(self.encoder)
        utils.initializeWeights(self.decoder)
        utils.initializeWeights(self.policy_agent)
        if self.imgEncodingMode == "state-adapt":
            utils.initializeWeights(self.policy_agent_2)

        # TODO: Initialize im_embedding and text_embedding layer
        # Initilaize reward_list
        self.reward_list = []
        self.final_rewards = None
        self.reset()

    def reset(self):
        '''Delete dialog history.'''
        self.questions = []
        self.encoder.reset()
        self.reward_list = []
        self.final_rewards = None

    def freezeFeatNet(self):
        nets = [self.featureNet]
        for net in nets:
            for param in net.parameters():
                param.requires_grad = False

    def observe(self, round, ques=None, **kwargs):
        '''
        Update Q-Bot percepts. See self.encoder.observe() in the corresponding
        encoder class definition (hre).
        '''
        assert 'image' not in kwargs, "Q-Bot does not see image"
        if ques is not None:
            # print(len(self.questions))
            assert round == len(self.questions), \
                "Round number does not match number of questions observed"
            self.questions.append(ques)

        self.encoder.observe(round, ques=ques, **kwargs)
    # Added by Mingyang Zhou

    def observe_im(self, image):
        '''
        Allow the agent to see the 20 images for the game
        '''
        self.encoder.observe_im(image)

    def forward(self):
        '''
        Forward pass the last observed question to compute its log
        likelihood under the current decoder RNN state.
        '''
        encStates = self.encoder()
        if len(self.questions) == 0:
            raise Exception('Must provide question if not sampling one.')
        decIn = self.questions[-1]

        logProbs = self.decoder(encStates, inputSeq=decIn)
        return logProbs

    def forwardDecode(self, inference='sample', beamSize=1, maxSeqLen=20):
        '''
        Decode a sequence (question) using either sampling or greedy inference.
        A question is decoded given current state (dialog history). This can
        be called at round 0 after the caption is observed, and at end of every
        round (after a response from A-Bot is observed).

        Arguments:
            inference : Inference method for decoding
                'sample' - Sample each word from its softmax distribution
                'greedy' - Always choose the word with highest probability
                           if beam size is 1, otherwise use beam search.
            beamSize  : Beam search width
            maxSeqLen : Maximum length of token sequence to generate
        '''
        encStates = self.encoder()
        questions, quesLens = self.decoder.forwardDecode(
            encStates,
            maxSeqLen=maxSeqLen,
            inference=inference,
            beamSize=beamSize)
        # questions, quesLens, logProbs = self.decoder.forwardDecode(
        #     encStates,
        #     maxSeqLen=maxSeqLen,
        #     inference=inference,
        #     beamSize=beamSize)

        return questions, quesLens

    def predictImage(self):
        '''
        Predict/guess an fc7 vector given the current conversation history. This can
        be called at round 0 after the caption is observed, and at end of every round
        (after a response from A-Bot is observed).
        '''
        encState = self.encoder()
        # h, c from lstm
        h, c = encState
        if self.imgEncodingMode in ['dual-view', 'state-adapt']:
            return self.encoder.featureNet(self.featureNetInputDropout(h[-1]))
        else:
            return self.featureNet(self.featureNetInputDropout(h[-1]))

    def multimodalEmbedding(self, image, criterion_vse):
        """
        Project the dialog history and Image to the Shared Embedding Space and
        Compute the pairwise ranking loss
        Input:
            image: (batchSize, imageFeaturesize)
            criterion_vse: loss function to optimize joint space learning
        Output:
            The featLoss is the
        """
        # Project the Dialog State to the shared space
        # encState = self.encoder()
        # dialogState = self.encoder.fuseDialogCap()
        # dialogEmbedding = loss_utils.l2norm(self.text_embedding(dialogState))
        dialogEmbedding = self.multimodalpredictText()

        # Project the Image to the shared Space
        # imageEmbedding = loss_utils.l2norm(self.im_embedding(image))
        imageEmbedding = self.multimodalpredictIm(image)
        # Compute the Embedding Retrieval Loss
        featLoss = criterion_vse(imageEmbedding, dialogEmbedding)

        return featLoss

    def multimodalpredictText(self):
        encState = self.encoder()

        # if self.fuseType == 4:
        #     dialogState = self.encoder.encState[0][-1]
        #     dialogEmbedding = loss_utils.l2norm(
        #         self.encoder.featureNet(dialogState))
        # else:
        dialogState = self.encoder.fuseDialogCap()
        if self.fuseType in [1, 2, 4] and self.encType == "hre-ques-20imguess-vse":
            dialogEmbedding = loss_utils.l2norm(
                self.encoder.featureNet(dialogState))
        else:
            dialogEmbedding = loss_utils.l2norm(
                self.text_embedding(dialogState))

        return dialogEmbedding

    def multimodalpredictIm(self, image):
        # imageEmbedding = loss_utils.l2norm(self.im_embedding(image))
        # imageEmbedding = loss_utils.l2norm(image)
        imageEmbedding = image

        return imageEmbedding

        # def predictAction(self):
        #     '''
        #     Predict the action for next round
        #     '''
        #     vis_state, _ = self.encoder()
        #     return policy_agent(vis_state)

    def reinforce(self, reward):
        # Propogate reinforce function call to decoder
        return self.decoder.reinforce(reward)

    def reinforce_guess(self, policy='policy_gradient', discount=True, normalize=True):
        # Propogate reinforce function call to decoder
        loss_rl_action = 0
        loss_rl_question = 0
        # Update the reward_list
        R = 0
        R_list = []
        for r in reversed(self.reward_list):
            if discount:
                R = r + self.gamma * R
            else:
                R = r
            R_list.insert(0, R.unsqueeze(0))
        self.final_rewards = torch.cat(R_list)
        # print("The final rewards are:")
        # print(self.final_rewards)
        # print("The Mask is:")
        # reward_mask = torch.zeros_like(self.final_rewards)
        # reward_mask[torch.nonzero(self.final_rewards).squeeze()] = 1
        # print(reward_mask)

        # Normalize
        # if normalize:
        #     self.final_rewards = torch.cat(R_list)
        #     self.final_rewards = (self.final_rewards -
        #                           torch.mean(self.final_rewards, dim=0))

        if policy == 'policy_gradient':
            loss_rl_action = self.policy_agent.reinforce_guess(
                self.final_rewards)  # Modified by Mingyang Zhou
        loss_rl_question = self.decoder.reinforce_guess(
            self.final_rewards)  # Modified by Mingyang Zhou

        return loss_rl_action + loss_rl_question

    def reset_reinforce(self):
        # Propagate reinforce function call to decoder
        self.reward_list = []
        self.policy_agent.reset_reinforce()  # Reset the action_log_probs_history
        self.decoder.reset_reinforce()

    def update_reward(self, reward):
        # Update_reward
        self.reward_list.append(reward)

    def compute_game_rewards(self):
        """
        Compute the Masked Mean Game Rewards for log purpose
        """
        nonzero_mask = torch.zeros_like(self.final_rewards)
        nonzero_mask[self.final_rewards != 0] = 1

        # Compute a weighted mean
        game_rewards = torch.sum(self.final_rewards, dim=0) / \
            torch.sum(nonzero_mask, dim=0)
        # print(game_rewards)
        return game_rewards.mean().data.item()

    def determine_action(self, done, policy="random", imgEncodingMode="lstm"):
        """
        Determine whether to make a Guess or Raise a Question
        Output:
        A tensor with size (batch_size,1)
        """
        if policy == "random":
            guess_epsilon = 0.9
            guess_result = torch.rand(done.size()) - 0.9
            guess_result[guess_result <= 0] = 0
            guess_result[guess_result > 0] = 1
            guess_result = guess_result.cuda()
            # Elementwise multiply 1-Done with guess_result
            guess_result = (1 - done) * guess_result
        elif policy == "policy_gradient":
            vis_state, _ = self.encoder()
            if imgEncodingMode in ["state-adapt", "dual-view"]:
                if self.fuseType == 2:
                    action_prob = self.policy_agent_2(
                        vis_state.view(-1, 2 * self.rnnHiddenSize))
                else:
                    input_list = [
                        vis_state.view(-1, 2 * self.rnnHiddenSize), self.encoder.raw_imageEmbed]
                    action_prob = self.policy_agent(
                        torch.cat(input_list, -1))
            else:
                action_prob = self.policy_agent(
                    vis_state.view(-1, 2 * self.rnnHiddenSize))
            # print(action_prob.type())
            # Now select actions
            c = Categorical(action_prob)
            guess_result = c.sample()
            # print(guess_result.type())
            # Convert guess_result to int
            guess_result = guess_result.type(torch.cuda.FloatTensor)
            guess_result = (1 - done) * guess_result

            # Append the log_results to history
            log_probs_guess_result = c.log_prob(guess_result)
            self.policy_agent.update_action_history(log_probs_guess_result)

        return guess_result
