from dataclasses import dataclass, field
from typing import Optional, Dict
import transformers
from common.invariants import n_rewards, available
from config.args import BaseDataArguments, BaseModelArguments, BaseTrainingArguments
from model.model import Classifier
import natsort
import json
from tqdm import tqdm
from common.utils import  TestDataset,   test_collate_fn, TestSample
from torch.utils.data import DataLoader


@dataclass
class ModelArguments(BaseModelArguments):
    
    classifier_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "n_rewards": n_rewards,
            "lr": 1e-4,
            "grad_clip": 3
        }
    )


@dataclass
class DataArguments(BaseDataArguments):
    train_dataset_path: Optional[str] = field(
        metadata={"help": "directory path that includes all .json training datasets"},
        default="dataset/trajectory/single/EoE_trajectory_succ.json"
    )
    eval_dataset_path: Optional[str] = field(
        metadata={"help": "directory path that includes all .json training datasets"},
        default="dataset/trajectory/single/test.json"
    )
    # num_eval_data_limit: Optional[int] = 100
    train_instruction_path: Optional[str] = field(default="dataset/instruction/EoE_train.json")
    eval_instruction_path: Optional[str] = field(default="dataset/instruction/EoE_test.json")
    temporal_reward_path:Optional[str] = field(default="dataset/gemini_reward/temporal.json")
    relational_reward_path:Optional[str] = field(default="dataset/gemini_reward/relational.json")
    procedure_reward_path:Optional[str] = field(default="dataset/gemini_reward/procedure.json")
    expert_reward_path:Optional[str] = field(default="dataset/gemini_reward/expert.json")


@dataclass
class TrainingArguments(BaseTrainingArguments):
    training_mode: Optional[str] = field(default="rl")
    per_device_train_batch_size: Optional[str] = field(default=16)
    per_device_eval_batch_size: Optional[str] = field(default=100)
    logging_steps: Optional[int] = field(default=10)
    eval_steps: Optional[int] = field(default=10)
    num_train_epochs: Optional[int] = field(default=10000)
    save_steps: Optional[int] = field(default=50000)
      

def program():

   ############################# MODEL & ENV #############################
    parser = transformers.HfArgumentParser((TrainingArguments, ModelArguments, DataArguments))
    train_args, model_args, data_args = parser.parse_args_into_dataclasses()  # type. ModelArguments, DataArguments
    model = Classifier(seed=train_args.seed, cfg=model_args.classifier_cfg, init_build_model=True)
    
    model_file ="MODEL DIR"
    model = model.load(model_file)

    eval_dataset = TestDataset(
        dataset_path=data_args.eval_dataset_path,
        instruction_path=data_args.eval_instruction_path,
        prompt_format=data_args.prompt_format,
        temporal_reward_path=data_args.temporal_reward_path,
        relational_reward_path=data_args.relational_reward_path,
        procedure_reward_path=data_args.procedure_reward_path,
        expert_reward_path=data_args.expert_reward_path,
        num_data_limit=data_args.num_eval_data_limit,
        for_eval=True
    )
    eval_dataloader = DataLoader(
        dataset=eval_dataset,
        batch_size=int(train_args.per_device_eval_batch_size),
        collate_fn=test_collate_fn,
        shuffle=True
    )

    predictions = []

    for data in tqdm(eval_dataloader):
        data: TestSample
        prediction, weight = model.predict(
            prompts=data.prompts,
            rewards=data.rewards, 
            deterministic=True,
        )  # (batch * max length, 1)
        prediction = prediction.reshape(-1)
        predictions.extend(prediction.tolist())



if __name__ == "__main__":
    program()

