import dataclasses
import json
from json import load as load_json
from pathlib import Path


@dataclasses.dataclass
class Hparams:
    patience: int
    num_epochs: int
    max_epochs: int
    lr_t5: float
    lr_iter: float
    lr_scheduler: str
    weight_decay: float
    warmup_steps: float | int
    task_weight_decay: float
    task_warmup_steps: float | int
    task_lr_scheduler: str
    activation_fn: str = "gelu"
    dropout: float = 0.3

    batch_size: int = 8
    gradient_accumulation: int = 1
    metric_average: str = "micro"
    optimize_for: str = "ere"
    eval_batch_size: int = 0

    @classmethod
    def from_name(cls, name: str) -> "Hparams":
        with open(Path.cwd() / "cfg" / (name + ".json")) as f:
            config = load_json(f)
        return cls(**config["training"])

    def to_json(self):
        return json.dumps(self.__dict__, indent=2)
