import google.generativeai as genai
import argparse
import json
import os
import atexit
from pathlib import Path
from tqdm import tqdm
from config import config
import pickle
import hashlib
import re
import time
from gemini_prompts import templates
import random

random.seed(1234)

genai.configure(api_key=config["GEMINI_API_KEY"])

def parse_args():
    parser = argparse.ArgumentParser(description='Answer Generation Using Gemini Models')
    parser.add_argument('--questions_file_path', type=str, help='Path to JSONL question data')
    parser.add_argument('--output_file_path', type=str, help='Path to save output')
    parser.add_argument('--model_name', type=str, default='gemini-1.0-pro-latest', help='OpenAI model code to use')
    parser.add_argument('--cache_path', type=str, default='../data/cache/gemini-1.0-pro-latest_cache.pkl', help='Cache to use; filename corresponds to model name')
    parser.add_argument('--max_tokens', type=int, default='150', help='Max tokens to generate in answer using Gemini model')
    parser.add_argument('--prompt_key', type=str, default='basic_binary', help='Prompt key to use from gpt_prompts.py')
    args, _ = parser.parse_known_args()
    return args

def load_cache_if_exists(cache_path):
    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as handle:
            cache_file = pickle.load(handle)
            return cache_file
    else:
        os.makedirs(os.path.dirname(cache_path), exist_ok=True)
        return {}

def create_numbered_procedure(steps):
    procedure = ""
    for idx, step in enumerate(steps):
        procedure += f"{idx + 1}. {step}\n"
    return procedure

def get_available_models():
    for m in genai.list_models():
        if 'generateContent' in m.supported_generation_methods:
            print(f'{m.name}: top_p={m.top_p}, top_k={m.top_k}')

        # models/gemini-1.0-pro: top_p=1.0, top_k=1
        # models/gemini-1.0-pro-001: top_p=1.0, top_k=1
        # models/gemini-1.0-pro-latest: top_p=1.0, top_k=1
        # models/gemini-1.0-pro-vision-latest: top_p=1.0, top_k=32
        # models/gemini-1.5-pro-latest: top_p=0.95, top_k=None
        # models/gemini-pro: top_p=1.0, top_k=1
        # models/gemini-pro-vision: top_p=1.0, top_k=32

def save_data(cache, data, cache_path, output_file_path):
    print(f'Cleanup and save...')
    Path(os.path.dirname(output_file_path)).mkdir(parents=True, exist_ok=True)
    out_fp = open(output_file_path, 'w+')
    for record in data:
        out_fp.write(json.dumps(record) + '\n')
    out_fp.close()

    with open(cache_path, 'rb') as file:
        old_cache = pickle.load(file)

    old_cache.update(cache)
    # old_cache is now updated with cache values from this run
    with open(cache_path, 'wb') as file:
        pickle.dump(old_cache, file)

def call_gemini(model, gemini_prompt, idx):
    try:
        response = model.generate_content(gemini_prompt)
    except Exception as e:
        print(f'Exception: {e} when trying record {idx}')
        time.sleep(60)
        print(f'Retrying for record {idx}...')
        response = call_gemini(model, gemini_prompt, idx)
    return response

def parse_model_response_text(response, model, gemini_prompt, idx):
    try:
        return response.text
    except:
        # sometimes model considers it unsafe text and does not generate anything
        # this adds a retry which usually fixes that issue
        response = call_gemini(model, gemini_prompt, idx)
        return parse_model_response_text(response, model, gemini_prompt, idx)

def main(args):
    with open(args.questions_file_path, 'r') as json_file:
        json_list = list(json_file)
    data = []
    for json_str in json_list:
        data.append(json.loads(json_str))

    cached_responses = load_cache_if_exists(args.cache_path)
    output_data = []

    model = genai.GenerativeModel(args.model_name)
    candidatecount, temp, topp, topk = 1, 0.0, 1.0, 1
    generation_config = genai.GenerationConfig(
        candidate_count = candidatecount,
        max_output_tokens = args.max_tokens,
        temperature = temp,
        top_p = topp,
        top_k = topk
    )

    atexit.register(save_data, cached_responses, output_data, args.cache_path, args.output_file_path)

    for idx, record in tqdm(enumerate(data)):
        numbered_procedure = create_numbered_procedure(record['steps'])
        title = record['title']
        # title = f'Make {title.replace('-', ' ')}'

        model_input = templates[args.model_name][args.prompt_key].format(title=title, procedure=numbered_procedure, binary_question=record['binary_question'], why_question=record['why_question'])
        record['model_name'] = args.model_name
        record['prompt_key'] = args.prompt_key
        record['model_input'] = model_input

        gemini_prompt = model_input

        hashed_prompt = hashlib.sha256(str(gemini_prompt).encode("utf-8")).hexdigest()
        cache_key = (hashed_prompt, args.model_name, args.max_tokens, temp, topp, topk, candidatecount)
        if cache_key in cached_responses:
            model_answer = cached_responses[cache_key]['text']
        else:
            response = call_gemini(model, gemini_prompt, idx)
            model_answer = parse_model_response_text(response, model, gemini_prompt, idx)
            cached_responses[cache_key] = {'text': model_answer}
        
        record['model_answer'] = model_answer
        output_data.append(record)

if __name__ == '__main__':
    args = parse_args()
    main(args)