#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import random

import numpy as np
from six.moves import xrange
import tensorflow as tf

import data_utils2
from seq2seq_model_utils import *

class S2SModel(object):
    def __init__(self,source_vocab_size,target_vocab_size,buckets,size,
                 num_layers,max_gradient_norm,batch_size,learning_rate,
                 num_samples=0,forward_only=False,beam_search=True,beam_size=10,attention=True,gamma=None,bidirection=True):
        """Create the model.
            Args:
              source_vocab_size: size of the source vocabulary.
              target_vocab_size: size of the target vocabulary.
              buckets: a list of pairs (I, O), where I specifies maximum input length
                that will be processed in that bucket, and O specifies maximum output
                length. Training instances that have inputs longer than I or outputs
                longer than O will be pushed to the next bucket and padded accordingly.
                We assume that the list is sorted, e.g., [(2, 4), (8, 16)].
              size: number of units in each layer of the model.
              num_layers: number of layers in the model.
              max_gradient_norm: gradients will be clipped to maximally this norm.
              batch_size: the size of the batches used during training;
                the model construction is independent of batch_size, so it can be
                changed after initialization if this is convenient, e.g., for decoding.
              learning_rate: learning rate to start with.
              learning_rate_decay_factor: decay learning rate by this much when needed.
              use_lstm: if true, we use LSTM cells instead of GRU cells.
              num_samples: number of samples for sampled softmax.
              forward_only: if set, we do not construct the backward pass in the model.
        """
        self.source_vocab_size = source_vocab_size
        self.target_vocab_size = target_vocab_size
        self.buckets = buckets
        self.batch_size = batch_size
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.gamma = gamma#tf.Variable(float(gamma), trainable=False)


        # If we use sampled softmax, we need an output projection.
        output_projection = None
        softmax_loss_function = None
        if num_samples > 0 and num_samples < self.target_vocab_size:
            print('mappsing: {}'.format(num_samples))
            w_t = tf.get_variable("proj_w",[self.target_vocab_size, size])
            w = tf.transpose(w_t)
            b = tf.get_variable("proj_b",[self.target_vocab_size])
            output_projection = (w, b)

            def sampled_loss(inputs, labels):
                labels = tf.reshape(labels, [-1, 1])
                return tf.nn.sampled_softmax_loss(
                        w_t, b, inputs, labels,
                        num_samples, self.target_vocab_size )

            softmax_loss_function = sampled_loss

        # LSTM cells
        cell = tf.nn.rnn_cell.BasicLSTMCell(size, state_is_tuple=False)
        cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=0.5)
        if bidirection:
            decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(size*2, state_is_tuple=False) 
            cell=(decoder_cell,cell)


        if num_layers>1:
            cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=False)

            # seq2seq_f
        def seq2seq_f(encoder_inputs, decoder_inputs, do_decode,beam_search=False,beam_size=0):
            if attention:
                print('Attention Model')
                return embedding_attention_seq2seq(  #tf.nn.seq2seq.
                    encoder_inputs, decoder_inputs, cell,
                    num_encoder_symbols=source_vocab_size,
                    num_decoder_symbols=target_vocab_size,
                    embedding_size=size,
                    output_projection=output_projection,
                    feed_previous=do_decode,
                    dtype=tf.float32,
                    beam_search=beam_search,
                    beam_size=beam_size,
                    gamma=self.gamma,
                    bidrectional=bidirection
                )
            else:
                print('Simple model')
                return embedding_rnn_seq2seq(
                    encoder_inputs, decoder_inputs, cell,
                    num_encoder_symbols=source_vocab_size,
                    num_decoder_symbols=target_vocab_size,
                    embedding_size=size,
                    output_projection=output_projection,
                    feed_previous=do_decode,
                    dtype=tf.float32,
                    beam_search=beam_search,
                    beam_size=beam_size,
                    gamma=self.gamma)

        # inputs
        self.encoder_inputs = []
        self.decoder_inputs = []
        self.decoder_weights  = []

        for i in xrange(buckets[-1][0]):
            self.encoder_inputs.append(tf.placeholder(tf.int32,shape=[None],
                                                      name='encoder_input_{}'.format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                      name="decoder{0}".format(i)))
            self.decoder_weights.append(tf.placeholder(tf.float32, shape=[None],
                                                      name="weight{0}".format(i)))

        # Our targets are decoder inputs shifted by one.
        targets = [self.decoder_inputs[i + 1] for i in range(buckets[-1][1])]
        #targets = [self.decoder_inputs[i + 1] for i in xrange(len(self.decoder_inputs) - 1)]

        print("forward_only:", forward_only)
        print("size:",size)
        print('buckets:', buckets)
        print("beam search:",beam_search)
        print('self.decoder_inputs:', len(self.decoder_inputs))
        print('self.decoder_weight', len(self.decoder_weights))
        print('self.encoder_inputs:', len(self.encoder_inputs))
        print('targets:', len(targets))
        print('softmax_loss_function', softmax_loss_function)
        print("bidirection:",bidirection)

        if forward_only:
            if beam_search:
                self.outputs,self.beam_path, self.beam_symbol,self.log_beam_probs = decode_model_with_buckets(
                    self.encoder_inputs, self.decoder_inputs, targets,
                    self.decoder_weights, buckets, lambda x, y: seq2seq_f(x, y, do_decode=True,beam_search=beam_search,beam_size=beam_size),
                    softmax_loss_function=softmax_loss_function)
            else:
                self.outputs, self.losses = model_with_buckets(
                    self.encoder_inputs,self.decoder_inputs,targets,
                    self.decoder_weights,buckets,lambda x, y: seq2seq_f(x, y, do_decode=True,beam_search=beam_search,beam_size=beam_size),
                    softmax_loss_function=softmax_loss_function)

                if output_projection is not None:
                    for b in range(len(buckets)):
                        self.outputs[b] = [
                            tf.matmul(output,output_projection[0]) + output_projection[1]
                            for output in self.outputs[b]]

        else:
            self.outputs, self.losses = model_with_buckets(#tf.nn.seq2seq.model_with_buckets(#
                    self.encoder_inputs, self.decoder_inputs, targets,
                    self.decoder_weights, buckets,
                    lambda x, y: seq2seq_f(x, y, do_decode=False,beam_search=False,beam_size=0),
                    softmax_loss_function=softmax_loss_function)


        params = tf.trainable_variables()
        if not forward_only:
            self.gradient_norms = []
            self.updates = []
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
            for b in xrange(len(buckets)):
                gradients = tf.gradients(self.losses[b], params)
                clipped_gradients, norm = tf.clip_by_global_norm(gradients,
                                                                 max_gradient_norm)
                self.gradient_norms.append(norm)
                self.updates.append(opt.apply_gradients(
                    zip(clipped_gradients, params)))
        self.saver = tf.train.Saver(tf.global_variables())



    def step(self,session, encoder_inputs, decoder_inputs, target_weights,
             bucket_id, forward_only, beam_search=None):
        encoder_size, decoder_size = self.buckets[bucket_id]
        if len(encoder_inputs) != encoder_size:
            raise ValueError("Encoder length must be equal to the one in bucket,"
                " %d != %d." % (len(encoder_inputs), encoder_size))
        if len(decoder_inputs) != decoder_size:
            raise ValueError("Decoder length must be equal to the one in bucket,"
                " %d != %d." % (len(decoder_inputs), decoder_size))
        if len(target_weights) != decoder_size:
            raise ValueError("Weights length must be equal to the one in bucket,"
                " %d != %d." % (len(target_weights), decoder_size))

        input_feed = {}
        for i in range(encoder_size):
            input_feed[self.encoder_inputs[i].name] = encoder_inputs[i]
        for i in range(decoder_size):
            input_feed[self.decoder_inputs[i].name] = decoder_inputs[i]
            input_feed[self.decoder_weights[i].name] = target_weights[i]

        last_target = self.decoder_inputs[decoder_size].name
        input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)

        if not forward_only:
            #print('not forward_only')
            output_feed = [self.updates[bucket_id],
                           self.gradient_norms[bucket_id],
                           self.losses[bucket_id]]
        else:
            if beam_search:
                #print("beam_search")
                output_feed=[self.beam_path[bucket_id],
                             self.beam_symbol[bucket_id],
                             self.log_beam_probs[bucket_id]]
            else:
                print("simple argmax")
                output_feed = [self.losses[bucket_id]]

            for i in range(decoder_size):
                output_feed.append(self.outputs[bucket_id][i])

        outputs = session.run(output_feed, input_feed)
        if not forward_only:
            return outputs[1], outputs[2],None
        else:
            if beam_search:
                return outputs[0],outputs[1],outputs[2],outputs[3:]
            else:
                return None, outputs[0], outputs[1:]




    def get_batch(self, data, bucket_id,en_word2index,de_word2index):
        encoder_size, decoder_size = self.buckets[bucket_id]
        encoder_inputs, decoder_inputs = [], []

        for encoder_input, decoder_input in data:

            encoder_input = data_utils2.sentence_indice(encoder_input.split(' '),en_word2index)

            decoder_input = data_utils2.sentence_indice(decoder_input.split(' '),de_word2index)
            # Encoder
            #print("encoder_input:",encoder_input)
            #print("decoder_input",decoder_input)
            encoder_pad = [en_word2index[data_utils2.PAD]] * (encoder_size - len(encoder_input))
            encoder_inputs.append((encoder_input + encoder_pad))
            # Decoder
            decoder_pad_size = decoder_size - len(decoder_input) - 2
            decoder_inputs.append([de_word2index[data_utils2.GO]] + decoder_input +
                                  [de_word2index[data_utils2.EOS]] +
                                  [de_word2index[data_utils2.PAD]] * decoder_pad_size)


        batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
        # batch encoder
        #print('encoder size',encoder_size)
        #print('ba tch size',self.batch_size)
        for i in range(encoder_size):
            batch_encoder_inputs.append(np.array(
                [encoder_inputs[j][i] for j in range(self.batch_size)],
                dtype=np.int32
            ))
        # batch decoder
        for i in range(decoder_size):
            batch_decoder_inputs.append(np.array(
                [decoder_inputs[j][i] for j in range(self.batch_size)],dtype=np.int32))

            batch_weight = np.ones(self.batch_size, dtype=np.float32)
            for j in range(self.batch_size):
                if i < decoder_size - 1:
                    target = decoder_inputs[j][i + 1]
                if i == decoder_size - 1 or target == de_word2index[data_utils2.PAD]:
                    batch_weight[j] = 0.0
            batch_weights.append(batch_weight)

        return batch_encoder_inputs, batch_decoder_inputs, batch_weights
