from dataclasses import dataclass, field
from typing import Dict, Optional

import transformers
from common.vh_invariants import n_skills
from config.args import BaseDataArguments, BaseModelArguments, BaseTrainingArguments
from models.mm_student import SkillDecoder
from models.multimodal_encoders import VitBertMultiModalEncoderForCaption
from trainer import EalfTrainer

import os

@dataclass
class DataArguments(BaseDataArguments):
    current_path = os.getcwd()
    train_dataset_path: Optional[str] = field(default=f"{current_path}/Dataset/trajectory/train.json")
    eval_dataset_path: Optional[str] = field(default=f"{current_path}/Dataset/trajectory/test.json")
    train_instruction_path: Optional[str] = field(default=f"{current_path}/Dataset/instruction/rl_train.json")
    eval_instruction_path: Optional[str] = field(default=f"{current_path}/Dataset/instruction/rl_test.json")
    reward_path: Optional[str] = field(default=f"{current_path}/Dataset/reward/reward.json")
    num_eval_data_limit: Optional[int] = field(default=300)

@dataclass
class ModelArguments(BaseModelArguments):
    skill_decoder_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "mode": "rl",
            "arch": "transformer",
            "gpt2_config": {
                "vocab_size": 1,  # No matter
                "n_positions": 768 *  2,
                "n_layer": 2,
                "n_head": 4,
                "activation_function": "relu",
                "resid_pdrop": 0.1,
                "embd_pdrop": 0.1,
                "attn_pdrop": 0.1,
                "layer_norm_epsilon": 0
            },
            "multimodal_embed_dim": 768,
            "n_skills": n_skills,
            "lr": 1e-4,
            "target_update_interval": 250,
            "gamma": 0.99,
            "tau": 0.005,
            "net_arch": [512, 512]
        }
    )
    save_path_dir: Optional[str] = "model_save/rl"

# trp_reward_noinit

@dataclass
class TrainingArguments(BaseTrainingArguments):
    training_mode: Optional[str] = field(default="rl")
    output_dir: Optional[str] = field(default="model_save/rl")
    output_filename: Optional[str] = field(default="rl")
    per_device_train_batch_size: Optional[str] = field(default=32)
    per_device_eval_batch_size: Optional[str] = field(default=80)
    logging_steps: Optional[int] = field(default=10)
    eval_steps: Optional[int] = field(default=10)
    num_train_epochs: Optional[int] = field(default=22)
    save_strategy: Optional[str] = field(default="steps")
    save_steps: Optional[int] = field(default=10)

def program():
    parser = transformers.HfArgumentParser((TrainingArguments, ModelArguments, DataArguments))
    train_args, model_args, data_args = parser.parse_args_into_dataclasses()  # type ModelArguments, DataArguments
    multimodal_encoder = VitBertMultiModalEncoderForCaption(model_args.multimodal_encoder_cfg)

    skill_decoder = SkillDecoder(seed=train_args.seed, cfg=model_args.skill_decoder_cfg, init_build_model=True)
    skill_decoder.multimodal_encoder = multimodal_encoder

    trainer = EalfTrainer(
        agent=skill_decoder,
        data_args=data_args,
        model_args=model_args,
        train_args=train_args
    )
    trainer.run()


if __name__ == "__main__":
    program()
