import re
from datasets import load_dataset

# ================== CONFIG ==================
DATASET = "preference-agents-experiments/enron-standardized-jeff-dasovich-20-split-llmeval"
columns_evaluated = [
    # "baseline_generation",
    # "gold_rule_email",
    # "naiveft_baseline",
    # "rulegen_rules_generated_email",
    "llmwinrate_70b_raw"
]

pattern = r'<winner>\s*(\d+)\s*</winner>|<winner>\s*(\d+)\s*$|(\d+)\s*$'
# 1. match <winner> </winner>
# 2. match <winner> without </winner>
# 3. match without <winner> and </winner>, take number at the end of the string 
# ============================================
def clean_text(ls):
    results = []
    null_counter = 0
    for each in ls:
        match = re.search(pattern, each)
        if match:
            # Check groups in order to handle cases correctly
            winner = match.group(1) or match.group(2) or match.group(3)
            results.append(int(winner))
        else:
            results.append(None)
            null_counter += 1
    print(f"Number of NULLs: {null_counter}")
    return results

def process_for_dataset_split(df):
    for evaluation_candidate in columns_evaluated:
        outputs = clean_text(df[evaluation_candidate])
        column_name = f"{evaluation_candidate}_cleaned"
        print(f"Adding column {column_name}")
        if column_name in df.column_names:
            df = df.remove_columns([column_name])
        df = df.add_column(column_name, outputs)
    return df

dataset = load_dataset(DATASET)
dataset["train"] = process_for_dataset_split(dataset["train"])
dataset["test"] = process_for_dataset_split(dataset["test"])

dataset.push_to_hub(DATASET)
