from generator.vanilla_trainer import train
from utils.util import load_config
from datasets import load_dataset, Dataset
from tqdm import tqdm
import json
from eval_fn.general_eval import judge_router

def decompose_samples(template, question, prediction, label=None):
    prediction = prediction.split('\n')
    step_wise_preds = []
    step_wise_labels = []
    temp_pred = ''
    for pred in prediction:
        if temp_pred == '':
            temp_pred = pred
        else:
            temp_pred += '\n' + pred
        step_wise_preds.append(template.format(question=question, answer=temp_pred))
        step_wise_labels.append(label)

    return step_wise_preds, step_wise_labels

def prepare_prm_data(config):
    dataset = load_dataset('json', data_files=config['reward_model']['dataset_name'], split=config['reward_model']['split'])
    print(dataset[0])

    samples = []
    labels = []
    template = "Q: {question}\nA: {answer}"
    judge = judge_router(config['task'])
    for idx in tqdm(range(len(dataset))):
        sample = dataset[idx]
        prediction, label = judge(sample['answer'], sample['generation'])
        step_samples, step_labels = decompose_samples(template, sample['question'], prediction, label)
        samples += step_samples
        labels += step_labels
    
    # convert samples into huggingface dataset with Dataset.from_dict
    dataset = Dataset.from_dict({"label": labels, "text": samples}).with_format("torch")
    return dataset

def filter_dpo_prm_positive_data(config):
    dataset = load_dataset('json', data_files=config['reward_model']['dataset_name'], split=config['reward_model']['split'])
    print(dataset[0])
    correct_samples = []
    template = "Q: {question}\nA: {answer}"
    # flag = True
    for idx in tqdm(range(len(dataset))):
        # if idx % config['generator']['num_return_sequences'] == 0:
        #     flag = True
        sample = dataset[idx]
        answer = sample['answer'].split('\n#### ')[-1]
        prediction = sample['generation'].split('\n#### ')
        if len(prediction) != 1:
            pred_answer = prediction[1].split('\n')[0]
            prediction = prediction[0] + '\n#### ' + pred_answer
            if answer in pred_answer: # and flag:
                correct_samples.append(template.format(question=sample['question'], answer=prediction))
                # flag = False
    with open(config['reward_model']['dataset_name'].replace('.jsonl','_correct.jsonl'), 'a') as f:
        for correct_sample in correct_samples:
            json.dumps(correct_sample)
            f.write('\n')
    return correct_samples