from keras.layers import Dense, Embedding, Activation
from keras.models import Sequential, Model
from keras.layers import LSTM, Bidirectional
from keras.layers import RepeatVector
from keras.layers import TimeDistributed
import keras.backend as K
from keras.engine.topology import Layer, InputSpec
from keras.utils.vis_utils import plot_model
import numpy as np
from keras.callbacks import CSVLogger
import pandas as pd
import utils
import pdb


def SAE_LSTM_tf(x, w2n, n2w, dropout, vocab_size, emb_dim, emb_matrix, use_pretrain=0):
    """
    Build a sequence anto-encoder and auto-decoder using LSTM;
    input_shape should be (emb_dim, max_len)
    :return:

    May consider trying Bi-direction LSTM
    """
    ##
    # w2n['padding'] = 0
    import tensorflow as tf
    from sklearn.model_selection import train_test_split

    # padding x
    x = np.array(x)
    y = [i[i!=0] for i in x]
    ##
    #char2numY['<GO>'] = len(char2numY)
    #num2charY = dict(zip(char2numY.values(), char2numY.keys()))
    #y = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in y]
    #print(''.join([num2charY[y_] for y_ in y[4]]))
    #y = np.array(y)
    ##
    x_seq_length = len(x[0])
    y_seq_length = len(y[0])

    def batch_data(x, y, batch_size):
        shuffle = np.random.permutation(len(x))
        start = 0
        #     from IPython.core.debugger import Tracer; Tracer()()
        x = x[shuffle]
        y = y[shuffle]
        while start + batch_size <= len(x):
            yield x[start:start + batch_size], y[start:start + batch_size]
            start += batch_size

    ####
    epochs = 2
    batch_size = 128
    nodes = 128

    tf.reset_default_graph()
    sess = tf.InteractiveSession()

    # Tensor where we will feed the data into graph
    inputs = tf.placeholder(tf.int32, (None, x_seq_length), 'inputs')
    outputs = tf.placeholder(tf.int32, (None, None), 'output')
    targets = tf.placeholder(tf.int32, (None, None), 'targets')

    # Embedding layers
    input_embedding = tf.Variable(tf.random_uniform((len(w2n)+2, emb_dim), -1.0, 1.0), name='enc_embedding')
    output_embedding = tf.Variable(tf.random_uniform((len(w2n)+2, emb_dim), -1.0, 1.0), name='dec_embedding')
    date_input_embed = tf.nn.embedding_lookup(input_embedding, inputs)
    date_output_embed = tf.nn.embedding_lookup(output_embedding, outputs)

    with tf.variable_scope("encoding") as encoding_scope:
        lstm_enc = tf.contrib.rnn.BasicLSTMCell(nodes)
        _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=date_input_embed, dtype=tf.float32)

    with tf.variable_scope("decoding") as decoding_scope:
        lstm_dec = tf.contrib.rnn.BasicLSTMCell(nodes)
        dec_outputs, _ = tf.nn.dynamic_rnn(lstm_dec, inputs=date_output_embed, initial_state=last_state)

    # connect outputs to
    logits = tf.contrib.layers.fully_connected(dec_outputs, num_outputs=len(w2n), activation_fn=None)
    with tf.name_scope("optimization"):
        # Loss function
        masks = tf.sequence_mask(target_sequence_length, max_target_sequence_length,
                                 dtype=tf.float32, name='masks')
        loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length]))
        # Optimizer
        optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss)

    dec_outputs.get_shape().as_list()
    last_state[0].get_shape().as_list()
    inputs.get_shape().as_list()
    date_input_embed.get_shape().as_list()
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=42)

    sess.run(tf.global_variables_initializer())
    epochs = 10
    for epoch_i in range(epochs):
        start_time = time.time()
        for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)):
            _, batch_loss, batch_logits = sess.run([optimizer, loss, logits],
                                                   feed_dict={inputs: source_batch,
                                                              outputs: target_batch[:, :-1],
                                                              targets: target_batch[:, 1:]})
        accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:, 1:])
        print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss,
                                                                                              accuracy,
                                                                                              time.time() - start_time))

    return None


def SAE_LSTM(max_seq_len, enc_size, dropout, vocab_size, emb_dim, emb_matrix, use_pretrain=1):
    """
    Build a sequence anto-encoder and auto-decoder using LSTM;
    input_shape should be (emb_dim, max_len)
    :return:
    May consider trying Bi-direction LSTM
    """
    # pdb.set_trace()
    model = Sequential()
    if use_pretrain == 0:
        model.add(Embedding(vocab_size, emb_dim, mask_zero=True))  # padding, unknown
    else:
        print('Using pre-trained embedding ...')
        model.add(Embedding(input_dim=vocab_size, output_dim=emb_dim, weights=[emb_matrix],
                            input_length=max_seq_len, mask_zero=True, trainable=True))
    model.add(Bidirectional(LSTM(units=128, activation='relu', dropout=dropout, recurrent_dropout=dropout,
                   return_sequences=True), name='en-lstm1'))
    # model.add(LSTM(units=64, activation='relu', dropout=dropout, recurrent_dropout=dropout, return_sequences=True, name='en-lstm2'))
    model.add(Bidirectional(LSTM(units=enc_size, activation='relu', return_sequences=False), name='sentence_embedding'))
    print('The auto-encoder dim is %s.' % enc_size)
    model.add(RepeatVector(max_seq_len))
    # model.add(LSTM(units=50, activation='relu', return_sequences=True, name='de-lstm0'))
    # model.add(LSTM(units=64, activation='relu', dropout=dropout, recurrent_dropout=dropout, return_sequences=True, name='de-lstm2'))
    model.add(Bidirectional(LSTM(units=128, activation='relu', dropout=dropout, recurrent_dropout=dropout,
                   return_sequences=True), name='de-lstm1'))
    model.add(TimeDistributed(Dense(vocab_size, activation='softmax')))
    model.summary()
    return model


if __name__ == "__main__":
    from time import time

    # setting the hyper parameters
    import argparse
    parser = argparse.ArgumentParser(description='train')
    parser.add_argument('--dataset', default='usps', choices=['mnist', 'usps'])
    parser.add_argument('--n_clusters', default=10, type=int)
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--epochs', default=200, type=int)
    parser.add_argument('--save_dir', default='results/temp', type=str)
    args = parser.parse_args()
    print(args)

    import os
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)



