from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch
import argparse
import json
from tqdm import tqdm
import time

def main(args):
    start_time = time.time()
    # load dataset
    with open(args.dataset_path, 'r') as f:
        eval_data = json.load(f)
        
    total_data_len = len(eval_data)
    batch_size = args.batch_size
    
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, cache_dir=args.cache_dir, trust_remote_code=True)
    prompt = (f"<|user|>\n{{input}}<|end|>\n<|assistant|>\n")
    
    # load base model
    model = AutoModelForCausalLM.from_pretrained(args.base_model, 
                                                low_cpu_mem_usage=True,
                                                cache_dir=args.cache_dir, 
                                                torch_dtype=torch.bfloat16,
                                                attn_implementation=None,
                                                trust_remote_code=True)
    model.to('cuda')
    model.eval()
    
    if args.lora_model: 
        # load LoRA
        model.load_adapter(args.lora_model)
        
    # iterate through dataset and generate response
    result = {}
    for idx in tqdm(range(total_data_len // batch_size + 1)): 
        start_idx = idx * batch_size
        end_idx = min((idx+1) * batch_size, total_data_len)
        inputs = [prompt.format(input=sample['prompt']) for sample in eval_data[start_idx : end_idx]]
        
        inputs = tokenizer(
            inputs,
            padding="longest",
            truncation=True,
            max_length=512,
            return_tensors="pt",
        )
        prompt_len = inputs['input_ids'].shape[1]

        # generate response
        outputs = model.generate(
            input_ids=inputs.input_ids.to(model.device),
            attention_mask=inputs.attention_mask.to(model.device),
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            max_new_tokens=512,
        )
        
        decoded_outputs = tokenizer.batch_decode(outputs[:, prompt_len:], skip_special_tokens=True)
        result.update({str(i) : decoded_output for i, decoded_output in zip(range(start_idx, end_idx), decoded_outputs)})
        
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    with open(args.output_path, 'w') as w:
        json.dump(result, w)
    
    print(f"Finished generating... Took {time.time() - start_time} seconds.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate base responses for training data"
    )
    parser.add_argument("--base_model", type=str, default="microsoft/Phi-3-mini-128k-instruct")
    parser.add_argument("--lora_model", type=str, default=None)
    parser.add_argument("--dataset_path", type=str, default="alpaca_gpt4_P1A_10k.json")
    parser.add_argument("--output_path", type=str, default="./generations/baseline.json")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--cache_dir", type=str, default="./")
    args = parser.parse_args()

    main(args)