import wandb
import torch
from datetime import datetime
import os
from peft import (
    PeftModel,
    LoraConfig,
    get_peft_model,
)

from transformers import (
    CodeLlamaTokenizer,
    LlamaForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
)
from datasets import load_dataset
from dataclasses import dataclass

@dataclass
class CompilerLoraConfig:
    src_language: str = "c"
    tgt_language: str = "x86"
    opt_level: str = "O0"
    corpus: str = "aligned_code_c_x86"  # anonymous_data

    max_length: int = 2048
    r: int = 128
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    bias: str = "none"

    base_model: str = "codellama/CodeLlama-13b-Instruct-hf"
    base_model_name: str = "CodeLlama-13b-Instruct-hf"

    batch_size: int = 16
    per_device_train_batch_size: int = 4
    gradient_accumulation_steps: int = batch_size // per_device_train_batch_size
    num_train_epochs: int = 1
    warmup_steps: int = 5000
    learning_rate: float = 1e-4
    lr_scheduler_type: str = "cosine_with_restarts"
    fp16: bool = True
    fp16_opt_level: str = "O1"
    logging_steps: int = 100
    optim: str = "adamw_torch"
    evaluation_strategy: str = "steps"
    save_strategy: str = "steps"
    eval_steps: int = 10000
    save_steps: int = 5000
    save_total_limit: int = 3

def train(config, save_path, merged_path, log_path, cache_path):
    print(config)
    corpora = load_dataset(config.corpus, split="train")
    sep_dataset = corpora.train_test_split(0.05,0.95, True)
    train_dataset = sep_dataset["train"]
    test_dataset = sep_dataset["test"]
    print("train_dataset", len(train_dataset))
    print("test_dataset", len(test_dataset))
    base_model = config.base_model
    base_model_name = config.base_model_name
    model = LlamaForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16, device_map="auto")
    tokenizer = CodeLlamaTokenizer.from_pretrained(base_model)
    
    tokenizer.add_eos_token = True
    tokenizer.pad_token_id = 2
    tokenizer.padding_side = "left"

    def tokenize(prompt):
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=config.max_length,
            padding=False,
            return_tensors=None,
        )

        # "self-supervised learning" means the labels are also the inputs:
        result["labels"] = result["input_ids"].copy()

        return result

    def generate_and_tokenize_prompt(data_point):
        text = data_point["text"]
        full_prompt = f"""{text}
    """
        return tokenize(full_prompt)
    
    tokenized_train_dataset = train_dataset.map(
        generate_and_tokenize_prompt,
        cache_file_name = cache_path + "/tokenized_train_dataset.arrow",
    )
    tokenized_test_dataset = test_dataset.map(
        generate_and_tokenize_prompt,
        cache_file_name = cache_path + "/tokenized_test_dataset.arrow",
    )
    print("tokenize dataset done.")
    model.train()

    lora_config = LoraConfig(
        r=config.r,
        lora_alpha=config.lora_alpha,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=config.lora_dropout,
        bias=config.bias,
        task_type="CAUSAL_LM",        
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    wandb_project = f"train_{config.base_model_name}_{config.src_language}_{config.tgt_language}_{config.opt_level}_lora{config.r}_{config.lora_alpha}_{config.lora_dropout}_{config.bias}"
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True
    output_name = f"{wandb_project}_b{config.batch_size}_gpu{config.batch_size//config.per_device_train_batch_size}"
    output_dir = save_path + output_name
    training_args = TrainingArguments(
        per_device_train_batch_size=config.per_device_train_batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        # per_device_eval_batch_size=per_device_train_batch_size,
        # eval_accumulation_steps=gradient_accumulation_steps,
        num_train_epochs=config.num_train_epochs,
        warmup_steps=config.warmup_steps,
        learning_rate=config.learning_rate,
        lr_scheduler_type=config.lr_scheduler_type,
        fp16=config.fp16,
        fp16_opt_level=config.fp16_opt_level,
        logging_steps=config.logging_steps,
        optim=config.optim,
        evaluation_strategy=config.evaluation_strategy,
        save_strategy=config.save_strategy,
        eval_steps=config.eval_steps,
        save_steps=config.save_steps,
        output_dir=output_dir,
        save_total_limit=config.save_total_limit,
        group_by_length=False,  # group sequences of roughly the same length together to speed up training
        hub_strategy="checkpoint",
        report_to="wandb",
        run_name=f"{wandb_project}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
    )
    trainer = Trainer(
        model=model,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_test_dataset,
        args=training_args,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )

    model.config.use_cache = False

    print("compiling the model")
    model = torch.compile(model)
    trainer.train(
        # resume_from_checkpoint=None,
    )
    # save final checkpoint
    output_dir = os.path.join(output_dir, "final_checkpoint")
    trainer.model.save_pretrained(output_dir)
    print("train done")
    merge_lora(output_dir, base_model, merged_path)
    print("merge done")

def merge_lora(lora_path, base_model, save_path):
    model = LlamaForCausalLM.from_pretrained(
        base_model, torch_dtype=torch.float16, device_map="auto"
    )
    tokenizer = CodeLlamaTokenizer.from_pretrained(base_model)
    model = PeftModel.from_pretrained(model, lora_path)
    model = model.merge_and_unload()
    tokenizer.save_pretrained(save_path)
    model.save_pretrained(save_path)
    print("save done")

if __name__ == "__main__":
    config = CompilerLoraConfig()
    workspace_path = "workspace" # anonymous_data
    cache_path = os.path.join(workspace_path + "/.cache/")
    save_path = os.path.join(workspace_path + "/lora_adapters/")
    merged_path = os.path.join(workspace_path + "/models/")
    log_path = os.path.join(workspace_path + "/logs/")
    train(config, save_path, merged_path, log_path, cache_path)