from dataclasses import dataclass, field
from typing import Dict, Optional

from common.vh_invariants import n_skills
from transformers import (TrainingArguments as HfTrainingArguments)


@dataclass
class BaseDataArguments:
    dataset_path: Optional[str] = field(default="")
    instruction_path: Optional[str] = field(default="")


@dataclass
class BaseModelArguments:
    batch_size: int = field(default=5)
    multimodal_prompt: Optional[str] = field(
        default="<image>\nUSER: {instruction}{history_prompt}\nASSISTANT: I should do "
    )
    multimodal_encoder_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "pretrained_model_name_or_path": "llava-hf/llava-1.5-7b-hf"
        }
    )
    skill_decoder_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "mode": "imitation",
            "gpt2_config": {
                "vocab_size": 1,  # No matter
                "n_positions": 1024,
                "n_layer": 1,
                "n_head": 4,
                "activation_function": "relu",
                "resid_pdrop": 0.1,
                "embd_pdrop": 0.1,
                "attn_pdrop": 0.1,
                "layer_norm_epsilon": 0
            },
            "multimodal_embed_dim": 4096,
            "n_skills": n_skills,
            "lr": 1e-5
        }
    )


@dataclass
class BaseTrainingArguments(HfTrainingArguments):
    output_dir: str = field(default="./model_save/")
    seed: int = field(default=0)
    per_device_train_batch_size: int = field(default=8)
    per_device_eval_batch_size: int = field(default=16)
    num_train_epochs: int = field(default=3)
    logging_steps: int = field(default=100)
