# sources
# https://colab.research.google.com/drive/1oVJAIAIyPoCqaqo4LonOKBZhkfVPV-UM?usp=sharing#scrollTo=8zpJOCJhmSno
# https://huggingface.co/docs/transformers/tasks/translation
# https://github.com/trusthlt/privacy-legal-nlp-lm
# https://github.com/huggingface/notebooks/blob/main/examples/causal_language_modeling_flax.ipynb
# https://github.com/deepmind/optax/blob/master/examples/differentially_private_sgd.py


# Settings
import os
import warnings

import dp_optimizers

MODEL_CHECKPOINT = "facebook/mbart-large-cc25"
PER_DEVICE_BATCH_SIZE = 4
SOURCE_LANG = 'de'
TARGET_LANG = 'en'
DATASET_NAME = 'wmt16'
SEED = 0
NUM_TRAIN_EPOCHS = 1
LEARNING_RATE = 0.01
MAX_LEN = 128
L2_NORM_CLIP = 1.0
NOISE_MULTIPLIER = 0.81
CHECKPOINT_PATH = 'checkpoints/'
NUM_EXAMPLES = 1000

# Imports
from opacus.data_loader import DPDataLoader
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, FlaxAutoModelForSeq2SeqLM
import jax
import jax.numpy as jnp
import flax
from flax.training.common_utils import shard
from flax.training import train_state
from flax import traverse_util  # mdda
import optax
from tqdm.notebook import tqdm
from typing import Callable
import evaluate
import numpy as np
import torch
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
os.environ["TOKENIZERS_PARALLELISM"] = "true"


# Functions
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray:
    """
    Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
    have a single `decoder_start_token_id` in contrast to other Bart-like models.
    """
    prev_output_tokens = np.array(input_ids).copy()

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")

    # replace possible -100 values in labels by `pad_token_id`
    prev_output_tokens = np.where(prev_output_tokens == -100, pad_token_id, input_ids)
    index_of_eos = (np.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
    decoder_start_tokens = np.array(
        [prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=np.int32
    ).squeeze()

    prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].copy()
    prev_output_tokens[:, 0] = decoder_start_tokens

    return prev_output_tokens

def preprocess_function(examples):
    inputs = [example[SOURCE_LANG] for example in examples["translation"]]
    targets = [example[TARGET_LANG] for example in examples["translation"]]
    model_inputs = tokenizer(
        inputs, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="np"
    )
    # Setup the tokenizer for targets
    labels = tokenizer(
        targets,
        max_length=MAX_LEN,
        padding="max_length",
        truncation=True,
        return_tensors="np"
    )

    model_inputs["labels"] = labels["input_ids"]
    decoder_input_ids = shift_tokens_right(labels["input_ids"],
                                           model.config.pad_token_id)  # model.config.decoder_start_token_id)
    model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)

    # We need decoder_attention_mask so we can ignore pad tokens from loss
    model_inputs["decoder_attention_mask"] = labels["attention_mask"]

    return model_inputs


def compute_epsilon(N, batch_size, noise_multiplier, epochs):
  delta = 1e-8 #<1/N
  if N * delta > 1.:
    logger.warning('Your delta might be too high.')
  epsilon = compute_dp_sgd_privacy(N, batch_size, noise_multiplier, epochs, delta)
  return epsilon[0]


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def numpy_collate(batch):
    batch = {k: np.array([d[k] for d in batch]) for k in batch[0]}
    while batch['input_ids'].shape[0] % jax.local_device_count() != 0:
        for key in batch.keys():
            batch[key] = np.vstack([batch[key], batch[key][-1]])
    return shard(batch)


class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)


@jax.jit
def loss_function(logits, labels):
    xentropy = optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=model.config.vocab_size))
    return jnp.mean(xentropy)


@jax.jit
def eval_function(logits):
    return logits.argmax(-1)


@jax.jit
def train_step(state, batch, dropout_rng):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    @jax.jit
    def calc_loss_function(params, batch):
        targets = batch.pop("labels")
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_fn = jax.value_and_grad(calc_loss_function)
    # Insert dummy dimension in axis 1 to use jax.vmap over the batch
    batch = jax.tree_util.tree_map(lambda x: x[:, None], batch)
    # Use jax.vmap across the batch to extract per-example gradients
    grad_fn = jax.vmap(grad_fn, in_axes=(None, 0))
    loss, grad = grad_fn(state.params, batch)
    # average the loss
    loss = jnp.mean(loss)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    return new_state, metrics, new_dropout_rng


@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)


# MAIN CODE
# Loading
train_dataset = load_dataset(DATASET_NAME, SOURCE_LANG + '-' + TARGET_LANG, split='train').select(range(100))
eval_dataset = load_dataset(DATASET_NAME, SOURCE_LANG + '-' + TARGET_LANG, split='validation').select(range(100))
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT, seed=SEED, from_pt=True)
metric = evaluate.load("sacrebleu")

# Preprocessing
train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(preprocess_function, batched=True, remove_columns=eval_dataset.column_names)

# Trainings Hyperparameters
total_batch_size = PER_DEVICE_BATCH_SIZE * jax.local_device_count()
num_train_steps = len(train_dataset) // total_batch_size * NUM_TRAIN_EPOCHS
learning_rate_function = optax.linear_schedule(init_value=LEARNING_RATE, end_value=0, transition_steps=num_train_steps)
rng = jax.random.PRNGKey(SEED)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
optimizer = optax.dpsgd(learning_rate=LEARNING_RATE, l2_norm_clip=L2_NORM_CLIP, noise_multiplier=NOISE_MULTIPLIER, seed=SEED)
#optimizer = dp_optimizers.dp_adam(learning_rate=LEARNING_RATE, l2_norm_clip=L2_NORM_CLIP, noise_multiplier=NOISE_MULTIPLIER, seed=SEED)
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=optimizer,  # mdda
    logits_function=eval_function,
    loss_function=loss_function,
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=numpy_collate)
dp_train_loader = DPDataLoader.from_data_loader(train_loader)
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=total_batch_size, collate_fn=numpy_collate)
_, last_eval_batch_length = divmod(len(eval_dataset), total_batch_size)

# Multiple Device support
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
state = flax.jax_utils.replicate(state)  # copy state on all devices

logger.info('Batch size: %d, max_len %d, learning rate: %f, num_train_steps: %d , dataset size: %d, eval dataset size: %d dataset name: %s',
             total_batch_size, MAX_LEN, LEARNING_RATE, num_train_steps, len(train_dataset), len(eval_dataset), DATASET_NAME)
# Training Loop
for i, epoch in enumerate(tqdm(range(1, NUM_TRAIN_EPOCHS + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for id, batch in enumerate(dp_train_loader):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for id, batch in enumerate(eval_loader):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            # implementation from https://huggingface.co/docs/transformers/tasks/translation
            predictions = [x for sublist in predictions for x in sublist]
            labels = [x for sublist in labels for x in sublist]
            if id == (len(eval_loader) - 1) and last_eval_batch_length != 0:
                predictions = predictions[:last_eval_batch_length]
                labels = labels[:last_eval_batch_length]
            decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
            decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
            metric.add_batch(predictions=decoded_preds, references=decoded_labels)
        result = metric.compute()

    result = {"bleu": result["score"]}
    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    eval_score = round(list(result.values())[0], 3)
    metric_name = list(result.keys())[0]
    # steps = epoch * NUM_EXAMPLES // total_batch_size
    # eps = compute_epsilon(steps, delta)
    logger.info(f"{i + 1}/{NUM_TRAIN_EPOCHS} | Train loss: {loss} | Eval {metric_name}: {result}")
    model.save_pretrained(CHECKPOINT_PATH + "mbart-dpsgd-wmt16-checkpoint-epoch-" + str(epoch), params=jax.device_get(jax.tree_map(lambda x: x[0], state.params)), )
    logger.info(f"Saved model to mbart-dpsgd-wmt16-checkpoint-epoch-" + str(epoch))
epsilon = compute_epsilon(NUM_EXAMPLES, total_batch_size, NOISE_MULTIPLIER, NUM_TRAIN_EPOCHS)
logger.info(f"Final epsilon: {epsilon}")
