import os
import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging
)
from peft import LoraConfig, AutoPeftModelForCausalLM
from trl import SFTTrainer
import transformers
import json
from huggingface_hub import login
from tqdm import tqdm
from trl.trainer import ConstantLengthDataset


def chars_token_ratio(dataset, tokenizer, nb_examples=400):
    """
    Estimate the average number of characters per token in the dataset.
    """
    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        total_characters += len(example["text"])
        if tokenizer.is_fast:
            total_tokens += len(tokenizer(example["text"]).tokens())
        else:
            total_tokens += len(tokenizer.tokenize(example["text"]))

    return total_characters / total_tokens


def fine_tune(train_data, validation_data, test_data, test_labels):
    login()
    # Set the name of the model we'll use for the rest of the notebook
    model_name = "meta-llama/Llama-2-7b-hf"

    # Load the entire model on the GPU 0
    device_map = {"": 0}

    # Set base model loading in 4-bits
    use_4bit = True

    # Compute dtype for 4-bit base models
    bnb_4bit_compute_dtype = torch.bfloat16

    # Quantization type (fp4 or nf4)
    bnb_4bit_quant_type = "nf4"

    # Activate nested quantization for 4-bit base models (double quantization)
    use_nested_quant = False
    # Load dataset (you can process it here)
    # dataset = load_dataset(path=os.path.join(data_path, "final_complete_qa.json"), split="train")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
        bnb_4bit_use_double_quant=use_nested_quant
    )
    # Load the base model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device_map,
        quantization_config=bnb_config,
        return_dict=True,
        low_cpu_mem_usage=True,
        cache_dir="llama_model_7b"
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    # Load the model tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # Define a custom padding token
    tokenizer.pad_token = "<PAD>"

    # Set the padding direction to the right
    tokenizer.padding_side = "right"

    chars_per_token = chars_token_ratio(train_data, tokenizer, len(train_data))
    print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")

    train_dataset = ConstantLengthDataset(
        tokenizer,
        train_data,
        dataset_text_field="text",
        infinite=False,
        seq_length=1024,
        chars_per_token=chars_per_token
    )

    validation_dataset = ConstantLengthDataset(
        tokenizer,
        validation_data,
        dataset_text_field="text",
        infinite=False,
        seq_length=1024,
        chars_per_token=chars_per_token
    )

    new_model = 'Llama-7b-reddit-finance'

    # LoRA attention dimension
    lora_r = 16
    # Alpha for LoRA scaling
    lora_alpha = 32
    # Dropout probability for LoRA
    lora_dropout = 0.05

    # Create the LoRA configuration
    peft_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        inference_mode=False,
        bias="none",
        task_type="CAUSAL_LM",
    )

    output_dir = "./results_financial_phrasebank_sft_big"
    final_checkpoint_dir = os.path.join(output_dir, "final_checkpoint")

    num_train_epochs = 1
    max_steps = -1
    bf16 = True
    fp16 = False
    batch_size = 8
    gradient_accumulation_steps = 1
    max_grad_norm = 0.3
    optim = "paged_adamw_32bit"
    learning_rate = 1e-4
    lr_scheduler_type = "cosine"
    warmup_steps = 100
    weight_decay = 0.05
    gradient_checkpointing = True
    save_steps = 1000
    logging_steps = 10

    training_arguments = TrainingArguments(
        output_dir=output_dir,
        dataloader_drop_last=True,
        evaluation_strategy="steps",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        gradient_checkpointing=gradient_checkpointing,
        optim=optim,
        save_steps=save_steps,
        logging_steps=logging_steps,
        learning_rate=learning_rate,
        num_train_epochs=num_train_epochs,
        weight_decay=weight_decay,
        fp16=fp16,
        bf16=bf16,
        max_grad_norm=max_grad_norm,
        max_steps=max_steps,
        warmup_steps=warmup_steps,
        lr_scheduler_type=lr_scheduler_type,
        run_name="llama-7b-finetuned_small",
        report_to="wandb",
        ddp_find_unused_parameters=False,
    )

    max_seq_length = 4096
    packing = True

    # Set the supervised fine-tuning parameters
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        peft_config=peft_config,
        dataset_text_field="text",
        tokenizer=tokenizer,
        args=training_arguments,
        packing=packing,
    )

    resume_checkpoint = None

    transformers.logging.set_verbosity_info()

    trainer.train(resume_checkpoint)

    trainer.save_model(final_checkpoint_dir)

    # Load the entire model on the GPU 0
    device_map = {"": 0}
    reloaded_model = AutoPeftModelForCausalLM.from_pretrained(
        final_checkpoint_dir,
        low_cpu_mem_usage=True,
        return_dict=True,
        torch_dtype=torch.float16,
        device_map=device_map,
    )
    reloaded_tokenizer = AutoTokenizer.from_pretrained(final_checkpoint_dir)

    # Merge the LoRA and the base model
    merged_model = reloaded_model.merge_and_unload()
    # Save the merged model
    merged_dir = os.path.join(output_dir, "final_merged_checkpoint")
    merged_model.save_pretrained(merged_dir)
    reloaded_tokenizer.save_pretrained(merged_dir)


def generate_prompt(sentence, sentiment):
    return f"""
            Analyze the sentiment of the news headline enclosed in square brackets, 
            determine if it is positive, neutral, or negative, and return the answer as 
            the corresponding sentiment label "positive" or "neutral" or "negative".

            [{sentence}] = {sentiment}
            """.strip()

def generate_test_prompt(sentence):
    return f"""
            Analyze the sentiment of the news headline enclosed in square brackets, 
            determine if it is positive, neutral, or negative, and return the answer as 
            the corresponding sentiment label "positive" or "neutral" or "negative".

            [{sentence}] = """.strip()


def rework_label(label):
    if label == 0:
        return "negative"
    elif label == 1:
        return "neutral"
    elif label == 2:
        return "positive"

def get_proper_data_format(data):
    data = data["train"].train_test_split(test_size=0.2)
    test_set = data["test"]
    data = data["train"].train_test_split(test_size=0.2)
    train_set = data["train"]
    validation_set = data["test"]
    final_train = []
    final_validation = []
    final_test = []
    final_test_labels = []
    for el in train_set:
        sentiment = rework_label(el["label"])
        new_el = {
            "text": generate_prompt(el["sentence"], sentiment)
        }
        final_train.append(new_el)
    for el in validation_set:
        sentiment = rework_label(el["label"])
        new_el = {
            "text": generate_prompt(el["sentence"], sentiment)
        }
        final_validation.append(new_el)
    for el in test_set:
        sentiment = rework_label(el["label"])
        new_el = {
            "text": generate_test_prompt(el["sentence"])
        }
        final_test.append(new_el)
        final_test_labels.append(sentiment)

    return final_train, final_validation, final_test, final_test_labels

if __name__ == "__main__":
    data = load_dataset("financial_phrasebank", 'sentences_50agree')
    train, validation, test, test_labels = get_proper_data_format(data)
    with open(os.path.join("/data", "train_financial.json"), "w") as file:
        json.dump(train, file)
    print("Adapted data format...")
    fine_tune(train, validation, test, test_labels)
