import numpy as np
import tensorflow as tf

np.random.seed(2)
tf.set_random_seed(2) 

GAMMA = 0.9   
LR_A = 0.001   

class Agent(object):
    def __init__(self, sess, n_features, n_actions,n_units, lr=0.001):
        with tf.variable_scope('Agent'):
            self.sess = sess

            self.s = tf.placeholder(tf.float32, [None, n_features], "state")
            self.a = tf.placeholder(tf.int32, [None,], "act")
            self.td_error = tf.placeholder(tf.float32, None, "td_error")  # TD_error

            l1 = tf.layers.dense(
                inputs=self.s,
                units=n_units,    # number of hidden units
                activation=tf.nn.relu,
                kernel_initializer=tf.random_normal_initializer(0., .1),    # weights
                bias_initializer=tf.constant_initializer(0.1),  # biases
                name='l1',
                reuse=tf.AUTO_REUSE
            )

            self.acts_prob = tf.layers.dense(
                inputs=l1,
                units=n_actions,    # output units
                activation=tf.nn.softmax,   # get action probabilities
                kernel_initializer=tf.random_normal_initializer(0., .1),  # weights
                bias_initializer=tf.constant_initializer(0.1),  # biases
                name='acts_prob',
                reuse=tf.AUTO_REUSE
            )
            self.actions = tf.multinomial(self.acts_prob, 1, name='action_sampling')
        with tf.variable_scope('exp_v'):
            log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.acts_prob,labels=self.a)
            self.exp_v = tf.reduce_mean(log_prob * self.td_error) 

        with tf.variable_scope('train',reuse=tf.AUTO_REUSE):
            self.train_op = tf.train.AdamOptimizer(lr).minimize(-self.exp_v)  # minimize(-exp_v) = maximize(exp_v)

    def learn(self, s, a, td):
        keep_idx=a.reshape(-1, )
        feed_dict = {self.s: s, self.a: keep_idx, self.td_error: td}
        _, exp_v = self.sess.run([self.train_op, self.exp_v], feed_dict)
        return exp_v

    def choose_action(self, s):
        actions_ = self.sess.run(self.actions, feed_dict={self.s: s})
        return  actions_


