from time import time
import numpy as np
import tensorflow as tf
import keras.backend as K
import keras
from keras.engine.topology import Layer, InputSpec
from keras.models import Model, load_model
from keras.utils.vis_utils import plot_model
from keras.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.cluster import KMeans
import pandas as pd
import pickle
import os
import logging
import argparse
import metrics
import utils
from E2EAE import SAE_LSTM, SAE_LSTM_tf
import pdb


class ClusteringLayer(Layer):
    """
    Clustering layer converts input sample (feature) to soft label, i.e. a vector that represents the probability of the
    sample belonging to each cluster. The probability is calculated with student's t-distribution.

    # Example
    ```
        model.add(ClusteringLayer(n_clusters=10))
    ```
    # Arguments
        n_clusters: number of clusters.
        weights: list of Numpy array with shape `(n_clusters, n_features)` witch represents the initial cluster centers.
        alpha: parameter in Student's t-distribution. Default to 1.0.
    # Input shape
        2D tensor with shape: `(n_samples, n_features)`.
    # Output shape
        2D tensor with shape: `(n_samples, n_clusters)`.
    """
    def __init__(self, n_clusters, weights=None, alpha=1.0, **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(ClusteringLayer, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.alpha = alpha
        self.initial_weights = weights
        self.input_spec = InputSpec(ndim=2)

    def build(self, input_shape):
        assert len(input_shape) == 2
        input_dim = input_shape[1]
        self.input_spec = InputSpec(dtype=K.floatx(), shape=(None, input_dim))
        self.clusters = self.add_weight((self.n_clusters, input_dim), initializer='glorot_uniform', name='clusters')
        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def call(self, inputs, **kwargs):
        """ student t-distribution, as same as used in t-SNE algorithm.
                 q_ij = 1/(1+dist(x_i, u_j)^2), then normalize it.
        Arguments:
            inputs: the variable containing data, shape=(n_samples, n_features)
        Return:
            q: student's t-distribution, or soft labels for each sample. shape=(n_samples, n_clusters)
        """
        q = 1.0 / (1.0 + (K.sum(K.square(K.expand_dims(inputs, axis=1) - self.clusters), axis=2) / self.alpha))
        q **= (self.alpha + 1.0) / 2.0
        q = K.transpose(K.transpose(q) / K.sum(q, axis=1))
        return q

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) == 2
        return input_shape[0], self.n_clusters

    def get_config(self):
        config = {'n_clusters': self.n_clusters}
        base_config = super(ClusteringLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class DSEC(object):
    # Deep sentence embedding clustering
    def __init__(self, max_seq_len, enc_size, dropout, vocab_size, emb_dim, emb_matrix,
                 w2n, n2w, n_clusters=10, alpha=1.0, pretrained=True, input_args=None):
        super(DSEC, self).__init__()

        self.n_clusters = n_clusters
        self.alpha = alpha
        self.pretrained = pretrained
        self.y_pred = []
        self.n2w = n2w
        self.w2n = w2n
        self.file_name = None
        self.args = input_args

        # Auto-encoder with LSTM
        self.sae = SAE_LSTM(max_seq_len, enc_size, dropout, vocab_size, emb_dim, emb_matrix, use_pretrain=1)
        hidden = self.sae.get_layer(name='sentence_embedding').output  # get the layer's output (10,)
        self.encoder = Model(inputs=self.sae.input, outputs=hidden)
        # self.decoder_target = tf.placeholder(dtype='int32', shape=(None, max_seq_len))

        # Define clustering model
        clustering_layer = ClusteringLayer(self.n_clusters, name='clustering')(hidden)
        self.model = Model(inputs=self.sae.input, outputs=[clustering_layer, self.sae.output])

    def pretrain(self, x, y, batch_size=256, epochs=200, save_dir='results/temp'):
        print('...Pretraining...')
        self.sae.compile(optimizer='adam',
                         loss='sparse_categorical_crossentropy',
                         # loss=utils.masked_sparse_categorical_crossentropy,
                         # metrics=[utils.sparse_loss],
                         metrics=['accuracy'])
        from keras.callbacks import CSVLogger
        csv_logger = CSVLogger(args.save_dir + '/pretrain-sae-model.log', append=True)

        logging.basicConfig(filename=args.save_dir + '/pretrain-sae-model.log', filemode='a', level=logging.DEBUG)
        logging.getLogger().addHandler(logging.StreamHandler())

        # begin training
        t0 = time()
        # file_name = save_dir+'/'+'pretrain-sae-model'
        self.file_name = save_dir + '/' + 'pretrain-sae-weights'
        # save each epoch
        # pretrain_model_dir = save_dir + '/' + 'pretrained-models'
        # if not os.path.exists(pretrain_model_dir):
        #    os.makedirs(pretrain_model_dir)
        # weights.{epoch:02d}-{val_loss:.2f}.hdf5

        early_stop = EarlyStopping(monitor='val_loss', patience=0, min_delta=0, mode='auto')
        checkpoint = ModelCheckpoint(self.file_name+'.h5', monitor='val_loss', save_best_only=True,
                                     save_weights_only=True, mode='auto', period=1)
        nx = x.shape[0]
        for i in range(int(epochs)):
            self.sae.fit(x, y, validation_split=0.1, batch_size=batch_size, epochs=1,
                         callbacks=[early_stop, checkpoint, csv_logger])
            # for each K epochs, randomly pick a x, and see its product
            x_test_id = np.random.choice(nx, 1)[0]
            x_test = x[x_test_id]
            logging.info('Testing: the source sentence is:')
            logging.info(' '.join([self.n2w[id] for id in x_test]))
            logging.info('The predicted sentence is:')
            y_test = np.argmax(self.sae.predict(np.array(x_test)[None, :]), -1)[0]
            logging.info(' '.join([self.n2w[id] for id in y_test]))
            logging.info('End of Epoch %s' % i)

        logging.info('Pretraining time: ', time() - t0)
        self.sae.save_weights(self.file_name+'-final.h5')
        logging.info('The best pretrained model is saved to %s' % (self.file_name+'.h5'))
        self.pretrained = True

    def load_weights(self, weights_path):
        self.model.load_weights(weights_path)

    def load_models(self, models_path):
        # save_dir+'/'+'pretrain-sae-model.h5'
        self.model = load_model(models_path)

    def extract_feature(self, x):  # extract features from before clustering layer
        return self.encoder.predict(x)

    def predict(self, x):
        q, _ = self.model.predict(x, verbose=0)
        return q.argmax(1)

    @staticmethod
    def predict_cluster_id_by_batch(keras_model, x, batch_size=1000, output_size=20, main_model=True):
        index = 0
        q_output = np.array([], dtype=np.float32).reshape(0, output_size)
        while True:
            if (index + 1) * batch_size >= x.shape[0]:
                x_batch = x[index*batch_size::]
                if main_model:
                    q_batch, _ = keras_model.predict(x_batch)
                else:
                    q_batch = keras_model.predict(x_batch)
                q_output = np.concatenate((q_output, q_batch), axis=0)
                break
            else:
                x_batch = x[index*batch_size: (index+1)*batch_size]
                if main_model:
                    q_batch, _ = keras_model.predict(x_batch)
                else:
                    q_batch = keras_model.predict(x_batch)
                index += 1
                q_output = np.concatenate((q_output, q_batch), axis=0)

        return q_output

    @staticmethod
    def target_distribution(q):
        weight = q ** 2 / q.sum(0)
        return (weight.T / weight.sum(1)).T

    def get_layer_output(self, x, layer_name='sentence_embedding'):
        sub_model = Model(inputs=self.model.input, outputs=self.model.get_layer(layer_name).output)
        output_size = self.model.get_layer(layer_name).output.get_shape().as_list()[1]
        output = self.predict_cluster_id_by_batch(sub_model, x, 1000, output_size, False)
        return output

    def compile(self, loss=['kld', 'sparse_categorical_crossentropy'], loss_weights=[1, 1], optimizer='adam'):
        try:
            self.model.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer)
        except:
            print('Fails to compile the model!')

    def fit(self, x, y, cluster_id, batch_size=256, maxiter=2e4, tol=1e-3,
            update_interval=100, sae_weights_path=None, save_dir='./results/temp', nrows=None):
        print('Update interval', update_interval)
        # save_interval = int(x.shape[0] / batch_size * 5)
        save_interval = 100
        print('Save interval', save_interval)

        # Step 1: pretrain if necessary
        t0 = time()
        if self.pretrained is False or sae_weights_path is None:
            if y.ndim == 2:
                y = y.reshape(y.shape + (1,))
            print('...pretraining auto-encoder using default hyper-parameters:')
            print('   optimizer=\'adam\';   epochs=100')
            self.pretrain(x, y, batch_size, epochs=100, save_dir=save_dir)
            self.pretrained = True
        elif sae_weights_path is not None:
            # sae_weights_path: path_to_model_dir (save_dir+'/'+'pretrain-sae-weights.h5')
            sae_params_path = sae_weights_path + '/' + 'sae-model-param.pickle'
            sae_weights_path = sae_weights_path + '/' + 'pretrain-sae-weights' + '.h5'
            # sae_params_path = save_dir + '/' + 'sae-model-param.pickle'
            # sae_weights_path = save_dir + '/' + 'pretrain-sae-weights' + '.h5'
            self.sae.load_weights(sae_weights_path)
            sae_params = pickle.load(open(sae_params_path, "rb"))
            self.w2n = sae_params['w2n']
            self.n2w = sae_params['n2w']

            if nrows is None:
                x = sae_params['x_train']
                y = sae_params['y_train']
                cluster_id = sae_params['cluster_id']
            else:
                print('Train clustering on a sample of size %s' % nrows)
                x = sae_params['x_train'][:nrows]
                y = sae_params['y_train'][:nrows]
                cluster_id = sae_params['cluster_id'][:nrows]

            if y.ndim == 2:
                y = y.reshape(y.shape + (1,))
            print('sae_weights is loaded successfully.')
            print('Evaluating the loaded sae model performance:')
            for i in range(10):
                x_test_id = np.random.choice(x.shape[0], 1)[0]
                x_test = x[x_test_id]
                print('Testing: the source sentence %s is:' % i)
                print(' '.join([self.n2w[id] for id in x_test]))
                print('The target sentence %s is:' % i)
                y_test = np.argmax(self.sae.predict(np.array(x_test)[None, :]), -1)[0]
                print(' '.join([self.n2w[id] for id in y_test]))
                print('='*50)

        # Step 2: initialize cluster centers using k-means
        cluster_id = np.array(cluster_id)
        eval_idx = np.where(~np.isnan(cluster_id))[0]
        print('The evaluation sample size is: ' + str(len(eval_idx)))
        t1 = time()
        print('Initializing cluster centers with k-means.')
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=20)
        self.y_pred = kmeans.fit_predict(self.encoder.predict(x))
        # set initial weights to layer??
        self.model.get_layer(name='clustering').set_weights([kmeans.cluster_centers_])

        # Step 3: deep clustering
        # logging file
        import csv
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        # save the desc intermediate models
        if not os.path.exists(save_dir + '/models'):
            os.makedirs(save_dir + '/models')

        logging.basicConfig(filename=save_dir + '/DSEC-cluster.log', filemode='a', level=logging.DEBUG)
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.info(self.args)
        logfile = open(save_dir + '/dsec_log.csv', 'w')
        logwriter = csv.DictWriter(logfile, fieldnames=['iter', 'acc', 'nmi', 'L', 'Lc', 'Lr'])
        logwriter.writeheader()
        t2 = time()
        loss = [0, 0, 0]
        index = 0

        for ite in range(int(maxiter)):
            if ite % update_interval == 0:
                q = self.predict_cluster_id_by_batch(self.model, x, 1000, self.n_clusters)
                q_max = np.max(q, axis=-1)
                # q, _ = self.model.predict(x, verbose=0)
                # logging.info('Iter = %s: ' % ite)
                # logging.info('Evaluating the loaded sae model performance:')
                for i in range(0):
                    x_test_id = np.random.choice(x.shape[0], 1)[0]
                    x_test = x[x_test_id]
                    logging.info('--Testing: the source sentence %s is:' % i)
                    logging.info(' '.join([self.n2w[id] for id in x_test]))
                    logging.info('--The target sentence %s is:' % i)
                    y_test = np.argmax(self.model.predict(np.array(x_test)[None, :])[1], -1)[0]
                    # y_test = np.argmax(self.sae.predict(np.array(x_test)[None, :]), -1)[0]
                    logging.info(' '.join([self.n2w[id] for id in y_test]))
                    logging.info('=' * 50)

                p = self.target_distribution(q)  # update the auxiliary target distribution p
                # evaluate the clustering performance: remove the non-confident ones - outliers
                q_threshold = 3 / self.n_clusters
                eval_idx = np.where(q_max > q_threshold)[0]
                # y_pred_last = np.copy(self.y_pred[eval_idx])
                self.y_pred = q.argmax(1)
                if cluster_id is not None:
                    eval_idx = np.where(~np.isnan(cluster_id) & (q_max > q_threshold))[0]
                    acc = np.round(metrics.acc(cluster_id[eval_idx], self.y_pred[eval_idx]), 5)
                    mi = np.round(metrics.mi(cluster_id[eval_idx], self.y_pred[eval_idx]), 5)
                    nmi = np.round(metrics.nmi(cluster_id[eval_idx], self.y_pred[eval_idx]), 5)
                    ari = np.round(metrics.ari(cluster_id[eval_idx], self.y_pred[eval_idx]), 5)
                    loss = np.round(loss, 5)
                    model_performance_str = ('Iter '+str(ite)+': Acc=' + str(acc) + ', mi='+str(mi)+
                                  ', nmi=' + str(nmi) + ',ari=' + str(ari) + '; loss='+
                                  str(loss[0])+', lc='+str(loss[1])+', lr='+str(loss[2]))
                    logging.info(model_performance_str)
                    # print(model_performance_str)
                    logdict = dict(iter=ite, acc=acc, nmi=nmi, L=loss[0], Lc=loss[1], Lr=loss[2])
                    logwriter.writerow(logdict)

                # check stop criterion
                y_pred_last = np.copy(self.y_pred[eval_idx])
                delta_label = np.sum(self.y_pred[eval_idx] != y_pred_last).astype(np.float32) / len(eval_idx)
                # if ite > 0 and delta_label < tol:
                #    logging.info('delta_label ' + str(delta_label), '< tol ', str(tol))
                #    logging.info('Reached tolerance threshold. Stopping training.')
                #    break

            # train on batch

            if (index + 1) * batch_size > x.shape[0]:
                loss = self.model.train_on_batch(x=x[index * batch_size::],
                                                 y=[p[index * batch_size::], y[index * batch_size::]])
                index = 0
            else:
                loss = self.model.train_on_batch(x=x[index * batch_size:(index + 1) * batch_size],
                                                 y=[p[index * batch_size:(index + 1) * batch_size],
                                                    y[index * batch_size:(index + 1) * batch_size]])
                index += 1
            # print(loss)
            # save intermediate model
            if ite % save_interval == 0:
                # Save cluster prediction results: cluster_id, cluster_pred, encoder
                cluster_pred = pd.Series(np.argmax(q, -1))
                se_layer_output = self.get_layer_output(x, 'sentence_embedding')
                clust_df = pd.concat([pd.Series(cluster_id), cluster_pred, pd.Series(q_max), pd.DataFrame(se_layer_output)], axis=1)
                clust_df.columns = ['Gold', 'Pred', 'q_ij'] + list(clust_df.columns[3:])
                print('saving cluster results to:', save_dir + '/models/cluster_prediction_' + str(ite) + '.tsv')
                clust_df.to_csv(save_dir + '/models/cluster_prediction_' + str(ite) + '.tsv', sep='\t', index=False)
                # Save model
                print('saving model to:', save_dir + '/models/dsec_model_' + str(ite) + '.h5')
                self.model.save_weights(save_dir + '/models/dsec_model_' + str(ite) + '.h5')

            ite += 1

        # save the trained model
        logfile.close()
        print('saving model to:', save_dir + '/dsec_model_final.h5')
        self.model.save_weights(save_dir + '/dsec_model_final.h5')
        t3 = time()
        print('Pretrain time:  ', t1 - t0)
        print('Clustering time:', t3 - t1)
        print('Total time:     ', t3 - t0)


if __name__ == "__main__":
    # setting the hyper parameters, default path for cancel prime
    parser = argparse.ArgumentParser(description='train')
    parser.add_argument('--train_file', default='', type=str)
    parser.add_argument('--emb_file', default='', type=str)
    parser.add_argument('--n_clusters', default=20, type=int)
    parser.add_argument('--maxiter', default=2e4, type=int)
    parser.add_argument('--gamma', default=0.1, type=float,
                        help='coefficient of clustering loss')
    parser.add_argument('--update_interval', default=5, type=int)
    parser.add_argument('--tol', default=0.001, type=float)
    parser.add_argument('--max_seq_len', default=20, type=int)
    parser.add_argument('--encoder_size', default=50, type=int)
    parser.add_argument('--emb_dim', default=100, type=int)
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--dropout', default=0.25, type=float)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--gpu', default='4,5,6,7', type=str)
    parser.add_argument('--sae_weights_path', default=None, help='This argument must be given')
    parser.add_argument('--save_dir', default='results/glove_lstm_20190923')
    parser.add_argument('--nrows', default=None, type=int, help='Choose a subset for clustering')
    parser.add_argument('--rm_stop_words', action="store_true", dest='rm_stop_words', help='Whether remove stop words')
    args = parser.parse_args()

    print(args)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    max_seq_len = args.max_seq_len
    enc_size = args.encoder_size
    train_fname = args.train_file
    dropout = args.dropout
    emb_file = args.emb_file
    emb_dim = args.emb_dim
    lr = args.lr

    # train_df should have 'Template' and 'Gold' columns
    train_df = pd.read_csv(train_fname, sep='\t')  # 50000 only for testing

    if args.rm_stop_words:
        """
        python baselines.py --method glove --nrows 50000 --n_clusters 50 --pooling mean --rm_stop_words
        """
        from nltk.corpus import stopwords
        stopWords = set(stopwords.words('english'))
        clean_templates = train_df['Template'].apply(lambda x: ' '.join([w for w in str(x).split() if w not in stopWords]))
        train_df['Template'] = pd.Series(clean_templates, index=train_df.index)
        train_df = train_df[train_df['Template'] != '']
        print('Training data has {} rows.'.format(train_df.shape[0]))

    max_seq_len = min(max_seq_len, max(train_df['Template'].apply(lambda x: len(str(x).split()))))
    print('max_seq_len=' + str(max_seq_len))
    # assume that the input train_df['Template'] is cleaned, rm punctuations
    x_train, y_train, emb_matrix, w2n, n2w, vocab_size = utils.get_data(train_df['Template'].to_frame(),
                                                                        max_seq_len, emb_file)

    # save the parameters
    param_file = open(args.save_dir + "/sae-model-param.pickle", 'wb')
    pickle.dump({'w2n': w2n, 'n2w': n2w, 'vocab_size': vocab_size, 'max_seq_len': max_seq_len,
                 'dropout': dropout, 'emb_dim': emb_dim, 'emb_matrix': emb_matrix,
                 'x_train': x_train, 'y_train': y_train, 'cluster_id': train_df['Gold']}, param_file)
    param_file.close()

    # prepare the DSEC model
    dsec = DSEC(max_seq_len, enc_size, dropout, vocab_size, emb_dim, emb_matrix, w2n, n2w,
                    n_clusters=args.n_clusters, alpha=1.0, input_args=args)
    plot_model(dsec.model, to_file=args.save_dir + '/dsec_model.png', show_shapes=True)
    dsec.model.summary()

    # begin clustering.
    # optimizer = 'adam'
    optimizer = keras.optimizers.Adam(lr=lr)
    dsec.compile(loss=['kld', 'sparse_categorical_crossentropy'], loss_weights=[args.gamma, 1], optimizer=optimizer)
    dsec.fit(x=x_train, y=y_train, cluster_id=train_df['Gold'], tol=args.tol, maxiter=args.maxiter,
             batch_size=args.batch_size, update_interval=args.update_interval, save_dir=args.save_dir,
             sae_weights_path=args.sae_weights_path, nrows=args.nrows)
    #y_pred = dsec.y_pred
    # eval_idx = np.where(~np.isnan(cluster_id))[0]
    # y_true = cluster_id[eval_idx]
    #y_true = np.array(train_df['Gold'])

    #print('acc = %.4f, nmi = %.4f, ari = %.4f' % (metrics.acc(y_true, y_pred), metrics.nmi(y_true, y_pred),
    #                                              metrics.ari(y_true, y_pred)))




