from dataclasses import dataclass, field
from typing import Dict, Optional, Union

from transformers import (
    IntervalStrategy,
    TrainingArguments as HfTrainingArguments
)
from common.invariants import n_rewards


@dataclass
class BaseDataArguments:
    train_dataset_path: Optional[str] = field(
        metadata={"help": "directory path that includes all .pkl training datasets"}
    )
    eval_dataset_path: Optional[str] = field(
        metadata={"help": "directory path that includes all .pkl training datasets"},
        default=None
    )
    num_eval_data_limit: Optional[int] = field(
        metadata={"help": "If eval dataset has more than this number, it will be truncated"},
        default=None
    )
    prompt_format = "Task: {instruction}\nVisible objects: {objects}\nGrabbed: {grab}"
    
    train_instruction_path: Optional[str] = field(default="")
    eval_instruction_path: Optional[str] = field(default="")
    temporal_reward_path: Optional[list] = field(default="")
    relational_reward_path:Optional[str] = field(default="")
    procedure_reward_path:Optional[str] = field(default="")
    expert_reward_path:Optional[str] = field(default="")

@dataclass
class BaseModelArguments:
    prompt_format = "Task: {instruction}\nVisible objects: {observation}\nGrabbed: {grab}"
    classifier_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "n_rewards": n_rewards,
            "lr": 1e-4,
        }
    )


@dataclass
class BaseTrainingArguments(HfTrainingArguments):
    wandb_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "project": "REWARD_ENSEMBLE",
            "entity": "YOUR_ID",
            "name": None
        }
    )

    output_dir: str = field(default="./model_save/")
    output_filename: str = field(default="")

    seed: int = field(default=0)
    per_device_train_batch_size: int = field(default=8)
    per_device_eval_batch_size: int = field(default=16)
    evaluation_strategy: Union[IntervalStrategy, str] = field(default=IntervalStrategy.STEPS)
    eval_steps: Optional[int] = field(default=999999999)
    num_train_epochs: int = field(default=3)
    logging_steps: int = field(default=3)
