# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.  
# SPDX-License-Identifier: CC-BY-NC-4.0
from typing import Optional
import torch
import csv
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM, MBartForConditionalGeneration, MBart50TokenizerFast
)
import tensorflow as tf
import random
from tqdm import tqdm
from tensorflow.keras import layers
import time
import os 
import matplotlib.pyplot as plt
from tensorflow.compat.v1.keras.layers import CuDNNLSTM 

tokenizer = AutoTokenizer.from_pretrained('facebook/mbart-large-50')
model = AutoModelForSeq2SeqLM.from_pretrained('facebook/mbart-large-50')
encoder_embed = model.state_dict()['model.encoder.embed_tokens.weight'] # encoder embedding
model.eval()
dimension = 24
hidden = 1024

low = 0.5# bart_score_low
high = 0.9#bart_score_high
accept_prob = high


tasks = ['STS.ar-ar.txt', 'STS.de-de.txt', 'STS.en-en.txt', 'STS.es-es.txt', 'STS.fr-fr.txt', 'STS.it-it.txt', 'STS.nl-nl.txt', 'STS.tr-tr.txt']

split_data = []
sent_score = []

for k in range(8):
  task_num = k
  task = tasks[task_num]
  file = open(task, 'r', encoding="utf-8")

  csvreader = csv.reader(file, delimiter="\t")


  for row in csvreader:  
    flag = 0
    for i in range(len(row)):
      if row[i].replace('.','',1).isdigit() == False:                    
        if flag == 0:
          sent1 = row[i]          
          flag = 1
        else:
          sent2 = row[i]        
      else:
        sent_score.append(row[i])
    split_data.append([sent1, sent2])

  file.close()
print("Example of data:", split_data[0])


train_embed = np.zeros((len(split_data), hidden, dimension))
for i in range(len(split_data)):
  txt = split_data[i][0]
  # Exclude first and last, meaning the start and end of sentence
  input_ids = torch.tensor([tokenizer.encode(txt, add_special_tokens=True)])[0][1:-1]
  if len(input_ids) > dimension:
    input_ids = input_ids[:dimension]

  for j in range(len(input_ids)):
    train_embed[i,:,j] = encoder_embed[input_ids[j]]


train_embed2 = np.zeros((len(split_data), hidden, dimension))
for i in range(len(split_data)):
  txt = split_data[i][1]
  # Exclude first and last, meaning the start and end of sentence
  input_ids = torch.tensor([tokenizer.encode(txt, add_special_tokens=True)])[0][1:-1]
  if len(input_ids) > dimension:
    input_ids = input_ids[:dimension]

  for j in range(len(input_ids)):
    train_embed2[i,:,j] = encoder_embed[input_ids[j]]





## GAN Part ##
# train_data: embedding of original text
train_data = np.concatenate((np.float64(train_embed), np.float64(train_embed2)), axis=0)  # Match the type
embedding_size = np.shape(train_data)[1]
ori_syn_data = {}

epochs = 50
batch_size = 16 # Tried 4 before for half-TREC
learning_rate = 1e-4
loss_fig='GAN_Loss.png'
dim = dimension

train_data = train_data[:len(train_data) // batch_size * batch_size]  # Eliminate the few data for matching epoch
train_data = np.reshape(train_data, [len(train_data), embedding_size, 1, dim])
train_dataset = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size)
expected_batches = int(len(train_data)/batch_size)

# Dummy labels. Please ignore it.
dummy_labels = np.ones((len(train_data), 1))  # Train label
dummy_labels = tf.data.Dataset.from_tensor_slices(dummy_labels).batch(batch_size)


generator_optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0, beta_2=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0, beta_2=0.9)

def make_generator_model():  # Put embedded input for text

  model = tf.keras.Sequential()

  model.add(layers.Flatten(input_shape=(embedding_size, dim)))
  model.add(layers.Dense(32 * 1 * 16, use_bias=False))
  model.add(layers.BatchNormalization())  # BN
  model.add(layers.LeakyReLU())  # LeakyReLu
    
  model.add(layers.Reshape((32, 1, 16)))
  assert model.output_shape == (None, 32, 1, 16)
  # Convolutional-Transpose Layer #1
  model.add(layers.Conv2DTranspose(128, (5, 1), strides=(2, 1), padding='same', use_bias=False))  # Conv-T
  assert model.output_shape == (None, 64, 1, 128)
  model.add(layers.BatchNormalization())  # BN
  model.add(layers.LeakyReLU())  # LeakyReLu

  # Convolutional-Transpose Layer #2
  model.add(layers.Conv2DTranspose(64, (5, 1), strides=(4, 1), padding='same', use_bias=False))  # Conv-T
  assert model.output_shape == (None, 256, 1, 64)
  model.add(layers.BatchNormalization())  # BN
  model.add(layers.LeakyReLU())  # LeakyReLu
    
  # Convolutional-Transpose Layer #3
  model.add(layers.Conv2DTranspose(32, (5, 1), strides=(2, 1), padding='same', use_bias=False))  # Conv-T
  assert model.output_shape == (None, 512, 1, 32)
  model.add(layers.BatchNormalization())  # BN
  model.add(layers.LeakyReLU())  # LeakyReLu

  # Convolutional-Transpose Layer #4
  model.add(
      layers.Conv2DTranspose(dim, (5, 1), strides=(2, 1), padding='same', use_bias=False))  # Conv-T
  # Output
  assert model.output_shape == (None, embedding_size, 1, dim)

  return model

generator = make_generator_model()

def make_discriminator_model():
    model = tf.keras.Sequential()
    # Convolutional Layer #1    
    model.add(layers.Conv2D(32, (5, 1), strides=(1, 1), padding='same',
                            input_shape=[embedding_size, 1, dim]))  # Conv
    model.add(layers.LeakyReLU())  # LeakyReLu
    model.add(layers.AveragePooling2D((2, 1), padding='same'))


    model.add(layers.Dropout(0.5))  # Dropout
    # Convolutional Layer #2
    model.add(layers.Conv2D(64, (5, 1), strides=(1, 1), padding='same'))  # Conv
    model.add(layers.LeakyReLU())  # LeakyReLu
    model.add(layers.AveragePooling2D((2, 1), padding='same'))


    model.add(layers.Dropout(0.5))  # Dropout
    # Convolutional Layer #3
    model.add(layers.Conv2D(128, (5, 1), strides=(1, 1), padding='same'))  # Conv
    model.add(layers.LeakyReLU())  # LeakyReLu
    model.add(layers.AveragePooling2D((3, 1), padding='same'))

    model.add(layers.Dropout(0.5))  # Dropout
    # Convolutional Layer #4
    model.add(layers.Conv2D(256, (5, 1), strides=(1, 1), padding='same'))  # Conv
    model.add(layers.LeakyReLU())  # LeakyReLu
    model.add(layers.AveragePooling2D((2, 1), padding='same'))


    model.add(layers.Dropout(0.5))  # Dropout
    # FC Layer + Output
    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model
discriminator = make_discriminator_model()

def generator_loss(fake_output):
    original_loss = tf.cast(-tf.reduce_mean(fake_output), dtype=tf.float64)  # Loss for WGAN

    return original_loss

def discriminator_loss(real_data, fake_data):
    # Calculate the gradient penalty term
    LAMBDA = 10
    alpha = tf.cast(tf.random.uniform([batch_size, embedding_size, 1, 1], 0., 1.), dtype=tf.float64)
    real_data = tf.cast(real_data, dtype=tf.float64)
    fake_data = tf.cast(fake_data, dtype=tf.float64)
    interpolates = alpha * real_data + (1 - alpha) * fake_data

    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolates)
        pred = discriminator(interpolates)

    gradients = gp_tape.gradient(pred, [interpolates])[0]

    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))

    gradient_penalty = tf.reduce_mean((slopes - 1) ** 2)  # Gradient penalty term

    real_output = discriminator(real_data, training=True)
    fake_output = discriminator(fake_data, training=True)

    wasserstein_dist = tf.cast(tf.reduce_mean(fake_output) - tf.reduce_mean(real_output),
                                dtype=tf.float64)  # Loss for WGAN

    return wasserstein_dist + LAMBDA * gradient_penalty  # Loss with gradient penalty term


def train_step(images, epoch):
    for _ in range(5):  # Train discriminator 5 times more than generator

        noise = tf.random.normal([batch_size, embedding_size, dim])  # Input for generator

        images_noise = tf.reshape(images, [batch_size, embedding_size, dim])
        noise = tf.cast(noise, tf.float64)
        noise = tf.add(noise, images_noise)

        with tf.GradientTape() as disc_tape:
            generated_images = generator(noise, training=True)  # Synthetic data

            disc_loss = discriminator_loss(images, generated_images)  # Loss from discriminator
            gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
            # Optimizer for discriminator
            discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator,
                                                        discriminator.trainable_variables))

        noise = tf.random.normal([batch_size, embedding_size, dim])  # Input for generator

        noise = tf.cast(noise, tf.float64)
        noise = tf.add(noise, images_noise)
        with tf.GradientTape() as gen_tape:
            generated_images = generator(noise, training=True)  # Synthetic data

            fake_output = discriminator(generated_images, training=True)
            gen_loss = generator_loss(fake_output)  # Loss from generator
            gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
            generator_optimizer.apply_gradients(
                zip(gradients_of_generator, generator.trainable_variables))  # Optimizer for generator

        return gen_loss, disc_loss


def train(dataset, labels, expected_batches):  # Train during epoch
    losses_gen = []
    losses_disc = []
    batch_ids = []
    plot_losses_gen = []
    plot_losses_disc = []
    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     generator=generator,
                                     discriminator=discriminator)

    start = time.time()
    AVERAGE_LAST = int(0.01 * expected_batches)
    batch_id = 0
    for epoch in tqdm(range(epochs), desc="Epoch"):
        for image_batch, true_label in tqdm(zip(dataset, labels), desc="Batch", total=expected_batches):
            gen_loss, disc_loss = train_step(image_batch, epoch)
            losses_gen.append(gen_loss)
            losses_disc.append(disc_loss)

            # batch_id += 1
            # if batch_id > AVERAGE_LAST:
            #     batch_ids.append(float(batch_id)/expected_batches)
        plot_losses_gen.append(np.mean(losses_gen[-1]))
        plot_losses_disc.append(np.mean(losses_disc[-1]))

        # _save_loss_plot(batch_ids, plot_losses_gen, plot_losses_disc)

        if epoch % 10 == 0:
            print('Time for epoch {} is {:.2f} sec'.format(epoch + 1, time.time() - start))
            # Save the model every 10 epochs
            checkpoint.save(file_prefix=checkpoint_prefix)
            start = time.time()

    return plot_losses_gen, plot_losses_disc

print('Learning Started!')
start_time = time.time()

gen_loss, disc_loss = train(train_dataset, dummy_labels, expected_batches)  # Train start!

print('Learning Finished!')

generator.save('saved_model/generator')
discriminator.save('saved_model/discriminator')


