''' finetune t5 on paraphrase data with style encoder '''

#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.


from transformers import (
    AutoTokenizer,
    AutoModel,
    get_scheduler,
    T5ForConditionalGeneration,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    LlamaTokenizer,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
import os
import sys
import random
from datetime import datetime
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import json
import copy
from torch.nn.utils.rnn import pad_sequence

from torch import nn
import torch.nn.functional as F

from torch.nn import CrossEntropyLoss

# from accelerate import Accelerator


from datasets import load_from_disk
import wandb
import click

MODEL_TO_MODEL_TYPE = {
        'google/t5-large-lm-adapt': 't5',
        'google/t5-v1_1-large': 't5',
        'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T': 'llama',

}

class StyleLlamaForCausalLM(LlamaForCausalLM):
    def forward(
        self,
        input_ids = None,
        attention_mask = None,
        position_ids = None,
        past_key_values = None,
        inputs_embeds = None,
        labels = None,
        use_cache = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
        cache_position = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            # cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., 1:-1, :].contiguous() # Added shift to account for adding in embeddings
            shift_labels = labels[..., 1:].contiguous() 
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict,
    tokenizer, 
    model, 
    non_special_tokens = None,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens)
    model.resize_token_embeddings(len(tokenizer))
    
    if num_new_tokens > 0:
        input_embeddings_data = model.get_input_embeddings().weight.data
        output_embeddings_data = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
        output_embeddings_data[-num_new_tokens:] = output_embeddings_avg
    print(f"Resized tokenizer and embedding to {len(tokenizer)} tokens.")

class TinyStyle(torch.nn.Module):
    def __init__(
        self, base_model, use_style=False, ctrl_embed_dim=768
    ):
        super().__init__()

        if MODEL_TO_MODEL_TYPE[base_model] == 't5':
            self.model = T5ForConditionalGeneration.from_pretrained(base_model)
        elif MODEL_TO_MODEL_TYPE[base_model] == 'llama':
            self.model = StyleLlamaForCausalLM.from_pretrained(base_model)
        else:
            assert False
        self.use_style = use_style  
        if self.use_style: 
            self.ctrl_embed_dim = ctrl_embed_dim
            if hasattr(self.model.config, 'd_model'):
                self.proj = torch.nn.Linear(self.ctrl_embed_dim, self.model.config.d_model)
            else:
                self.proj = torch.nn.Linear(self.ctrl_embed_dim, self.model.config.hidden_size)
      

    def forward(self, input_ids, attention_mask, labels=None, style=None):
        if self.use_style:
            style_embed = self.proj(style).unsqueeze(1)

        input_embeds = self.model.get_input_embeddings()(input_ids)
        if self.use_style:
            input_embeds = torch.cat([style_embed, input_embeds], dim=1)
            attention_mask = torch.cat([torch.ones((input_embeds.shape[0], 1)).to(attention_mask.device), attention_mask], dim=1)


        # import pdb; pdb.set_trace()
        
        return self.model(inputs_embeds=input_embeds, attention_mask=attention_mask, labels=labels)

    def generate(self, input_ids, attention_mask, style=None, **kwargs):
        if self.use_style:
            style_embed = self.proj(style.unsqueeze(1))

        input_embeds = self.model.get_input_embeddings()(input_ids)
        if self.use_style:
            input_embeds = torch.cat([style_embed, input_embeds], dim=1)
            attention_mask = torch.cat([torch.ones((input_embeds.shape[0], 1)).to(attention_mask.device), attention_mask], dim=1)
        
        return self.model.generate(
            inputs_embeds=input_embeds, attention_mask=attention_mask,  **kwargs 
        )

def data_collator(*, batch, tokenizer, max_length_src, max_length_tgt, ctr_embed_key, ignore_idx=-100, input_key='paraphrase', output_key='text', do_lower=False):
    if do_lower:
        for example in batch:
            example[input_key] = example[input_key].lower()

    inputs = tokenizer([x[input_key] for x in batch], max_length=max_length_src, padding=True, truncation=True, return_tensors='pt')
    labels = tokenizer([x[output_key] for x in batch], max_length=max_length_tgt, padding=True, truncation=True, return_tensors='pt')['input_ids']
    labels[labels == tokenizer.pad_token_id] = ignore_idx

    retval = inputs
    retval['labels'] = labels
    
    
    style_embeddings = torch.stack([torch.tensor(x[ctr_embed_key]) for x in batch], dim=0)
    retval[ctr_embed_key] = style_embeddings

    return retval

def data_collator_causal_fixed(*, batch, tokenizer, max_length_src, max_length_tgt, ctr_embed_key, ignore_idx=-100, input_key='paraphrase', output_key='text', do_lower=False):
    # Extract elements
    # sources = [f"{tokenizer.bos_token}{example['paraphrase']}" for example in batch]
    # targets = [f"{example['text']}{tokenizer.eos_token}" for example in batch]

    if do_lower:
        for example in batch:
            example[input_key] = example[input_key].lower()

    sources = [f"{tokenizer.bos_token} {example[input_key]} |||" for example in batch]
    targets = [f"{example[output_key]}" for example in batch]
    # Tokenize

    

    tokenized_sources_with_prompt = tokenizer(
        sources,
        max_length=max_length_src,
        truncation=True,
        add_special_tokens=False,
        # add_special_tokens=False,
    )
    tokenized_targets = tokenizer(
        targets,
        max_length=max_length_tgt,
        truncation=True,
        add_special_tokens=False,
        # add_special_tokens=False,
    )
    # Build the input and labels for causal LM
    input_ids = []
    labels = []
    for tokenized_source, tokenized_target in zip(
        tokenized_sources_with_prompt['input_ids'],
        tokenized_targets['input_ids']
    ):
            input_ids.append(torch.tensor(tokenized_source + tokenized_target))
            labels.append(
                torch.tensor([ignore_idx for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
            )
        

    # Apply padding
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = pad_sequence(labels, batch_first=True, padding_value=ignore_idx)
    data_dict = {
        'input_ids': input_ids,
        'attention_mask':input_ids.ne(tokenizer.pad_token_id),

    }
    if labels is not None:
        data_dict['labels'] = labels


    style_embeddings = torch.stack([torch.tensor(x[ctr_embed_key]) for x in batch], dim=0)
    data_dict[ctr_embed_key] = style_embeddings
    
    return data_dict

# def data_collator_causal(*, batch, tokenizer, max_length_src, max_length_tgt, ctr_embed_key, ignore_idx=-100):
#     # Extract elements
#     sources = [f"{tokenizer.bos_token}{example['paraphrase']}" for example in batch]
#     targets = [f"{example['text']}{tokenizer.eos_token}" for example in batch]
#     # Tokenize
#     tokenized_sources_with_prompt = tokenizer(
#         sources,
#         max_length=max_length_src,
#         truncation=True,
#         add_special_tokens=False,
#     )
#     tokenized_targets = tokenizer(
#         targets,
#         max_length=max_length_tgt,
#         truncation=True,
#         add_special_tokens=False,
#     )
#     # Build the input and labels for causal LM
#     input_ids = []
#     labels = []
#     for tokenized_source, tokenized_target in zip(
#         tokenized_sources_with_prompt['input_ids'],
#         tokenized_targets['input_ids']
#     ):
#             input_ids.append(torch.tensor(tokenized_source + tokenized_target))
#             labels.append(
#                 torch.tensor([ignore_idx for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
#             )
        

#     # Apply padding
#     input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
#     labels = pad_sequence(labels, batch_first=True, padding_value=ignore_idx)
#     data_dict = {
#         'input_ids': input_ids,
#         'attention_mask':input_ids.ne(tokenizer.pad_token_id),

#     }
#     if labels is not None:
#         data_dict['labels'] = labels


#     style_embeddings = torch.stack([torch.tensor(x[ctr_embed_key]) for x in batch], dim=0)
#     data_dict[ctr_embed_key] = style_embeddings
    
#     return data_dict


def run_model(*, batch, model, device, style_embed):
    # import pdb; pdb.set_trace()
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    style = batch[style_embed].to(device)
    result = model(
        input_ids=input_ids, attention_mask=attention_mask, labels=labels, style=style
    )
    loss = result.loss
    logits = result.logits
    return logits, loss




# example usage: 
# python style_enc_dec.py \
# --learning_rate 1e-5 
# --batch_size 64
# --accumulation_steps 2
# --out_dir /path/to/outdir
# --device cuda
# --warmup_steps 2000
# --max_steps 10000000
# --ctrl_embed_dim 768
# --style_embed style_embedding
# --model t5-large
# --data_file_path /path/to/data
# --checkpoint /path/to/checkpoint
# --seed 42
# --max_encoder_len 80
# --max_decoder_len 80
# --max_val_batch 200


@click.command()
@click.option('--learning_rate', type=float, default=1e-5)
@click.option('--batch_size', type=int, default=64)
@click.option('--accumulation_steps', type=int, default=1)
@click.option('--out_dir', type=str, required=True)
@click.option('--device', type=str, default='cuda')
@click.option('--warmup_steps', type=int, default=2000)
@click.option('--max_steps', type=int, default=10000000)
@click.option('--max_num_epochs', type=int, default=10000)
@click.option('--eval_freq', type=int, default=1000)
@click.option('--ctrl_embed_dim', type=int, default=768)
@click.option('--style_embed', type=str, default='style_embedding')
@click.option('--model_name', type=str, default='t5-large')
@click.option('--data_file_path', type=str, required=True)
@click.option('--checkpoint', type=str, default=None)
@click.option('--seed', type=int, default=42)
@click.option('--max_encoder_len', type=int, default=80)
@click.option('--max_decoder_len', type=int, default=80)
@click.option('--max_val_batch', type=int, default=200)
@click.option('--input_key', type=str, default='paraphrase')
@click.option('--output_key', type=str, default='text')
@click.option('--skip_load_optimizer', is_flag=True)
@click.option('--do_lower', is_flag=True)
def main(learning_rate, batch_size, accumulation_steps, out_dir, device, warmup_steps, max_steps, max_num_epochs, eval_freq, ctrl_embed_dim, style_embed, model_name, data_file_path, checkpoint, seed, max_encoder_len, max_decoder_len, max_val_batch, input_key, output_key, skip_load_optimizer, do_lower):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)   


    
    assert model_name in MODEL_TO_MODEL_TYPE

    # accelerator = Accelerator()
    # device = accelerator.device

    device = 'cuda'
    print(device)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side="right",
        use_fast=True, # Fast tokenizer giving issues.
        trust_remote_code=True,
    )
    


    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = TinyStyle(
        base_model=model_name,
        use_style=True,
        ctrl_embed_dim=ctrl_embed_dim,
    )
    model.to(device)

    if tokenizer._pad_token is None:
        special_tokens_dict = dict(pad_token='[PAD]')
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=special_tokens_dict,
            tokenizer=tokenizer,
            model=model.model
        )
        
    # load from checkpoint
    if checkpoint is not None:
        checkpoint_dir = os.path.dirname(checkpoint)

        if skip_load_optimizer:
            optimizer_path = None
            scheduler_path = None
           
        else:
            optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt')
            scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt')

            assert os.path.exists(optimizer_path)
            assert os.path.exists(scheduler_path)

        current_state = model.state_dict()
        saved_state_dict = torch.load(checkpoint, map_location=device)
        current_state.update(saved_state_dict)
        model.load_state_dict(current_state)

    else:
        checkpoint_dir = None
        optimizer_path = None
        scheduler_path = None


        
    tokenized_datasets = load_from_disk(data_file_path)

    train_dataset = tokenized_datasets["train"]
    eval_dataset = tokenized_datasets["val"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        print(f"Sample {index} of the training set: {train_dataset[index]}.")

    if MODEL_TO_MODEL_TYPE[model_name] == 't5':
        collator_fn = data_collator
    elif MODEL_TO_MODEL_TYPE[model_name] == 'llama':
        collator_fn = data_collator_causal_fixed #data_collator_causal
    else:
        assert False

    collator_args={'tokenizer': tokenizer, 'max_length_src': max_encoder_len, 'max_length_tgt': max_decoder_len, 'ctr_embed_key': style_embed, 'ignore_idx': -100, 'input_key': input_key, 'output_key': output_key, 'do_lower': do_lower}

    # DataLoaders creation:
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, batch_size=batch_size, collate_fn=lambda x: collator_fn(batch=x, **collator_args)
    )
    eval_dataloader = DataLoader(
        eval_dataset, shuffle=True, batch_size=batch_size, collate_fn=lambda x: collator_fn(batch=x, **collator_args)
    )

    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

    if optimizer_path is not None:
        optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))


    scheduler = get_scheduler(
        name='constant_with_warmup',
        optimizer=optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=max_steps,
    )

    if scheduler_path is not None:
        scheduler.load_state_dict(torch.load(scheduler_path, map_location=device))

    # model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    #     model, optimizer, train_dataloader, eval_dataloader
    # )

    # add date
    cur_date=datetime.now().strftime("%Y-%m-%d-%H.%M.%S")
    out_dir = os.path.join(out_dir, cur_date)
    os.makedirs(out_dir, exist_ok=True)

    wandb.init(
        project='tinystyle_emnlp',
        config={
            'learning_rate': learning_rate,
            'batch_size': batch_size,
            'accumulation_steps': accumulation_steps,
            'outdir': out_dir,
            'device': device,
            'warmup_steps': warmup_steps,
            'max_steps': max_steps,
            'max_num_epochs': max_num_epochs,
            'eval_feq': eval_freq,
            'ctrl_embed_dim': ctrl_embed_dim,
            'style_embed': style_embed,
            'model': model_name,
            'data_file_path': data_file_path,
            'checkpoint': checkpoint,
            'seed': seed,
            'max_encoder_len': max_encoder_len,
            'max_decoder_len': max_decoder_len,
            'max_val_batch': max_val_batch,
            'input_key': input_key,
            'output_key': output_key,
            'skip_load_optimizer': skip_load_optimizer

        },
    )

    fname = f'best_model_{model_name.replace("/","_")}_{learning_rate}_{batch_size*accumulation_steps}.pt'

    best_val_loss = None
    optimizer.zero_grad()
    steps = 0
    counter = 0
    for epoch in range(max_num_epochs):
        model.train()
        wandb.log({"epoch": epoch})

        with tqdm(total=len(train_dataloader)) as pbar:
            for i, data in enumerate(train_dataloader):
                _, loss = run_model(batch=data, model=model, device=device, style_embed=style_embed)
                if counter % 100 == 0:
                    print('Epoch: ', epoch, ', Train Loss: ', loss.item())

                wandb.log({"train_loss": loss.item()})
                loss = loss / accumulation_steps

                # accelerator.backward(loss)
                loss.backward()
                pbar.update(1)

                if (counter + 1) % accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    scheduler.step()
                    steps += 1

                if (counter + 1) % (eval_freq * accumulation_steps) == 0:
                    model.eval()
                    losses = []
                    with torch.no_grad():
                        for j, val_data in enumerate(eval_dataloader):
                            if j > max_val_batch:
                                break
                            _, loss = run_model(
                                batch=val_data, model=model, device=device, style_embed=style_embed
                            )
                            losses.append(loss.item())

                    val_loss = sum(losses) / len(losses)
                    wandb.log({"val_loss": val_loss})
                    print('Epoch: ', epoch, ', Val Loss: ', val_loss)

                    if best_val_loss is None or val_loss < best_val_loss:
                        best_val_loss = val_loss
                        print(epoch, i, 'New best val loss: ', best_val_loss)
                        with open(
                            os.path.join(out_dir, 'checkpoint_info.json'), 'w+'
                        ) as out_:
                            json.dump(
                                {
                                    'epoch': epoch,
                                    'i': i,
                                    'counter': counter,
                                    'steps': steps,
                                    'loss': best_val_loss,
                                },
                                out_,
                            )
                        torch.save(model.state_dict(), os.path.join(out_dir, fname))
                        # save optimizer state, save scheduler state
                        torch.save(
                            optimizer.state_dict(), os.path.join(out_dir, 'optimizer.pt')
                        )
                        torch.save(
                            scheduler.state_dict(), os.path.join(out_dir, 'scheduler.pt')
                        )

                    model.train()

                counter += 1

                if steps >= max_steps:
                    break

    wandb.finish()



if __name__ == '__main__':
    main()