import os
import json
import argparse
import re
import random
import csv
import pandas as pd
from sklearn.metrics import classification_report

random.seed(1234)

def parse_args():
    parser = argparse.ArgumentParser(description='Metrics')
    parser.add_argument('--dirname', type=str, help='Directory of files to calculate metrics over')
    parser.add_argument('--expt_name', type=str, help='Expt name to find corresponding files for metrics')
    parser.add_argument('--qual_file', type=str, help='CSV file name to save data sample for quualitative analysis')
    args, _ = parser.parse_known_args()
    return args

def sanitize_for_sheets_rendering(text):
    esc_char = chr(10)
    text = text.replace('\n', esc_char)
    return text

def load_file_data(fname):
    data = []
    with open(fname, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        info = json.loads(json_str)
        data.append(info)
    return data

def add_binary_gold_label(data):
    for info in data:
        info['step_difference'] = abs(info['step_pair_idx_asked_about'][1] - info['step_pair_idx_asked_about'][0])
        threshold = 3
        if info['step_difference'] <= threshold:
            info['dependency'] = 'close'
        else:
            info['dependency'] = 'far'
        if info['question_type'].startswith('dependent_real_'):
            info['binary_gold_label'] = True
        else:
            info['binary_gold_label'] = False
    return data

def parse_basic_must_why_output(data):
    for info in data:
        info['why_explanation'] = info['model_answer']
    return data

def parse_basic_binary_output(data):
    parse_count = 0
    for info in data:
        if 'yes' in info['model_answer'].lower():
            info['binary_prediction_label'] = True
            parse_count += 1
        elif 'no' in info['model_answer'].lower():
            info['binary_prediction_label'] = False
            parse_count += 1
        else:
            info['binary_prediction_label'] = False
    print(f'Total: {len(data)}, Parsed: {parse_count}')
    return data

def parse_json_markdown(json_string: str) -> dict:
    # Try to find JSON string within first and last triple backticks
    match = re.search(r"""[`]*       # match first occuring triple backticks
                          (?:json)? # zero or one match of string json in non-capturing group
                          (.*)[`]*   # greedy match to last triple backticks""", json_string, flags=re.DOTALL|re.VERBOSE)

    # If no match found, assume the entire string is a JSON string
    if match is None:
        json_str = json_string
        breakpoint()
    else:
        # If match found, use the content within the backticks
        json_str = match.group(1)

    # Strip whitespace and newlines from the start and end
    json_str = json_str.strip()

    # Parse the JSON string into a Python dictionary while allowing control characters by setting strict to False
    try:
        parsed = json.loads(json_str, strict=False)
    except:
        match = re.search(r"[`]*\{(?:json)?(.*)\}[`]*", json_string, flags=re.DOTALL|re.VERBOSE)
        if match is None:
            json_str = json_string + '}'
            # breakpoint()
        else:
            json_str = match.group(0).strip()
        parsed = json.loads(json_str, strict=False)

    return parsed

def parse_basic_binary_must_why_output(data):
    parse_count = 0
    for info in data:
        loaded_info = parse_json_markdown(info['model_answer'])
        if 'yes' in loaded_info['binary_answer'].lower():
            info['binary_prediction_label'] = True
            info['why_explanation'] = loaded_info['why_answer']
            parse_count += 1
        elif 'no' in loaded_info['binary_answer'].lower():
            info['binary_prediction_label'] = False
            info['why_explanation'] = loaded_info['why_answer']
            parse_count += 1
    if len(data) != parse_count:
        print(f'ERROR ERROR ERROR: Total: {len(data)}, Parsed: {parse_count}')
    else:
        print(f'Total: {len(data)}, Parsed: {parse_count}')
    return data

def parse_basic_binary_must_why_nl_output(data):
    parse_count = 0
    for info in data:
        if info['model_answer'].startswith('Answers:'):
            # handle one particular case where model_answer is as below:
            # 'Answers:\n\n1. Yes.\n2. The chocolate mixture must be cooled to room temperature before it can be folded into the egg whites in Step 12.'
            answer_strs = info['model_answer'].replace('Answers:\n\n', '').split('\n')
            binary_answer = answer_strs[0].replace('1. ', '').strip()
            why_answer = answer_strs[1].replace('2. ', '').strip()
        else:
            answer_strs = info['model_answer'].split('\n')
            binary_answer = answer_strs[0].replace('Answer 1: ', '').strip()
            why_answer = answer_strs[1].replace('Answer 2: ', '').strip()
        
        if 'yes' in binary_answer.lower():
            info['binary_prediction_label'] = True
            info['why_explanation'] = why_answer
            parse_count += 1
        elif 'no' in binary_answer.lower():
            info['binary_prediction_label'] = False
            info['why_explanation'] = why_answer
            parse_count += 1
        else:
            breakpoint()
    if len(data) != parse_count:
        print(f'ERROR ERROR ERROR: Total: {len(data)}, Parsed: {parse_count}')
    else:
        print(f'Total: {len(data)}, Parsed: {parse_count}')
    return data

def save_for_qualitative_analysis(data, fname):
    categories = ['dependent_real', 'nondependent_real', 'nondependent_switched_real']
    qual_data = []
    for category in categories:
        category_data = [x for x in data if x['question_type'].startswith(category)]
        sample = random.sample(category_data, 20)
        qual_data.extend(sample)
    fieldnames = ['model_input', 'question_type', 'binary_gold_label']
    if 'why_explanation' in qual_data[0]:
        fieldnames.append('why_explanation')
    if 'binary_prediction_label' in qual_data[0]:
        fieldnames.append('binary_prediction_label')
    with open(fname, 'w+') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for info in qual_data:
            write_data = {}
            for fieldname in fieldnames:
                write_data[fieldname] = info[fieldname]
            writer.writerow(write_data)

def calculate_accuracy(data):
    correct, total = 0, 0
    binary_yes_prediction, binary_no_prediction = 0, 0
    for info in data:
        if info['binary_prediction_label']:
            binary_yes_prediction += 1
        elif not info['binary_prediction_label']:
            binary_no_prediction += 1
        if info['binary_gold_label'] == info['binary_prediction_label']:
            correct += 1
        total += 1
    try:
        accuracy = correct / total
    except ZeroDivisionError:
        accuracy = 0
    return round(accuracy*100, 2), total, binary_yes_prediction, binary_no_prediction

def calculate_precision_recall_f1(data):
    tp, fp, tn, fn = 0, 0, 0, 0
    for info in data:
        if info['binary_gold_label'] == info['binary_prediction_label']:
            if info['binary_gold_label']:
                tp += 1
            else:
                tn += 1
        else:
            if info['binary_gold_label']:
                fn += 1
            else:
                fp += 1
    try:
        precision = tp / (tp + fp)
    except ZeroDivisionError:
        precision = 0
    try:
        recall = tp / (tp + fn)
    except ZeroDivisionError:
        recall = 0.0
    try:
        f1 = 2 * (precision * recall) / (precision + recall)
    except ZeroDivisionError:
        f1 = 0.0
    print(f"Precision: {round(precision, 2)}, Recall: {round(recall, 2)}, F1: {round(f1, 2)}; TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")

def calculate_classification_report(data):
    preds, labels = [], []
    for info in data:
        preds.append(info['binary_prediction_label'])
        labels.append(info['binary_gold_label'])
    print(classification_report(labels, preds, target_names=['nondependent', 'dependent']))

def find_steps_and_type_in_question(data):
    for info in data:
        steps = re.findall(r'Step [0-9]*', info['binary_question'])
        step_strs = [int(x.split(' ')[1]) for x in steps]
        step_strs.sort()
        info['steps_in_question'] = f'{step_strs[0]}_{step_strs[1]}'
        if info['question_type'].endswith('before'):
            info['temporal_relation'] = 'before'
        elif info['question_type'].endswith('after'):
            info['temporal_relation'] = 'after'
        info['question_category'] = '_'.join(info['question_type'].split('_')[:-2])
    return data

def temporal_consistency(data):
    before_data = [x for x in data if x['question_type'].endswith('before')]
    before_data = find_steps_and_type_in_question(before_data)
    after_data = [x for x in data if x['question_type'].endswith('after')]
    after_data = find_steps_and_type_in_question(after_data)
    
    df = pd.DataFrame(before_data+after_data)
    correct, total = 0, 0
    grouped_df = df.groupby(['title', 'steps_in_question', 'question_category'])
    for group_name, group in grouped_df:
        if group.shape[0] != 2:
            breakpoint()
        if group.iloc[0]['binary_prediction_label'] == group.iloc[1]['binary_prediction_label']:
            correct += 1
        total += 1
    try:
        consistency = correct/total*100
    except ZeroDivisionError:
        consistency = 0
    print(f'Temporal consistency: {round(consistency, 2)}% out of {total} samples')

def nondependent_switch_consistency(data):
    correct, total = 0, 0

    nondependent_real_data = [x for x in data if x['question_type'].startswith('nondependent_real')]
    nondependent_real_data = find_steps_and_type_in_question(nondependent_real_data)
    nondependent_switched_real_data = [x for x in data if x['question_type'].startswith('nondependent_switched_real')]
    nondependent_switched_real_data = find_steps_and_type_in_question(nondependent_switched_real_data)

    df = pd.DataFrame(nondependent_real_data+nondependent_switched_real_data)
    correct, total = 0, 0
    grouped_df = df.groupby(['title', 'steps_in_question', 'temporal_relation'])
    for group_name, group in grouped_df:
        if group.shape[0] != 2:
            breakpoint()
        if group.iloc[0]['binary_prediction_label'] == group.iloc[1]['binary_prediction_label']:
            correct += 1
        total += 1

    try:
        consistency = correct/total*100
    except ZeroDivisionError:
        consistency = 0
    print(f'Nondependent switch consistency: {round(consistency, 2)}% out of {total} samples')

def data_split_by_distance(data):
    categories = ['close', 'far']
    close_data, far_data = [], []
    for category in categories:
        category_data = [x for x in data if x['dependency'] == category]
        if category == 'close':
            close_data.extend(category_data)
        elif category == 'far':
            far_data.extend(category_data)
    print(f'Far data: {len(far_data)}; Close data: {len(close_data)}')
    return close_data, far_data

def categorical_accuracy(category_data, category):
    # category_accuracy, category_total, binary_yes_prediction, binary_no_prediction = calculate_accuracy(category_data)
    # print(f'{category} accuracy: {category_accuracy}% for {category_total} samples; #binary yes: {binary_yes_prediction}, #binary no: {binary_no_prediction}')
    print(f'Working with {category} data')
    calculate_classification_report(category_data)

def calculate_metrics_by_category(data):
    categories = ['dependent_real', 'nondependent_real', 'nondependent_switched_real']
    for category in categories:
        category_data = [x for x in data if x['question_type'].startswith(category)]
        categorical_accuracy(category_data, category)
        calculate_precision_recall_f1(category_data)
        temporal_consistency(category_data)
        print()

def calculate_accuracy_by_temporal_relation(data):
    categories = ['before', 'after']
    for category in categories:
        category_data = [x for x in data if x['question_type'].endswith(category)]
        categorical_accuracy(category_data, category)

# TODO: Clean up this function or remove it altogether
def data_check(data):
    tp, fp, tn, fn = [], [], [], []
    tpc, fpc, tnc, fnc = 0, 0, 0, 0
    for info in data:
        if info['binary_gold_label'] == info['binary_prediction_label']:
            if info['binary_gold_label']:
                if tpc < 10:
                    tp.append(info)
                    tpc += 1
            else:
                if tnc < 10:
                    tn.append(info)
                    tnc += 1
        else:
            if info['binary_gold_label']:
                if fnc < 10:
                    fn.append(info)
                    fnc += 1
            else:
                if fpc < 10:
                    fp.append(info)
                    fpc += 1
    
    modelname = 'gpt4turbo'
    arr, arr_str = [tp, fp, tn, fn], ['tp', 'fp', 'tn', 'fn']
    for a,s in zip(arr, arr_str):
        filename = f'../data/qualitative/{modelname}_{s}.csv'
        keys = a[0].keys()
        with open(filename, 'w+', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=keys)
            writer.writeheader()
            writer.writerows(a)

def run_evaluation(full_data):
    data = [x for x in full_data if 'dependent_real' in x['question_type']]
    # accuracy, total, binary_yes_prediction, binary_no_prediction = calculate_accuracy(data)
    # print(f'Overall accuracy: {accuracy}% for {total} samples; #binary yes: {binary_yes_prediction}, #binary no: {binary_no_prediction}')
    calculate_classification_report(data)
    temporal_consistency(data)
    nondependent_switch_consistency(full_data)

    calculate_accuracy_by_temporal_relation(data)

def main(args):
    data = []
    fnames = os.listdir(args.dirname)
    fnames = [x for x in fnames if x.endswith(f'{args.expt_name}.jsonl')]
    fnames = [x for x in fnames if 'real' in x]
    for fname in fnames:
        new_data = load_file_data(os.path.join(args.dirname, fname))
        data.extend(new_data)
    data = add_binary_gold_label(data)
    if args.expt_name == 'basic_binary':
        data = parse_basic_binary_output(data)
        run_evaluation(data)
        close_data, far_data = data_split_by_distance(data)
        print('\nLooking at far data')
        run_evaluation(far_data)
        print('\nLooking at close data')
        run_evaluation(close_data)
    elif args.expt_name == 'basic_must_why':
        data = parse_basic_must_why_output(data)
    elif args.expt_name == 'basic_binary_must_why':
        data = parse_basic_binary_must_why_output(data)
        run_evaluation(data)
        close_data, far_data = data_split_by_distance(data)
        print('\nLooking at far data')
        run_evaluation(far_data)
        print('\nLooking at close data')
        run_evaluation(close_data)
    elif args.expt_name == 'basic_binary_must_why_nl':
        data = parse_basic_binary_must_why_nl_output(data)
        run_evaluation(data)
        close_data, far_data = data_split_by_distance(data)
        print('\nLooking at far data')
        run_evaluation(far_data)
        print('\nLooking at close data')
        run_evaluation(close_data)
    elif args.expt_name == 'basic_binary_explain':
        data = parse_basic_binary_must_why_output(data)
        run_evaluation(data)
        close_data, far_data = data_split_by_distance(data)
        print('\nLooking at far data')
        run_evaluation(far_data)
        print('\nLooking at close data')
        run_evaluation(close_data)
    elif args.expt_name == 'explain_binary':
        data = parse_basic_binary_must_why_output(data)
        run_evaluation(data)
        close_data, far_data = data_split_by_distance(data)
        print('\nLooking at far data')
        run_evaluation(far_data)
        print('\nLooking at close data')
        run_evaluation(close_data)
    elif args.expt_name == 'basic_binary_explain_nl':
        data = parse_basic_binary_must_why_nl_output(data)
        run_evaluation(data)
        close_data, far_data = data_split_by_distance(data)
        print('\nLooking at far data')
        run_evaluation(far_data)
        print('\nLooking at close data')
        run_evaluation(close_data)
    # data_check(data)
    if args.qual_file:
        save_for_qualitative_analysis(data, args.qual_file)

if __name__ == '__main__':
    args = parse_args()
    main(args)