import tensorflow as tf

from model.features import features
from utils.logger import get_logger

logger = get_logger(__name__)


class PerceptronModel(object):
    def __init__(self, args):
        self.config = config = args.config
        self.batch_size = batch_size = config.perceptron_batch_size
        self.lr = lr = config.perceptron_lr
        # Epoch variable and its update op
        self.epoch = tf.Variable(0, trainable=False)
        self.epoch_incr = self.epoch.assign(self.epoch + 1)
        self.global_step = tf.Variable(0, trainable=False)
        self.num_classes = num_classes = config.num_classes
        # Parameters to be learnt
        self.weights = weights = tf.get_variable(
            "weight", [len(features)],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        self.bias = bias = tf.get_variable(
            "bias", [1],
            initializer=tf.zeros_initializer()
        )
        # Input Placeholders
        self.inputs = tf.placeholder(tf.float32, [batch_size, num_classes, len(features)])
        self.labels = tf.placeholder(tf.int64, [batch_size])
        one_hot_labels = tf.one_hot(self.labels, num_classes)

        # Perceptron
        self.logits = tf.einsum('ijk,k->ij', self.inputs, weights) + bias
        self.softmax = tf.nn.softmax(self.logits)
        self.loss = tf.nn.softmax_cross_entropy_with_logits(
            logits=self.logits,
            labels=one_hot_labels
        )
        self.cost = tf.reduce_sum(self.loss) / batch_size
        opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
        self.updates = opt.minimize(self.cost, global_step=self.global_step)

