from dataclasses import dataclass, field
from typing import Dict, Optional, Union

from common.vh_invariants import n_skills
from transformers import (
    IntervalStrategy,
    TrainingArguments as HfTrainingArguments
)


@dataclass
class BaseDataArguments:
    train_dataset_path: Optional[str] = field(
        metadata={"help": "directory path that includes all .pkl training datasets"}
    )
    eval_dataset_path: Optional[str] = field(
        metadata={"help": "directory path that includes all .pkl training datasets"},
        default=None
    )
    num_eval_data_limit: Optional[int] = field(
        metadata={"help": "If eval dataset has more than this number, it will be truncated"},
        default=None
    )
    train_instruction_path: Optional[str] = field(default="")
    eval_instruction_path: Optional[str] = field(default="")
    reward_path: Optional[str] = field(default=None)


@dataclass
class BaseModelArguments:

    multimodal_prompt: Optional[str] = field(
        default="Instruction: {instruction}\nHistory: {history_prompt}\nNext Skill: [MASK]"

    )
    multimodal_encoder_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "pretrained_model_name_or_path": "bczhou/tiny-llava-v1-hf"
        }
    )
    skill_decoder_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "mode": "imitation",
            "gpt2_config": {
                "vocab_size": 1,  # Doesn't matter
                "n_positions": 2048,
                "n_layer": 1,
                "n_head": 4,
                "activation_function": "relu",
                "resid_pdrop": 0.1,
                "embd_pdrop": 0.1,
                "attn_pdrop": 0.1,
                "layer_norm_epsilon": 0
            },
            "multimodal_embed_dim": 4096,
            "n_skills": n_skills,
            "lr": 1e-5
        }
    )


@dataclass
class BaseTrainingArguments(HfTrainingArguments):
    wandb_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "project": "REWARD_ENSEMBLE",
            "entity": "YOUR_ID",
            "name": None
        }
    )

    output_dir: str = field(default="./model_save/")
    output_filename: str = field(default="")

    seed: int = field(default=0)
    per_device_train_batch_size: int = field(default=8)
    per_device_eval_batch_size: int = field(default=16)
    evaluation_strategy: Union[IntervalStrategy, str] = field(default=IntervalStrategy.STEPS)
    eval_steps: Optional[int] = field(default=999999999)
    num_train_epochs: int = field(default=3)
    logging_steps: int = field(default=3)
    training_mode: str = field(metadata={"help": "on of (1)imitation, (2)rl, and (3)irl"}, default=None)
