import argparse
import csv
import json
import os
import random
from pathlib import Path
from tqdm import tqdm

from gpt import OpenAICommunicator
from gpt_prompts import templates

def parse_args():
    parser = argparse.ArgumentParser(description='Answer Generation Using OpenAI GPT Models')
    parser.add_argument('--questions_file_path', type=str, help='Path to JSONL question data')
    parser.add_argument('--output_file_path', type=str, help='Path to save output')
    parser.add_argument('--model_name', type=str, default='gpt-4-turbo-2024-04-09', help='OpenAI model code to use')
    parser.add_argument('--cache_path', type=str, default='../data/cache/gpt-4-turbo-2024-04-09_cache.pkl', help='Cache to use; filename corresponds to model name')
    parser.add_argument('--max_tokens', type=int, default='150', help='Max tokens to generate in answer using OpenAI model')
    parser.add_argument('--prompt_key', type=str, default='basic_binary', help='Prompt key to use from gpt_prompts.py')
    parser.add_argument('--few_shot_n', type=int, default='5', help='Max number of few-shot examples to use from training data')
    parser.add_argument('--few_shot_data_dir', type=str, help='Path to JSONL train question directory for few-shot examples')
    parser.add_argument('--temp', type=float, help='Temperature for self-consistency experiments')
    parser.add_argument('--self_consistency', dest='self_consistency', action='store_true', help='Flag to run self-consistency experiments')
    parser.add_argument('--num_sc_trials', type=int, default='5', help='Number of self-consistency trials to run to get majority')
    args, _ = parser.parse_known_args()
    return args

def create_numbered_procedure(steps):
    procedure = ""
    for idx, step in enumerate(steps):
        procedure += f"{idx + 1}. {step}\n"
    return procedure

def create_few_shot_examples(data, model_name, prompt_key):
    examples = []
    for example in data:
        numbered_procedure = create_numbered_procedure(example['steps'])
        title = example['title']
        model_input = templates[model_name][prompt_key].format(title=title, procedure=numbered_procedure, binary_question=example['binary_question'], why_question=example['why_question'])
        examples.extend([{"role": "user", "content": model_input}])
        if example['question_type'].startswith('dependent'):
            response = 'Yes.'
        else:
            response = 'No.'
        examples.extend([{"role": "assistant", "content": response}])
    return examples

def load_jsonl_file_or_dir(file_path):
    if os.path.isdir(file_path):
        fnames = [os.path.join(file_path, f) for f in os.listdir(file_path)]
        fnames = [f for f in fnames if 'switched' not in f and 'fake' not in f]
        data = []
        for fname in fnames:
            with open(fname, 'r') as json_file:
                json_list = list(json_file)
            for json_str in json_list:
                data.append(json.loads(json_str))
    else:
        with open(file_path, 'r') as json_file:
            json_list = list(json_file)
        data = []
        for json_str in json_list:
            data.append(json.loads(json_str))
    return data

def main(args):
    data = load_jsonl_file_or_dir(args.questions_file_path)

    options = {
        'model_name': args.model_name,
        'max_tokens': args.max_tokens,
        'cache_path': args.cache_path,
        'temperature': args.temp if args.temp else 0.0,
    }
    openai_communicator = OpenAICommunicator(options)

    Path(os.path.dirname(args.output_file_path)).mkdir(parents=True, exist_ok=True)
    out_fp = open(args.output_file_path, 'w+')

    for record in tqdm(data):
        numbered_procedure = create_numbered_procedure(record['steps'])
        title = record['title']
        # title = f'Make {title.replace('-', ' ')}'
        if args.prompt_key.endswith('explain_binary'):
            if record['question_type'].endswith('before'):
                temporal_key = 'before'
                pre_step = record['step_pair_idx_asked_about'][0]+1
                post_step = record['step_pair_idx_asked_about'][1]+1
            else:
                temporal_key = 'after'
                pre_step = record['step_pair_idx_asked_about'][1]+1
                post_step = record['step_pair_idx_asked_about'][0]+1
            why_question = f"Explain why or why not Step {pre_step} must happen {temporal_key} Step {post_step}. Think step by step."
            record['why_question'] = why_question
        gpt_prompt = []
        if args.few_shot_data_dir:
            few_shot_data = load_jsonl_file_or_dir(args.few_shot_data_dir)
            few_shot_samples = random.sample(few_shot_data, args.few_shot_n)
            few_shot_example_prompt = create_few_shot_examples(few_shot_samples, args.model_name, args.prompt_key)
            gpt_prompt.extend(few_shot_example_prompt)
        model_input = templates[args.model_name][args.prompt_key].format(title=title, procedure=numbered_procedure, binary_question=record['binary_question'], why_question=record['why_question'])
        record['model_name'] = args.model_name
        record['prompt_key'] = args.prompt_key
        record['model_input'] = model_input
        record_prompt = [{"role": "user", "content": model_input}]
        gpt_prompt.extend(record_prompt)
        if args.self_consistency:
            for trial_idx in range(args.num_sc_trials):
                trial_answer = openai_communicator.run_inference(gpt_prompt, use_cache=False)
                record[f'trial_{trial_idx}_answer'] = trial_answer
                true_cnt, false_cnt = 0, 0
                if 'yes' in trial_answer.lower():
                    true_cnt += 1
                elif 'no' in trial_answer.lower():
                    false_cnt += 1
            if true_cnt > false_cnt:
                model_answer = 'Yes.'
            else:
                model_answer = 'No.'
        else:
            model_answer = openai_communicator.run_inference(gpt_prompt)
        record['model_answer'] = model_answer
        out_fp.write(json.dumps(record) + '\n')

    out_fp.close()

if __name__ == '__main__':
    args = parse_args()
    main(args)