from dataclasses import dataclass, field
from typing import Optional, Dict

import transformers
from common.invariants import n_rewards
from config.args import BaseDataArguments, BaseModelArguments, BaseTrainingArguments
from model.model import Classifier
from trainer.trainer import Trainer
import os

@dataclass
class DataArguments(BaseDataArguments):
    
    current_path = os.getcwd()
    train_dataset_path: Optional[str] = field(
        metadata={"help": "directory path that includes all .json training datasets"},
        default=f"{current_path}/dataset/trajectory/single/EoE_trajectory_succ.json"
    )
    eval_dataset_path: Optional[str] = field(
        metadata={"help": "directory path that includes all .json training datasets"},
        default=f"{current_path}/dataset/trajectory/single/test_succ.json"
    )
    num_eval_data_limit: Optional[int] = 100
    train_instruction_path: Optional[str] = field(default="dataset/instruction/EoE_train.json")
    eval_instruction_path: Optional[str] = field(default="dataset/instruction/EoE_test.json")
    temporal_reward_path:Optional[str] = field(default="dataset/reward/temporal.json")
    relational_reward_path:Optional[str] = field(default="dataset/reward/contextual.json")
    procedure_reward_path:Optional[str] = field(default="dataset/reward/structural.json")
    expert_reward_path:Optional[str] = field(default="dataset/reward/human_example.json")

@dataclass
class ModelArguments(BaseModelArguments):
    
    classifier_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "n_rewards": n_rewards,
            "lr": 1e-4,
            "grad_clip": 3
        }
    )

@dataclass
class TrainingArguments(BaseTrainingArguments):
    training_mode: Optional[str] = "classifier"
    output_dir: Optional[str] = field(default="./model_save/ensemble")
    output_filename: Optional[str] = field(default="ensemble")
    per_device_train_batch_size: Optional[str] = field(default=16)
    per_device_eval_batch_size: Optional[str] = field(default=80)
    logging_steps: Optional[int] = field(default=50)
    eval_steps: Optional[int] = field(default=100)
    save_strategy: Optional[str] = "steps"
    save_steps: Optional[int] = 500
    num_train_epochs: Optional[int] = field(default=300)


def program():
    parser = transformers.HfArgumentParser((TrainingArguments, ModelArguments, DataArguments))
    train_args, model_args, data_args = parser.parse_args_into_dataclasses()  # type ModelArguments, DataArguments

    model = Classifier(seed=train_args.seed, cfg=model_args.classifier_cfg, init_build_model=True)

    trainer = Trainer(
        classifier=model,
        data_args=data_args,
        model_args=model_args,
        train_args=train_args,
        save_dir = train_args.output_filename
    )
    trainer.run()


if __name__ == "__main__":
    program()
