# Example: python human_eval.py --folder_path ../data/mturk/100percent/allresponses/ --input_files ../data/mturk/input/zs_gt_go_l3_dependent_real_binary_explain_split_0.csv ../data/mturk/input/zs_gt_go_l3_nondependent_real_binary_explain_split_0.csv ../data/mturk/input/zs_g1.5_dependent_real_binary_explain_nl_split_0.csv ../data/mturk/input/zs_g1.5_nondependent_real_binary_explain_nl_split_0.csv --model_output_dir ../data/mturk/model_outputs_to_use/
import argparse
import os
import pandas as pd
from collections import defaultdict
import numpy as np
import math
import re
import json
from collections import defaultdict

def parse_args():
    # Set up argument parsing
    parser = argparse.ArgumentParser(description="Load all xlsx files from a folder into a single pandas DataFrame.")
    parser.add_argument("--input_files", nargs='+', help="Path to the corresponding input csv file")
    parser.add_argument("--folder_path", help="Path to the folder containing xlsx files")
    parser.add_argument("--model_output_dir", help="Path to the folder containing model predictions")
    args, _ = parser.parse_known_args()
    return args

def load_excel_files(folder_path, input_df):
    scores = defaultdict(list)

    for file_name in os.listdir(folder_path):
        if file_name.endswith('.xlsx'):
            file_path = os.path.join(folder_path, file_name)
            df = pd.read_excel(file_path)
            filename_parts = file_name.split('_')
            uuid = filename_parts[filename_parts.index('Responses')+1]
            if '_dependent_real_' in file_name:
                question_type = 'dependent_real'
            elif '_nondependent_real_' in file_name:
                question_type = 'nondependent_real'
            elif '_nondependent_switched_real_' in file_name:
                question_type = 'nondependent_switched_real'
                continue
            # print(f'Loading file type: {question_type} with {df.shape[0]} rows')
            if df.shape[0] < 3:
                print(f'File {file_name} has {df.shape[0]} rows')
            input_uuid_row = input_df.loc[input_df['uuid'] == uuid]
            input_row = input_uuid_row.loc[input_uuid_row['question_type'] == question_type]
            if input_row.shape[0] != 1:
                breakpoint()
            assert len(input_row) == 1
            count = 0
            for col in list(df.columns):
                if col == 'Timestamp':
                    continue
                # if count > 1:
                #     break
                # assuming that all rows for a particular column are related to the same source model
                # this should be true since the considered answer is same for all rows
                source_model = input_row[f'question_{count}_source_model'].iloc[0]
                goal_and_steps = input_row['goal_and_steps'].iloc[0]
                for i, score in enumerate(list(df[col])):
                    # uncomment 2 lines below for trial 2, also comment 2 lines below scores line
                    # if i<=2:
                    #     continue
                    # breakpoint()
                    scores[(source_model, col, uuid, question_type, goal_and_steps)].append(score)
                    # uncomment 2 lines below for trial 1, also comment 2 lines above scores line
                    if i >= 2:
                        break
                count += 1
    return scores

def transform_to_binary(scores):
    binary_scores = []
    for s in scores:
        if s > 3:
            binary_scores.append(1)
        else:
            binary_scores.append(0)
    return binary_scores

def get_majority_binary_vote(scores):
    positive, negative = 0, 0
    for s in scores:
        if s == 1:
            positive += 1
        else:
            negative += 1
    if positive > negative:
        return 1
    else:
        return 0

def value2num(value):
    if value == 'Strongly Agree':
        return 5
    elif value == 'Agree':
        return 4
    elif value == 'Disagree':
        return 2
    elif value == 'Strongly Disagree':
        return 1
    return 3

def create_df(scores):
    data = []
    for key, value in scores.items():
        info = {}
        info['source_model'] = key[0]
        info['question'] = key[1]
        info['uuid'] = key[2]
        info['question_type'] = key[3]
        info['goal_and_steps'] = key[4]
        value = [value2num(v) for v in value]
        info['scores'] = value
        info['avg_score'] = sum(value)/len(value)
        info['binary_scores'] = transform_to_binary(value)
        info['avg_binary_score'] = sum(info['binary_scores'])/len(info['binary_scores'])
        if info['avg_binary_score'] > 0.5:
            info['binary_avg_score'] = 1
        else:
            info['binary_avg_score'] = 0
        info['majority_binary_vote'] = get_majority_binary_vote(info['binary_scores'])
        data.append(info)
    df = pd.DataFrame(data)
    return df

def get_cherries_lemons_per_model(df):
    qtypes = ['dependent', 'nondependent', 'nondependent_switched']
    model_names = list(df.source_model.unique())
    for model in model_names:
        model_df = df[df['source_model'] == model]
        print(f"Model: {model}")
        print(f"{model_df['avg_score'].describe()}")
        cherries, lemons = [], []
        for idx, row in model_df.iterrows():
            if row['avg_score'] >= 4:
                cherries.append(row)
            elif row['avg_score'] <= 2.67:
                lemons.append(row)
        cherries_df = pd.DataFrame(cherries)
        lemons_df = pd.DataFrame(lemons)
        for qtype in qtypes:
            print(f'{qtype} questions')
            try:
                print('Lemon')
                most_lemon = lemons_df[lemons_df['question_type'].str.startswith(qtype)].sort_values('avg_score').iloc[0]
                print(most_lemon.goal_and_steps)
                print(most_lemon.question)
                print(f'Avg Score: {most_lemon.avg_score}')
            except:
                print('No lemons')
            print()
            try:
                print('Cherry')
                most_cherry = cherries_df[cherries_df['question_type'].str.startswith(qtype)].sort_values('avg_score', ascending=False).iloc[0]
                print(most_cherry.goal_and_steps)
                print(most_cherry.question)
                print(f'Avg Score: {most_cherry.avg_score}')
            except:
                print('No cherries')
            print()
        print('-------------------------------------------------')

def get_histogram_of_scores(df):
    score_counts, binary_score_counts = defaultdict(int), defaultdict(int)
    for idx, row in df.iterrows():
        for score in row['scores']:
            score_counts[score] += 1
        for score in row['binary_scores']:
            binary_score_counts[score] += 1
    print(f"Histogram counts: {score_counts}")
    print(f"Binary histogram counts: {binary_score_counts}\n")

def get_model_scores(df, full):
    model_names = list(df.source_model.unique())
    for model in model_names:
        model_df = df[df['source_model'] == model]
        # only pick the first 480 rows for full, 240 for D and ND
        if not full:
            if model_df.shape[0] > 240:
                model_df = model_df.iloc[:240]
        else:
            if model_df.shape[0] > 480:
                model_df = model_df.iloc[:480]
        print(f"Model: {model} on {model_df.shape[0]} data points")
        print(f"Average Score: {round(model_df['avg_score'].mean(), 2)}")
        print(f"Std Dev Score: {round(model_df['avg_score'].std(), 2)}")
        print(f"Average Binary Score: {round(model_df['avg_binary_score'].mean(), 2)}")
        print(f"Binarized Average Score: {round(model_df['binary_avg_score'].mean(), 2)}")
        print(f"Majority Binary Vote: {round(100*sum(model_df['majority_binary_vote'])/len(model_df['majority_binary_vote']), 2)}%")
        get_histogram_of_scores(model_df)

def num_annotators_check(new_final_labels, num_annotators):
    """
    Check if the number of annotators is the same for all questions
    """
    error_rows = []
    for idx, record in enumerate(new_final_labels):
        if sum(record) != num_annotators:
            print(f'Row {idx} has {sum(record)} annotators')
            error_rows.append(idx)
    return error_rows


def weighted_fleiss_kappa(new_final_labels, weights, weighted=True):
    
    """
    Code from Mohaddeseh's COLING paper
    """

    error_rows = num_annotators_check(new_final_labels, 3)
    if error_rows:
        print(f"Not enough annotators in the following rows: {error_rows}")
        print("IAA cannot be calculated")
        return None

    table = 1 * np.asarray(new_final_labels)   # avoid integer division

    n_sub, n_cat =  table.shape

    n_total = table.sum()
    n_rater = table.sum(1)
    n_rat = n_rater.max()
    # assume fully ranked
    assert n_total == n_sub * n_rat

    # marginal frequency  of categories
    p_cat = table.sum(0) / n_total

    if weighted:
        table_weight = 1 * np.asarray(weights)     
        table2 = np.matmul(table , table_weight)
        table2 = np.multiply(table2,table)
    else:
        table2 = table * table
   
    p_rat = (table2.sum(1) - n_rat) / (n_rat * (n_rat - 1.))
    p_mean = p_rat.mean()

    p_mean_exp = (p_cat*p_cat).sum()
  
    kappa = float(p_mean - p_mean_exp) / (1- p_mean_exp)

    return round(kappa, 4)

def fleiss_kappa(M):

    """

    From: https://towardsdatascience.com/inter-annotator-agreement-2f46c6d37bf3

    Computes Fleiss' kappa for group of annotators.
    :param M: a matrix of shape (:attr:'N', :attr:'k') with 'N' = number of subjects and 'k' = the number of categories.
        'M[i, j]' represent the number of raters who assigned the 'i'th subject to the 'j'th category.
    :type: numpy matrix
    :rtype: float
    :return: Fleiss' kappa score

    # turned out to be same as my implementation
    """

    M = np.array(M)
    N, k = M.shape  # N is # of items, k is # of categories
    n_annotators = float(np.sum(M[0, :]))  # # of annotators
    tot_annotations = N * n_annotators  # the total # of annotations
    category_sum = np.sum(M, axis=0)  # the sum of each category over all items

    # chance agreement
    p = category_sum / tot_annotations  # the distribution of each category over all annotations
    PbarE = np.sum(p * p)  # average chance agreement over all categories

    # observed agreement
    P = (np.sum(M * M, axis=1) - n_annotators) / (n_annotators * (n_annotators - 1))
    Pbar = np.sum(P) / N  # add all observed agreement chances per item and divide by amount of items

    return round((Pbar - PbarE) / (1 - PbarE), 4)

def create_kappa_matrix(df, score_col_name):
    M = []
    if 'binary' in score_col_name:
        start_label = 0
        end_label = 1
    else:
        start_label = 1
        end_label = 5
    for idx, row in df.iterrows():
        info = []
        for i in range(start_label, end_label+1):
            info.append(row[score_col_name].count(i))
        M.append(info)
    M = np.array(M)
    return M

def read_input_files(input_files):
    df = pd.DataFrame()
    for file in input_files:
        file_df = pd.read_csv(file)
        if '_dependent_real_' in file:
            file_df['question_type'] = 'dependent_real'
        elif '_nondependent_real_' in file:
            file_df['question_type'] = 'nondependent_real'
        elif '_nondependent_switched_real_' in file:
            file_df['question_type'] = 'nondependent_switched_real'
        df = pd.concat([df, file_df])
    return df

def parse_json_markdown(json_string: str) -> dict:
    # Try to find JSON string within first and last triple backticks
    match = re.search(r"""[`]*       # match first occuring triple backticks
                          (?:json)? # zero or one match of string json in non-capturing group
                          (.*)[`]*   # greedy match to last triple backticks""", json_string, flags=re.DOTALL|re.VERBOSE)

    # If no match found, assume the entire string is a JSON string
    if match is None:
        json_str = json_string
        breakpoint()
    else:
        # If match found, use the content within the backticks
        json_str = match.group(1)

    # Strip whitespace and newlines from the start and end
    json_str = json_str.strip()

    # Parse the JSON string into a Python dictionary while allowing control characters by setting strict to False
    try:
        parsed = json.loads(json_str, strict=False)
    except:
        match = re.search(r"[`]*\{(?:json)?(.*)\}[`]*", json_string, flags=re.DOTALL|re.VERBOSE)
        if match is None:
            json_str = json_string + '}'
            # breakpoint()
        else:
            json_str = match.group(0).strip()
        parsed = json.loads(json_str, strict=False)

    return parsed

def parse_basic_binary_must_why_output(data):
    parse_count = 0
    bpl, why_ans = [], []
    for idx, info in data.iterrows():
        loaded_info = parse_json_markdown(info['model_answer'])
        if 'yes' in loaded_info['binary_answer'].lower():
            bpl.append(True)
            why_ans.append(loaded_info['why_answer'])
            parse_count += 1
        elif 'no' in loaded_info['binary_answer'].lower():
            bpl.append(False)
            why_ans.append(loaded_info['why_answer'])
            parse_count += 1
        else:
            bpl.append(False)
            why_ans.append('Could not think of a reason')

    data['binary_prediction_label'] = bpl
    data['why_explanation'] = why_ans
    if len(data) != parse_count:
        print(f'ERROR ERROR ERROR: Total: {len(data)}, Parsed: {parse_count}')
    else:
        print(f'Total: {len(data)}, Parsed: {parse_count}')
    return data

def parse_basic_binary_must_why_nl_output(data):
    parse_count = 0
    bpl, why_ans = [], []
    for idx, info in data.iterrows():
        if info['model_answer'].startswith('Answers:'):
            # handle one particular case where model_answer is as below:
            # 'Answers:\n\n1. Yes.\n2. The chocolate mixture must be cooled to room temperature before it can be folded into the egg whites in Step 12.'
            answer_strs = info['model_answer'].replace('Answers:\n\n', '').split('\n')
            binary_answer = answer_strs[0].replace('1. ', '').strip()
            why_answer = answer_strs[1].replace('2. ', '').strip()
        else:
            answer_strs = info['model_answer'].split('\n')
            binary_answer = answer_strs[0].replace('Answer 1: ', '').strip()
            why_answer = answer_strs[1].replace('Answer 2: ', '').strip()
        
        if 'yes' in binary_answer.lower():
            bpl.append(True)
            why_ans.append(why_answer)
            parse_count += 1
        elif 'no' in binary_answer.lower():
            bpl.append(False)
            why_ans.append(why_answer)
            parse_count += 1
        else:
            bpl.append(False)
            why_ans.append('Could not think of a reason')

    data['binary_prediction_label'] = bpl
    data['why_explanation'] = why_ans
    if len(data) != parse_count:
        print(f'ERROR ERROR ERROR: Total: {len(data)}, Parsed: {parse_count}')
    else:
        print(f'Total: {len(data)}, Parsed: {parse_count}')
    return data

def get_clean_answer(text):
    # Example text: 'Question: Explain why or why not Step 1 must happen before Step 13.\n\nAnswer:\nStep 1 involves preheating the oven which is necessary before baking the cheesecake in subsequent steps, well before the refrigeration in Step 13.\n\nDo you think the answer contain all the relevant details to address what the question requires?'
    # To extract: 
    # Step 1 involves preheating the oven which is necessary before baking the cheesecake in subsequent steps, well before the refrigeration in Step 13.
    components = text.split('\n\n')
    question_components = components[0]
    question = question_components.replace('Question: ', '')
    if 'before' in question:
        temporal = 'before'
    elif 'after' in question:
        temporal = 'after'
    relevant_steps = re.findall(r'Step \d+', question)
    relevant_steps = sorted([int(step.split(' ')[1]) for step in relevant_steps])
    answer_component = components[1]
    answer = answer_component.split('\n')[1]
    return answer, f'{relevant_steps[0]-1}_{relevant_steps[1]-1}', temporal

def get_raw_model_outputs(row, model_output_df):
    # use model name, title, step_pair_idx_about_str_pair, answer to retrieve
    # some answers can be nan, make sure to skip them
    model_answer, relevant_steps_idx, temporal = get_clean_answer(row['question'])
    goal_mturk = row['goal_and_steps'].split('\n')[0].replace('Goal: ', '')
    model_output_row = model_output_df.loc[(model_output_df['model_name'] == row['source_model']) & (model_answer == model_output_df['why_explanation']) & (model_output_df['title'] == goal_mturk) & (model_output_df['step_pair_idx_asked_about_str'] == relevant_steps_idx) & (f"{row['question_type']}_{temporal}" == model_output_df['question_type'])]
    if model_output_row.shape[0] != 1:
        print('X')
        breakpoint()
    relevant_info = model_output_row.iloc[0].to_dict()
    return relevant_info

def modify_df(df, model_output_df):
    data = []
    unfaithful_count = defaultdict(int)
    for idx, row in df.iterrows():
        raw_info = get_raw_model_outputs(row, model_output_df)
        info = row.to_dict()
        info['raw_title'] = raw_info['title']
        info['raw_steps'] = raw_info['steps']
        info['raw_model_answer'] = raw_info['model_answer']
        info['raw_why_explanation'] = raw_info['why_explanation']
        info['raw_binary_prediction_label'] = raw_info['binary_prediction_label']
        info['raw_gold_label'] = raw_info['gold_label']
        info['step_pair_idx_asked_about'] = raw_info['step_pair_idx_asked_about']
        if raw_info['binary_prediction_label'] != raw_info['gold_label']:
            unfaithful_count[info['source_model']] += 1
            # this is where the scores are replaced
            info['scores'] = [0,0,0]
            info['avg_score'] = sum(info['scores'])/len(info['scores'])
            info['binary_scores'] = transform_to_binary(info['scores'])
            info['avg_binary_score'] = sum(info['binary_scores'])/len(info['binary_scores'])
            if info['avg_binary_score'] > 0.5:
                info['binary_avg_score'] = 1
            else:
                info['binary_avg_score'] = 0
            info['majority_binary_vote'] = get_majority_binary_vote(info['binary_scores'])
        data.append(info)
    print(f'Unfaithful count: {unfaithful_count}')
    print(f'Total records: {df.source_model.value_counts()}')
    print()
    return pd.DataFrame(data)

def run_evaluation(df, model_output_df, full):
    df = df.dropna(subset=['source_model']).reset_index(drop=True)
    get_model_scores(df, full, full)
    modified_df = modify_df(df, model_output_df)
    print('Modified Scoring')
    get_model_scores(modified_df, full)

    Mn = create_kappa_matrix(df, 'scores')
    k_fleiss_weights = np.array([[1, math.cos(math.pi/8), math.cos(math.pi/4), math.cos(3*math.pi/8), 0],
                                        [math.cos(math.pi/8), 1, math.cos(math.pi/8), math.cos(math.pi/4), math.cos(3*math.pi/8)],
                                        [math.cos(math.pi/4), math.cos(math.pi/8), 1, math.cos(math.pi/8), math.cos(math.pi/4)],
                                        [math.cos(3*math.pi/8), math.cos(math.pi/4), math.cos(math.pi/8), 1, math.cos(math.pi/8)],
                                        [0, math.cos(3*math.pi/8), math.cos(math.pi/4), math.cos(math.pi/8), 1]])
    wfk = weighted_fleiss_kappa(Mn, k_fleiss_weights)
    print(f"Weighted Fleiss Kappa: {wfk}")

    Mb = create_kappa_matrix(df, 'binary_scores')
    b_fleiss_weights = np.array([[1, math.cos(math.pi/8)],
                                        [math.cos(math.pi/8), 1]])
    wfk = weighted_fleiss_kappa(Mb, b_fleiss_weights)
    print(f"Binary Weighted Fleiss Kappa: {wfk}")

def load_raw_model_outputs(folder_name):
    df = pd.DataFrame()
    for file in os.listdir(folder_name):
        model_name = file.split('_')[0]
        if file.endswith('.jsonl'):
            file_df = pd.read_json(os.path.join(folder_name, file), lines=True)
            file_df['model_name'] = model_name
            if 'gemini' in model_name:
                file_df = parse_basic_binary_must_why_nl_output(file_df)
            else:
                file_df = parse_basic_binary_must_why_output(file_df)
            df = pd.concat([df, file_df])

    step_pair_str, labels = [], []
    for idx, row in df.iterrows():
        step_idxs = sorted(row['step_pair_idx_asked_about'])
        step_pair_str.append(f"{step_idxs[0]}_{step_idxs[1]}")
        if 'nondependent' in row['question_type']:
            labels.append(False)
        else:
            labels.append(True)
    df['step_pair_idx_asked_about_str'] = step_pair_str
    df['gold_label'] = labels
    return df

def main(args):
    model_outputs = load_raw_model_outputs(args.model_output_dir)
    input_df = read_input_files(args.input_files)
    scores = load_excel_files(args.folder_path, input_df)
    df = create_df(scores)
    print(f'Running on all types of questions available: {df.question_type.unique()}')
    run_evaluation(df, model_outputs, True)
    question_types = ['dependent_real', 'nondependent_real', 'nondependent_switched_real']
    question_types = ['dependent_real', 'nondependent_real']
    for question_type in question_types:
        print(f'Running on {question_type} questions')
        run_evaluation(df[df['question_type'] == question_type], model_outputs, False)
    # get_cherries_lemons_per_model(df)

if __name__ == "__main__":
    args = parse_args()
    main(args)
