import sys
import os
import re

DIRPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
print(DIRPATH)
sys.path.append(DIRPATH)

from datasets import load_dataset
from src.util import get_config, print_config

config = get_config()
print_config(config)

DATASET = config["dataset"]
COLUMN_NAME = "with_8b_baseline_rules"


def parse_rules(text):
    # Check for <rules> tag
    rules_start = re.search(r"<rules>", text)
    if rules_start:
        # Remove everything up until the start of <rules> tag
        # text = re.sub(r".*?<rules>", "", text)
        text = text[rules_start.start() + 7 :]
    else:
        # Check for </thinking> tag
        thinking_end = re.search(r"</thinking>", text)
        if thinking_end:
            # Remove everything up until the end of </thinking> tag
            # text = re.sub(r".*?</thinking>", "", text)
            text = text[thinking_end.end() + 11 :]

    # Remove the ending </rules> tag, if it exists
    text = re.sub(r"</rules>", "", text)
    return text


def clean_intent(row):
    cleaned_intent = parse_rules(row[COLUMN_NAME])
    return {COLUMN_NAME: cleaned_intent}


dataset = load_dataset(DATASET, split="train")
dataset.push_to_hub(DATASET + "_original_rules")
dataset = dataset.map(clean_intent)
dataset.push_to_hub(DATASET)
