import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


class Policy(nn.Module):

    def __init__(self, state_space):
        super(Policy, self).__init__()
        self.state_space = state_space  # The dimension of state_space
        self.action_space = 2  # guess or ask_questions
        self.l1_dimension = 128
        self.dropout = nn.Dropout(p=0.6)

        self.policy_out = nn.Sequential(
            nn.Linear(self.state_space, self.l1_dimension),
            self.dropout,
            nn.ReLU(),
            nn.Linear(self.l1_dimension, self.action_space),
            nn.Softmax(dim=-1)
        )
        # Initialize the Action Results log probs
        self.action_log_probs = []
        self.gamma = 0.99
        #self.mask_Done_games = []

    def forward(self, diag_state_var):
        out = self.policy_out(diag_state_var)
        return out

    def init_model_weights(self):
        ih = (param.data for name, param in self.named_parameters()
              if 'weight_ih' in name)
        hh = (param.data for name, param in self.named_parameters()
              if 'weight_hh' in name)
        b = (param.data for name, param in self.named_parameters() if 'bias' in name)
        for t in ih:
            nn.init.xavier_uniform(t)
        for t in hh:
            nn.init.orthogonal(t)
        for t in b:
            nn.init.constant(t, 0)

    def update_action_history(self, action_log_prob):
        """
        Store the log probs of sampled actions into the list of action_history
        """
        self.action_log_probs.append(action_log_prob)

    def reset_reinforce(self):
        """
        Reset the mask history and log_probs history
        """
        self.action_log_probs = []

    # def reinforce_guess(self,reward_list):
    def reinforce_guess(self, reward_list):
        """
        Compute loss using REINFORCE on log probabilities of tokens sampled from decoder
        RNN, scaled by input 'reward'. This is a post process, the saved_log_probs will be 
        a two-D list, which include the saved_log_probs in two rounds.
        Note that an earlier call to forwardDecode must have been made in order to have samples 
        for which REINFORCE can be applied. These samples are staored in 'self.saved_log_probs_all_dialog'
        """
        loss = 0
        # TODO: Implement the new reinforcement
        if len(self.action_log_probs) == 0:
            raise RuntimeError("Reinforce called without sampling in Decoder")
        # Initialize a index i
        # if discount:
        # i = 0
        # for (saved_log_probs, reward) in zip(self.action_log_probs, reward_list):
        #     # Compute the discounted Reward
        #     R = 0
        #     for j in range(len(reward_list) - 1, i - 1, -1):
        #         R = reward_list[j] + self.gamma * R
        #     for t, log_prob in enumerate(saved_log_probs):
        #         loss += -1 * log_prob * R
        #     i += 1
        for i, saved_log_probs in enumerate(self.action_log_probs):
            for log_prob in saved_log_probs:
                loss += -1 * log_prob * reward_list[i]
        return loss
