# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0

import tensorflow as tf
import os, time, argparse
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers


# GAN model structure
from tqdm import tqdm


class GANModel:
    def __init__(self, file, gan, epoch, batch_size=10, learning_rate=1e-4, dimension=27, loss_fig='GAN_Loss.png', max_samples=None):
        with np.load(file) as openfile:
            # train_data: embedding of original text
            self.train_data = np.float64(openfile['train_data'])  # Match the type
        if max_samples is not None:
            self.train_data = self.train_data[:max_samples]
        self.embedding_size = np.shape(self.train_data)[1]
        self.ori_syn_data = {}

        self.epochs = epoch
        self.batch_size = batch_size
        self.learning_rate = learning_rate

        self.dim = dimension
        # if gan == 'bart':  # dimension match according to GAN
        #     self.dim = dimension
        # else:
        #     self.dim = 1

        self.generator = self.make_generator_model(gan)
        self.discriminator = self.make_discriminator_model()
        self.generator_optimizer = tf.keras.optimizers.Adam(self.learning_rate, beta_1=0, beta_2=0.9)
        self.discriminator_optimizer = tf.keras.optimizers.Adam(self.learning_rate, beta_1=0, beta_2=0.9)
        self.loss_fig = loss_fig
        self.expected_batches = int(len(self.train_data)/self.batch_size)

    def GAN_train(self):
        train_data = self.train_data[:len(
            self.train_data) // self.batch_size * self.batch_size]  # Eliminate the few data for matching epoch
        train_data = np.reshape(train_data, [len(train_data), self.embedding_size, 1, self.dim])
        train_dataset = tf.data.Dataset.from_tensor_slices(train_data).batch(self.batch_size)
        expected_batches = int(len(train_data)/self.batch_size)

        # Dummy labels. Please ignore it. 
        dummy_labels = np.ones((len(train_data), 1))  # Train label
        dummy_labels = tf.data.Dataset.from_tensor_slices(dummy_labels).batch(self.batch_size)

        print('Learning Started!')
        start_time = time.time()

        self.train(train_dataset, dummy_labels, expected_batches)  # Train start!

        print('Learning Finished!')
        print('Building time for GAN: {:.2f} seconds'.format(time.time() - start_time))

    def _save_loss_plot(self, batch_ids, gen_losses, disc_losses):
        SCALING_FACTOR=5
        fig = plt.figure(figsize=(
            SCALING_FACTOR*0.75*len(batch_ids)/self.expected_batches,
            SCALING_FACTOR
        ))
        ax = plt.subplot(111)

        ax.plot(batch_ids, gen_losses, 'r', label='Generator')
        ax.plot(batch_ids, disc_losses, 'g', label='Discriminator')
        plt.legend()
        plt.savefig(self.loss_fig)

    def make_generator_model(self, gan):  # Put embedded input for text

        model = tf.keras.Sequential()

        model.add(layers.Flatten(input_shape=(self.embedding_size, self.dim)))
        model.add(layers.Dense(32 * 1 * 16, use_bias=False))
        model.add(layers.BatchNormalization())  # BN
        model.add(layers.LeakyReLU())  # LeakyReLu
        model.add(layers.Reshape((32, 1, 16)))
        assert model.output_shape == (None, 32, 1, 16)
        # Convolutional-Transpose Layer #1
        model.add(layers.Conv2DTranspose(128, (5, 1), strides=(2, 1), padding='same', use_bias=False))  # Conv-T
        assert model.output_shape == (None, 64, 1, 128)
        model.add(layers.BatchNormalization())  # BN
        model.add(layers.LeakyReLU())  # LeakyReLu
        # Convolutional-Transpose Layer #2
        model.add(layers.Conv2DTranspose(64, (5, 1), strides=(3, 1), padding='same', use_bias=False))  # Conv-T
        assert model.output_shape == (None, 192, 1, 64)
        model.add(layers.BatchNormalization())  # BN
        model.add(layers.LeakyReLU())  # LeakyReLu
        # Convolutional-Transpose Layer #3
        model.add(layers.Conv2DTranspose(32, (5, 1), strides=(2, 1), padding='same', use_bias=False))  # Conv-T
        assert model.output_shape == (None, 384, 1, 32)
        model.add(layers.BatchNormalization())  # BN
        model.add(layers.LeakyReLU())  # LeakyReLu
        # Convolutional-Transpose Layer #4
        model.add(
            layers.Conv2DTranspose(self.dim, (5, 1), strides=(2, 1), padding='same', use_bias=False))  # Conv-T
        # Output
        assert model.output_shape == (None, self.embedding_size, 1, self.dim)

        return model

    def make_discriminator_model(self):
        model = tf.keras.Sequential()

        # Convolutional Layer #1    
        model.add(layers.Conv2D(32, (5, 1), strides=(2, 1), padding='same',
                                input_shape=[self.embedding_size, 1, self.dim]))  # Conv
        model.add(layers.LeakyReLU())  # LeakyReLu
        model.add(layers.Dropout(0.5))  # Dropout
        # Convolutional Layer #2
        model.add(layers.Conv2D(64, (5, 1), strides=(2, 1), padding='same'))  # Conv
        model.add(layers.LeakyReLU())  # LeakyReLu
        model.add(layers.Dropout(0.5))  # Dropout
        # Convolutional Layer #3
        model.add(layers.Conv2D(128, (5, 1), strides=(3, 1), padding='same'))  # Conv
        model.add(layers.LeakyReLU())  # LeakyReLu
        model.add(layers.Dropout(0.5))  # Dropout
        # Convolutional Layer #4
        model.add(layers.Conv2D(256, (5, 1), strides=(2, 1), padding='same'))  # Conv
        model.add(layers.LeakyReLU())  # LeakyReLu
        model.add(layers.Dropout(0.5))  # Dropout
        # FC Layer + Output
        model.add(layers.Flatten())
        model.add(layers.Dense(1))

        return model

    def generator_loss(self, fake_output):
        original_loss = tf.cast(-tf.reduce_mean(fake_output), dtype=tf.float64)  # Loss for WGAN

        return original_loss

    def discriminator_loss(self, real_data, fake_data):
        # Calculate the gradient penalty term 
        LAMBDA = 10
        alpha = tf.cast(tf.random.uniform([self.batch_size, self.embedding_size, 1, 1], 0., 1.), dtype=tf.float64)
        real_data = tf.cast(real_data, dtype=tf.float64)
        fake_data = tf.cast(fake_data, dtype=tf.float64)
        interpolates = alpha * real_data + (1 - alpha) * fake_data

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolates)
            pred = self.discriminator(interpolates)

        gradients = gp_tape.gradient(pred, [interpolates])[0]

        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))

        gradient_penalty = tf.reduce_mean((slopes - 1) ** 2)  # Gradient penalty term

        real_output = self.discriminator(real_data, training=True)
        fake_output = self.discriminator(fake_data, training=True)

        wasserstein_dist = tf.cast(tf.reduce_mean(fake_output) - tf.reduce_mean(real_output),
                                   dtype=tf.float64)  # Loss for WGAN

        return wasserstein_dist + LAMBDA * gradient_penalty  # Loss with gradient penalty term

    def train_step(self, images, epoch):
        for _ in range(5):  # Train discriminator 5 times more than generator

            noise = tf.random.normal([self.batch_size, self.embedding_size, self.dim])  # Input for generator

            images_noise = tf.reshape(images, [self.batch_size, self.embedding_size, self.dim])
            noise = tf.cast(noise, tf.float64)
            noise = tf.add(noise, images_noise)

            with tf.GradientTape() as disc_tape:
                generated_images = self.generator(noise, training=True)  # Synthetic data

                disc_loss = self.discriminator_loss(images, generated_images)  # Loss from discriminator
                gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
                # Optimizer for discriminator
                self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator,
                                                                 self.discriminator.trainable_variables))

        noise = tf.random.normal([self.batch_size, self.embedding_size, self.dim])  # Input for generator

        noise = tf.cast(noise, tf.float64)
        noise = tf.add(noise, images_noise)

        with tf.GradientTape() as gen_tape:
            generated_images = self.generator(noise, training=True)  # Synthetic data

            fake_output = self.discriminator(generated_images, training=True)
            gen_loss = self.generator_loss(fake_output)  # Loss from generator
            gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
            self.generator_optimizer.apply_gradients(
                zip(gradients_of_generator, self.generator.trainable_variables))  # Optimizer for generator

        return gen_loss, disc_loss

    def train(self, dataset, labels, expected_batches):  # Train during epoch
        losses_gen = []
        losses_disc = []
        batch_ids = []
        plot_losses_gen = []
        plot_losses_disc = []
        checkpoint_dir = './training_checkpoints'
        checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
        checkpoint = tf.train.Checkpoint(generator_optimizer=self.generator_optimizer,
                                         discriminator_optimizer=self.discriminator_optimizer,
                                         generator=self.generator,
                                         discriminator=self.discriminator)

        start = time.time()
        AVERAGE_LAST = int(0.01*expected_batches)
        batch_id = 0
        for epoch in tqdm(range(self.epochs), desc="Epoch"):
            for image_batch, true_label in tqdm(zip(dataset, labels), desc="Batch", total=expected_batches):
                gen_loss, disc_loss = self.train_step(image_batch, epoch)
                losses_gen.append(gen_loss)
                losses_disc.append(disc_loss)

                batch_id += 1
                if batch_id > AVERAGE_LAST:
                    batch_ids.append(float(batch_id)/expected_batches)
                    plot_losses_gen.append(np.mean(losses_gen[-AVERAGE_LAST:]))
                    plot_losses_disc.append(np.mean(losses_disc[-AVERAGE_LAST:]))

            self._save_loss_plot(batch_ids, plot_losses_gen, plot_losses_disc)

            if epoch % 10 == 0:
                print('Time for epoch {} is {:.2f} sec'.format(epoch + 1, time.time() - start))
                # Save the model every 10 epochs
                checkpoint.save(file_prefix=checkpoint_prefix)
                start = time.time()


def main(args):
    model = GANModel(args.input, args.gan, args.epoch, args.batch_size, args.learning_rate,
                     args.dimension, args.loss_fig, args.max_samples)
    model.GAN_train()
    generator = model.generator
    generator.save(args.output)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='GAN training and generating synthetic data',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input', type=str, required=True, help='Original embedding data. Only receive npz file.')
    parser.add_argument('--gan', type=str, required=True, help='Choose GAN methods: bart.')
    parser.add_argument('--output', type=str, required=True, help='Output trained generator model')
    parser.add_argument('--batch-size', type=int, default=128, help='Batch size used in GAN')
    parser.add_argument('--epoch', type=int, default=30, help='Number of epochs used in GAN')
    parser.add_argument('--learning-rate', type=float, default=1e-4, help='Learning rate used in GAN')
    parser.add_argument('--dimension', type=int, default=6, help='Number of tokens considered in BART-GAN')
    parser.add_argument('--loss-fig', type=str, default='GAN_Loss.png', help='Name of loss figure')
    parser.add_argument('--max-samples', type=int, help='Maximum number of samples')

    args = parser.parse_args()
    main(args)
