import hashlib
import pickle
import argparse
import json
import os
from tqdm import tqdm
import glob

def parse_args():
    parser = argparse.ArgumentParser(description='Cache operations')
    parser.add_argument('--mode', type=str, help='Key to store the value')
    parser.add_argument('--cache_path', type=str, default='../data/cache/openai_cache.pkl', help='Cache file to write to')
    parser.add_argument('--max_tokens', type=int, help='Number of tokens used in generation')
    parser.add_argument('--dirname', type=str, help='Directory of files to use for cache operations')
    parser.add_argument('--expt_name', type=str, help='Experiment key to identify files')
    parser.add_argument('--model_name', type=str, help='Model name corresponding to cache')
    args, _ = parser.parse_known_args()
    return args

def load_data_to_add_to_cache(fname):
    data = []
    with open(fname, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        data.append(json.loads(json_str))
    return data

def convert_to_gpt_cache(data, cache, cache_path, max_tokens, temp=0.0, top_p=1.0, frequency_penalty=0.0, presence_penalty=0.0):
    for info in tqdm(data):
        hashed_prompt = hashlib.sha256(str(info['model_input']).encode("utf-8")).hexdigest()
        cache_key = (hashed_prompt, info['model_name'], max_tokens, temp, top_p, frequency_penalty, presence_penalty)
        cache[cache_key] = {'text': info['model_answer']}
    with open(cache_path, 'wb') as handle:
        pickle.dump(cache, handle)
        
def convert_to_llama_cache(data, cache, cache_path, max_tokens, temp=0.0, topp=0.0, frequency_penalty=1.0, presence_penalty=0.0):
    for info in tqdm(data):
        hashed_prompt = hashlib.sha256(str(info['model_input']).encode("utf-8")).hexdigest()
        cache_key = (hashed_prompt, 'Meta-Llama3-8B-Instruct', max_tokens, temp, topp, frequency_penalty, presence_penalty)
        cache[cache_key] = {'text': info['model_answer']}
    with open(cache_path, 'wb') as handle:
        pickle.dump(cache, handle)

def convert_to_gemini_cache(data, cache, model_name, cache_path, max_tokens, temp=0.0, topp=1.0, topk=1, candidatecount=1):
    for info in tqdm(data):
        hashed_prompt = hashlib.sha256(str(info['model_input']).encode("utf-8")).hexdigest()
        cache_key = (hashed_prompt, model_name, max_tokens, temp, topp, topk, candidatecount)
        cache[cache_key] = {'text': info['model_answer']}
    with open(cache_path, 'wb') as handle:
        pickle.dump(cache, handle)

def main(args):
    if args.mode == 'create' or args.mode == 'update':
        if args.mode == 'create':
            if os.path.exists(args.cache_path):
                os.remove(args.cache_path)
            cache = {}
        elif args.mode == 'update':
            with open(args.cache_path, 'rb') as handle:
                cache = pickle.load(handle)
        data = []
        fnames = os.listdir(args.dirname)
        fnames = [x for x in fnames if x.endswith(f'{args.expt_name}.jsonl')]
        for fname in fnames:
            new_data = load_data_to_add_to_cache(os.path.join(args.dirname, fname))
            data.extend(new_data)
        if args.model_name == 'gpt-3.5-turbo-1106' or args.model_name == 'gpt-4-turbo-2024-04-09':
            convert_to_gpt_cache(data, cache, args.cache_path, args.max_tokens)
        elif args.model_name == 'Meta-Llama3-8B-Instruct':
            convert_to_llama_cache(data, cache, args.cache_path, args.max_tokens)
        elif args.model_name == 'gemini-1.0-pro-latest':
            convert_to_gemini_cache(data, cache, args.model_name, args.cache_path, args.max_tokens)
        elif args.model_name == 'gemini-1.5-pro-latest':
            convert_to_gemini_cache(data, cache, args.model_name, args.cache_path, args.max_tokens)
        elif args.model_name == 'gemini-1.5-flash-latest':
            convert_to_gemini_cache(data, cache, args.model_name, args.cache_path, args.max_tokens)
        print(len(cache.keys()))
    elif args.mode == 'check':
        with open(args.cache_path, 'rb') as handle:
            cache = pickle.load(handle)
        print(len(cache.keys()))

if __name__ == '__main__':
    args = parse_args()
    main(args)