from generator.vanilla_trainer import train
from utils.util import load_config
from datasets import load_dataset, Dataset
from tqdm import tqdm
from eval_fn.general_eval import judge_router

def prepare_orm_data(config):
    dataset = load_dataset('json', data_files=config['reward_model']['dataset_name'], split=config['reward_model']['split'])
    print(dataset[0])

    samples = []
    labels = []
    template = "Q: {question}\nA: {answer}"
    judge = judge_router(config['task'])
    for idx in tqdm(range(len(dataset))):
        sample = dataset[idx]
        prediction, label = judge(sample['answer'], sample['generation'])
        samples.append(template.format(question=sample['question'], answer=prediction))
        labels.append(label)
        # answer = sample['answer'].split('\n#### ')[-1]
        # prediction = sample['generation'].split('\n#### ')
        # if len(prediction) == 1:
        #     samples.append(template.format(question=sample['question'], answer=sample['generation']))
        #     labels.append(0)
        # else:
        #     pred_answer = prediction[1].split(' ')[0]
        #     prediction = prediction[0] + '\n#### ' + pred_answer
        #     if answer.lower() in pred_answer.lower():
        #         samples.append(template.format(question=sample['question'], answer=prediction))
        #         labels.append(1)
        #     else:
        #         samples.append(template.format(question=sample['question'], answer=prediction))
        #         labels.append(0)

    # convert samples into huggingface dataset with Dataset.from_dict
    dataset = Dataset.from_dict({"label": labels, "text": samples}).with_format("torch")
    return dataset
    