# Usage example: python create_mturk_files.py --question_type nondependent_switched_real --outfile_prefix ../data/mturk/input/all_models_nondependent_switched_real_binary_explain

import argparse
import csv
import glob
import pandas as pd
import json
import random
import re
import shortuuid

random.seed(1234)

def parse_args():
    parser = argparse.ArgumentParser(description='Metrics')
    parser.add_argument('--question_type', type=str, help='Type of questions to create HIT data for')
    parser.add_argument('--outfile_prefix', type=str, help='CSV file name prefix to save data')
    args, _ = parser.parse_known_args()
    return args

def parse_json_markdown(json_string: str) -> dict:
    # Try to find JSON string within first and last triple backticks
    match = re.search(r"""[`]*       # match first occuring triple backticks
                          (?:json)? # zero or one match of string json in non-capturing group
                          (.*)[`]*   # greedy match to last triple backticks""", json_string, flags=re.DOTALL|re.VERBOSE)

    # If no match found, assume the entire string is a JSON string
    if match is None:
        json_str = json_string
        breakpoint()
    else:
        # If match found, use the content within the backticks
        json_str = match.group(1)

    # Strip whitespace and newlines from the start and end
    json_str = json_str.strip()

    # Parse the JSON string into a Python dictionary while allowing control characters by setting strict to False
    try:
        parsed = json.loads(json_str, strict=False)
    except:
        match = re.search(r"[`]*\{(?:json)?(.*)\}[`]*", json_string, flags=re.DOTALL|re.VERBOSE)
        if match is None:
            json_str = json_string + '}'
            # breakpoint()
        else:
            json_str = match.group(0).strip()
        parsed = json.loads(json_str, strict=False)

    return parsed

def parse_basic_binary_must_why_output(data):
    parse_count = 0
    bpl, why_ans = [], []
    for idx, info in data.iterrows():
        loaded_info = parse_json_markdown(info['model_answer'])
        if 'yes' in loaded_info['binary_answer'].lower():
            bpl.append(True)
            why_ans.append(loaded_info['why_answer'])
            parse_count += 1
        elif 'no' in loaded_info['binary_answer'].lower():
            bpl.append(False)
            why_ans.append(loaded_info['why_answer'])
            parse_count += 1
        else:
            bpl.append(False)
            why_ans.append('Could not think of a reason')

    data['binary_prediction_label'] = bpl
    data['why_explanation'] = why_ans
    if len(data) != parse_count:
        print(f'ERROR ERROR ERROR: Total: {len(data)}, Parsed: {parse_count}')
    else:
        print(f'Total: {len(data)}, Parsed: {parse_count}')
    return data

def parse_basic_binary_must_why_nl_output(data):
    parse_count = 0
    bpl, why_ans = [], []
    for idx, info in data.iterrows():
        if info['model_answer'].startswith('Answers:'):
            # handle one particular case where model_answer is as below:
            # 'Answers:\n\n1. Yes.\n2. The chocolate mixture must be cooled to room temperature before it can be folded into the egg whites in Step 12.'
            answer_strs = info['model_answer'].replace('Answers:\n\n', '').split('\n')
            binary_answer = answer_strs[0].replace('1. ', '').strip()
            why_answer = answer_strs[1].replace('2. ', '').strip()
        else:
            answer_strs = info['model_answer'].split('\n')
            binary_answer = answer_strs[0].replace('Answer 1: ', '').strip()
            why_answer = answer_strs[1].replace('Answer 2: ', '').strip()
        
        if 'yes' in binary_answer.lower():
            bpl.append(True)
            why_ans.append(why_answer)
            parse_count += 1
        elif 'no' in binary_answer.lower():
            bpl.append(False)
            why_ans.append(why_answer)
            parse_count += 1
        else:
            bpl.append(False)
            why_ans.append('Could not think of a reason')

    data['binary_prediction_label'] = bpl
    data['why_explanation'] = why_ans
    if len(data) != parse_count:
        print(f'ERROR ERROR ERROR: Total: {len(data)}, Parsed: {parse_count}')
    else:
        print(f'Total: {len(data)}, Parsed: {parse_count}')
    return data

def create_numbered_procedure(steps):
    procedure = ""
    for idx, step in enumerate(steps):
        procedure += f"{idx + 1}. {step}\n"
    return procedure

def load_expt_data(args):
    model_names = ['gpt-4o-2024-05-13', 'gpt-4-turbo-2024-04-09', 'Meta-Llama3-8B-Instruct']
    answer_base_dir = f'../data/generated_answers'

    df = pd.DataFrame()
    for model in model_names:
        model_df = pd.DataFrame()
        multitask = True

        if multitask:
            suffix = 'explain'
        else:
            suffix = 'basic_must_why'

        file_regex = f'{answer_base_dir}/{model}/test_must_why/{args.question_type}*_{suffix}.jsonl'
        answer_files = glob.glob(file_regex)
        nl_parse = False
        if len(answer_files) == 0:
            file_regex = f'{answer_base_dir}/{model}/test_must_why/{args.question_type}*_{suffix}_nl.jsonl'
            answer_files = glob.glob(file_regex)
            nl_parse = True

        for answer_file in answer_files:
            file_df = pd.read_json(answer_file, lines=True)
            if multitask:
                if nl_parse:
                    file_df = parse_basic_binary_must_why_nl_output(file_df)
                else:
                    file_df = parse_basic_binary_must_why_output(file_df)
            else:
                file_df['why_explanation'] = file_df['model_answer']
            model_df = pd.concat([model_df, file_df], ignore_index=True)
        
        model_df['source_model'] = model
        model_df['source_question_type'] = args.question_type
        model_df['step_pair_idx_str'] = model_df['step_pair_idx_asked_about'].apply(lambda x: f'{x[0]}_{x[1]}')
        df = pd.concat([df, model_df], ignore_index=True)

    return df

# TODO: this is a sanity check; remove this later
def check_for_duplicates_and_issues(data, num_questions_per_hit):
    questions = []
    if num_questions_per_hit == 2:
        check_against = set(['title', 'goal_and_steps', 'idx', 'step_pair_idx_asked_about_str_pair_0', 'question_0_source_model', 'question_0', 'answer_0', 'step_pair_idx_asked_about_str_pair_1', 'question_1_source_model', 'question_1', 'answer_1', f'all_{num_questions_per_hit}_questions_present'])
    elif num_questions_per_hit == 6:
        check_against = set(['title', 'goal_and_steps', 'idx', 'step_pair_idx_asked_about_str_pair_0', 'question_0_source_model', 'question_0', 'answer_0', 'step_pair_idx_asked_about_str_pair_1', 'question_1_source_model', 'question_1', 'answer_1', 'step_pair_idx_asked_about_str_pair_2', 'question_2_source_model', 'question_2', 'answer_2', 'step_pair_idx_asked_about_str_pair_3', 'question_3_source_model', 'question_3', 'answer_3', 'step_pair_idx_asked_about_str_pair_4', 'question_4_source_model', 'question_4', 'answer_4', 'step_pair_idx_asked_about_str_pair_5', 'question_5_source_model', 'question_5', 'answer_5', f'all_{num_questions_per_hit}_questions_present'])
    print(len(data))
    c = 0
    na = 0
    d = 0
    for x in data:
        if not set(list(x.keys())) == check_against:
            breakpoint()
    for x in data:
        for k in list(x.keys()):
            if k.startswith('question_'):
                if k.endswith('0') == 'N/A' and k.endswith('1') == 'N/A':
                    d += 1
                if k.endswith('0') or k.endswith('1'):
                    if x[k] != 'N/A':
                        c+=1
                    else:
                        na+=1
    print(f'c = {c}; na = {na}; d = {d}')
    for info in data:
        max_qs = 0
        for k in list(info.keys()):
            if k.startswith('question_'):
                qnum = k.split('_')[-1]
                if not qnum.isdigit():
                    continue
                qnum = int(qnum)
                if qnum > max_qs:
                    max_qs = qnum
            if max_qs >= num_questions_per_hit:
                breakpoint()
        for i in range(num_questions_per_hit):
            if f'question_{i}' not in info:
                break
            if info[f'question_{i}'] == 'N/A':
                continue
            question = info[f'question_{i}']
            title = info['title']
            questions.append((title, question))
    print(len(questions))
    print(len(set(questions)))

    df = pd.DataFrame(data)

def remove_rows_with_all_na_questions(data, num_questions_per_hit):
    df = pd.DataFrame(data)

    row_idxs_to_delete = []
    for idx, row in df.iterrows():
        no_questions_present = True
        for i in range(num_questions_per_hit):
            if row[f'question_{i}'] != 'N/A':
                no_questions_present = False
                break
        if no_questions_present:
            row_idxs_to_delete.append(idx)
    clean_df = df.drop(row_idxs_to_delete)
    
    clean_data = clean_df.to_dict(orient='records')
    print(f'Total HITs to MTurk: {len(clean_data)}')
    return clean_data

def create_csv_data(df, num_questions_per_hit):

    full_data = []
    newline_char = '\n'
    grouped_df = df.groupby(['title'])
    for group_name, group in grouped_df:
        count = 0
        step_pair_grouped_df = group.groupby(['step_pair_idx_str'])
        question_count = 0
        write_info = {}
        for step_pair_group_name, step_pair_group in step_pair_grouped_df:
            step_pair_group = step_pair_group.sample(frac=1, random_state=1234)
            for x, step_pair_questions in step_pair_group.iterrows():
                write_info['title'] = step_pair_questions['title']
                local_procedure = create_numbered_procedure(step_pair_questions['steps'])
                write_info['goal_and_steps'] = f"Goal: {step_pair_questions['title']}{newline_char}Steps:{newline_char}{local_procedure}"
                write_info['idx'] = f'plan_idx_{group_name}_hit_{count}'
                if step_pair_questions['question_type'].endswith('before'):
                    write_info[f'step_pair_idx_asked_about_str_pair_{question_count}'] = step_pair_questions['step_pair_idx_str']
                    write_info[f'question_{question_count}_source_model'] = step_pair_questions['source_model']
                    write_info[f'question_{question_count}'] = f"Explain why or why not Step {step_pair_questions['step_pair_idx_asked_about'][0]+1} must happen before Step {step_pair_questions['step_pair_idx_asked_about'][1]+1}."
                    write_info[f'answer_{question_count}'] = step_pair_questions['why_explanation']
                    question_count += 1
                elif step_pair_questions['question_type'].endswith('after'):
                    write_info[f'step_pair_idx_asked_about_str_pair_{question_count}'] = step_pair_questions['step_pair_idx_str']
                    write_info[f'question_{question_count}_source_model'] = step_pair_questions['source_model']
                    write_info[f'question_{question_count}'] = f"Explain why or why not Step {step_pair_questions['step_pair_idx_asked_about'][1]+1} must happen after Step {step_pair_questions['step_pair_idx_asked_about'][0]+1}."
                    write_info[f'answer_{question_count}'] = step_pair_questions['why_explanation']
                    question_count += 1
                if question_count >= num_questions_per_hit:
                    write_info['idx'] = f'plan_idx_{group_name}_hit_{count}'
                    write_info[f'all_{num_questions_per_hit}_questions_present'] = True
                    full_data.append(write_info)
                    write_info = {}
                    write_info['title'] = step_pair_questions['title']
                    local_procedure = create_numbered_procedure(step_pair_questions['steps'])
                    write_info['goal_and_steps'] = f"Goal: {step_pair_questions['title']}{newline_char}Steps:{newline_char}{local_procedure}"
                    count += 1
                    question_count = 0
        if question_count < num_questions_per_hit:
            for x in range(question_count, num_questions_per_hit, 2):
                write_info[f'step_pair_idx_asked_about_str_pair_{question_count}'] = 'N/A'
                write_info[f'question_{x}'] = 'N/A'
                write_info[f'answer_{x}'] = 'N/A'
                write_info[f'question_{x}_source_model'] = 'N/A'
                question_count += 1
                write_info[f'step_pair_idx_asked_about_str_pair_{question_count}'] = 'N/A'
                write_info[f'question_{x+1}'] = 'N/A'
                write_info[f'answer_{x+1}'] = 'N/A'
                write_info[f'question_{x+1}_source_model'] = 'N/A'
                question_count += 1
            write_info['idx'] = f'plan_idx_{group_name}_hit_{count}'
            write_info[f'all_{num_questions_per_hit}_questions_present'] = False
            full_data.append(write_info)
            count += 1

    return full_data

def write_to_csv(data, fpath):
    fieldnames = ['uuid', 'plan_uuid'] + list(data[0].keys())
    uuid_key = fpath.split('/')[-1].split('.')[0]
    with open(fpath, mode='w+') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for idx, info in enumerate(data):
            write_data = {}
            write_data["plan_uuid"] = f"{uuid_key}_row_{idx}"
            write_data["uuid"] = shortuuid.uuid()
            for fieldname in fieldnames:
                if 'uuid' in fieldname:
                    continue
                write_data[fieldname] = info[fieldname]
            writer.writerow(write_data)

def main(args):

    df = load_expt_data(args)
    print(len(df))
    if args.question_type == 'nondependent_switched_real':
        num_questions_per_hit = 2
    else:
        num_questions_per_hit = 6
    csv_data = create_csv_data(df, num_questions_per_hit)
    print(f'Total HITs to possibly MTurk: {len(csv_data)}; needs cleaning')
    check_for_duplicates_and_issues(csv_data, num_questions_per_hit)
    clean_csv_data = remove_rows_with_all_na_questions(csv_data, num_questions_per_hit)
    random.shuffle(clean_csv_data)
    write_to_csv(clean_csv_data, f'{args.outfile_prefix}_full_data.csv')
    file_count = 0
    num_models, num_plans = 3, 120
    # TODO: Make above variables dynamic; unable to select unique recipes currently
    split = int(num_plans*num_models/(num_questions_per_hit/2))
    for i in range(0, len(clean_csv_data), int(split)):
        write_to_csv(clean_csv_data[i:i+split], f'{args.outfile_prefix}_split_{file_count}.csv')
        file_count += 1

if __name__ == '__main__':
    args = parse_args()
    main(args)