from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import torch
import os
import json
from tqdm import tqdm
from datasets import load_dataset
from train_tools.peft_utils import *


def load_model_and_tokenizer(model_name, cache_dir):
    quantization_config = get_default_quantization_config()

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        low_cpu_mem_usage=True,
        cache_dir=cache_dir,
        torch_dtype=torch.bfloat16,
        quantization_config=quantization_config,
        attn_implementation="flash_attention_2",
        device_map="auto",
    )
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
    return model, tokenizer


def get_chat_template(model_name):
    chat_query = {
        "gemma-2b-it": "<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
        "phi3-mini": "<|user|>\n{}<|end|>\n<|assistant|>\n",
        "tinyllama": "<|user|>\n{}<s>\n<|assistant|>\n ",
    }

    return chat_query[model_name]


def generate_base_responses(
    dataset_path, model, tokenizer, output_file_path, prompt_template
):
    # dataset = load_dataset("json", data_files=dataset_path, split="train")

    # query_key = "instruction" if "alpaca" in dataset_path else "query"
    
    with open(dataset_path, "r") as file:
        dataset = json.load(file)
    
    
    user_inputs = [key for key in dataset.keys()]

    base_responses = []

    batch_size = 10
    with torch.no_grad():
        for i in tqdm(
            range(0, len(dataset), batch_size), desc="Generating base responses"
        ):
            sources = []

            for user_input in user_inputs[i : i + batch_size]:
                prompt = prompt_template.format(user_input)
                sources.append(prompt)

            inputs = tokenizer(
                sources,
                padding="longest",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            inputs = inputs.to(model.device)

            batch_responses = model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                max_new_tokens=512,
            )

            batch_responses = tokenizer.batch_decode(
                batch_responses[:, inputs.input_ids.shape[1] :],
                skip_special_tokens=True,
            )

            # print(batch_responses)

            for response in batch_responses:
                base_responses.append(response)

    save_base_responses(user_inputs, base_responses, output_file_path)


def save_base_responses(user_inputs, base_responses, file_path):
    responses_dict = {"user_input": user_inputs, "base_response": base_responses}
    with open(file_path, "w") as file:
        json.dump(responses_dict, file)
    print(f"Base responses saved to {file_path}")


def main(args):
    MODEL_DICT = {
        "gemma-2b-it": "google/gemma-2b-it",
        "phi3-mini": "microsoft/Phi-3-mini-128k-instruct",
        "tinyllama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    }

    model, tokenizer = load_model_and_tokenizer(
        MODEL_DICT[args.model_name], args.cache_dir
    )
    prompt_template = get_chat_template(args.model_name)

    data_path = os.path.join(args.data_root, args.dataset_name)
    response_file_name = (
        f"alpaca_base_{args.model_name}.json"
        if "alpaca" in args.dataset_name
        else f"dsp_base_{args.model_name}.json"
    )

    response_file_path = data_path.replace(".json", f"_{response_file_name}")

    if os.path.exists(response_file_path):
        print(f"Base response file already exists: {response_file_path}")
    else:
        print(f"Generating base response file: {response_file_path}")
        generate_base_responses(
            data_path, model, tokenizer, response_file_path, prompt_template
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate base responses for training data"
    )
    parser.add_argument("--model_name", type=str, default="gemma-2b-it")
    parser.add_argument("--data_root", type=str, default="./data/psoups")
    parser.add_argument("--dataset_name", type=str, default="alpaca_gpt4_P1A_10k.json")
    parser.add_argument("--cache_dir", type=str, default="./")
    args = parser.parse_args()

    main(args)
