import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
import atexit
import os
import sys
import pandas as pd
import json
import numpy as np
import csv
import argparse
import hashlib
import pickle
from tqdm import tqdm
from llama_prompts import templates
from pathlib import Path
import random

random.seed(1234)

def parse_args():
    parser = argparse.ArgumentParser(description='Huggingface Llama3-8B')
    parser.add_argument('--model_name', type=str, default='meta-llama/Meta-Llama-3-8B-Instruct', help='Model name.')
    parser.add_argument('--output_file_path', type=str, help='Path to save output')
    parser.add_argument('--max_tokens', type=int, default='150', help='Max tokens to generate in answer using OpenAI model')
    parser.add_argument('--cache_path', type=str, default='../data/cache/Meta-Llama-3-8B-Instruct_cache.pkl', help='Cache to use; filename corresponds to model name')
    parser.add_argument('--prompt_key', type=str, default='basic_binary', help='Prompt key to use from gpt_prompts.py')
    parser.add_argument('--questions_file_path', type=str, help='Path to JSONL question data')
    args = parser.parse_args()
    return args

def load_cache_if_exists(cache_path):
    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as handle:
            cache_file = pickle.load(handle)
            return cache_file
    else:
        os.makedirs(os.path.dirname(cache_path), exist_ok=True)
        return {}

def create_numbered_procedure(steps):
    procedure = ""
    for idx, step in enumerate(steps):
        procedure += f"{idx + 1}. {step}\n"
    return procedure

def save_data(cache, data, cache_path, output_file_path):
    print(f'Cleanup and save...')
    Path(os.path.dirname(output_file_path)).mkdir(parents=True, exist_ok=True)
    out_fp = open(output_file_path, 'w+')
    for record in data:
        out_fp.write(json.dumps(record) + '\n')
    out_fp.close()

    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as file:
            old_cache = pickle.load(file)
        # old_cache is now updated with cache values from this run
        old_cache.update(cache)
    else:
        old_cache = cache
    with open(cache_path, 'wb') as file:
        pickle.dump(old_cache, file)

def main(args):

    with open(args.questions_file_path, 'r') as json_file:
        json_list = list(json_file)
    data = []
    for json_str in json_list:
        data.append(json.loads(json_str))

    output_data = []
    cached_responses = load_cache_if_exists(args.cache_path)
    atexit.register(save_data, cached_responses, output_data, args.cache_path, args.output_file_path)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.model_name == 'meta-llama/Meta-Llama-3-8B-Instruct':
        model_path = '/shared_archive/yklal95/Meta-Llama-3-8B-Instruct/snapshots/e5e23bbe8e749ef0efcf16cad411a7d23bd23298'
        print('Using local weights')
    else:
        model_path = 'meta-llama/Meta-Llama-3-8B-Instruct'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
    terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk

    model.eval()
    
    model_outputs = []
    model.to(torch.device(device))

    frequency_penalty, temp, topp, presence_penalty = 0, 0.0, 1.0, 0

    for idx, record in tqdm(enumerate(data)):

        numbered_procedure = create_numbered_procedure(record['steps'])
        title = record['title']

        model_input = templates[args.model_name][args.prompt_key].format(title=title, procedure=numbered_procedure, binary_question=record['binary_question'], why_question=record['why_question'])
        record['model_input'] = model_input

        prompt = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": model_input},
        ]

        hashed_prompt = hashlib.sha256(str(prompt).encode("utf-8")).hexdigest()
        cache_key = (hashed_prompt, args.model_name, args.max_tokens, temp, topp, frequency_penalty, presence_penalty)
        if cache_key in cached_responses:
            model_answer = cached_responses[cache_key]['text']
        else:
            input_ids = tokenizer.apply_chat_template(
                prompt,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    input_ids,
                    max_new_tokens=args.max_tokens,
                    eos_token_id=terminators,
                    do_sample=False,
                    temperature=0.0,
                )
            response = outputs[0][input_ids.shape[-1]:]
            model_answer = tokenizer.decode(response, skip_special_tokens=True)
            cached_responses[cache_key] = {'text': model_answer}

        record['model_answer'] = model_answer
        output_data.append(record)

if __name__ == '__main__':
    args = parse_args()
    main(args)
