from dataclasses import dataclass, field
from typing import List, Optional, Tuple


@dataclass
class ModelArguments:
    is_train: bool = field(default=True, metadata={"help": "is_train"})
    ### model component
    pretrained_model_name_or_path: str = field(
        default="stable-diffusionxl-v1.0",
        metadata={"help": "Path to pretrained name or model"},
    )
    vae_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "pretrained_vae_name_or_path"}
    )
    model_max_length: Optional[int] = field(
        default=77, metadata={"help": "Pretrained tokenizer model_max_length"}
    )
    tokenizer_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path"}
    )
    tokenizer_2_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer 2. "}
    )
    text_encoder_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "text_encoder_name_or_path"}
    )
    text_encoder_2_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "text_encoder_2_name_or_path"}
    )
    unet_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "unet_encoder_name_or_path"}
    )
    num_inference_steps: Optional[int] = field(
        default=50, metadata={"help": "num_inference_steps"}
    )
    proportion_empty_prompts: float = field(
        default=0.1,
        metadata={"help": "The proportion to replace prompts with empty strings."},
    )
    snr_gamma: float = field(default=None, metadata={"help": "snr_gamma."})  # 5.0
    ### hyper params
    local_mse_loss_weight: float = field(default=0.0, metadata={"help": "local_mse_loss_weight."})

    use_ema: bool = field(default=False, metadata={"help": "Whether or not use ema"})
    # paddle >= 2.0 use this by default
    enable_xformers_memory_efficient_attention: bool = field(
        default=False, metadata={"help": "enable_xformers_memory_efficient_attention."}
    )
    noise_offset: float = field(default=0.1, metadata={"help": "noise_offset."})

    ### custom model config                             
    vision_embedding_ckpt_1: str = field(
        default=None, metadata={"help": "Path to vision embedding ckpt."}
    )
    vision_embedding_ckpt_2: str = field(
        default=None, metadata={"help": "Path to vision embedding ckpt."}
    )
    tokenization_level: str = field(
        default="BPE",
        metadata={
            "help": "tokenization level",
            "choices": ["Word", "BPE", "Char"],
        },
    )
    # object localization loss
    use_glyph_localization: bool = field(
        default=False, metadata={"help": "Whether or not use object localization."}
    )
    cross_glyph_localization_weight: float = field(
        default=0, metadata={"help": "cross_glyph_localization_weight."}
    )
    # ocr rec loss
    ocr_predictor_name_or_path: str = field(default=None)
    train_ocr_predictor: bool = field(
        default=False, metadata={"help": "Whether or not train ocr rec model."}
    )
    ocr_rec_loss_weight: float = field(
        default=0, metadata={"help": "ocr_rec_loss_weight."}
    )
    ocr_rec_config: str = field(
        default=None, metadata={"help": "Path to ocr rec model config."}
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    train_file_list: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a jsonlines)."}
    )
    log_file_list: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input log data file to log images."},
    )
    # data config
    resolution: int = field(
        default=1024, metadata={"help": "The resolution for input images"}
    )
    image_logging_steps: Optional[int] = field(
        default=100, metadata={"help": "Log image every X steps."}
    )
    num_records: int = field(
        default=1000000,
        metadata={"help": "num records, used as len(dataset) in iterable data reader"},
    )
    max_logging_samples: Optional[int] = field(
        default=100, metadata={"help": "For debugging purposes or quicker training"}
    )
    # additional resources
    ocr_model_path: str = field(
        default="ocr_models", metadata={"help": "dirs to cache ocr models"}
    )
    font_path: str = field(
        default="data/fonts/SimSun.ttf",
        metadata={"help": "Path to place fonts that are used in rendering image"},
    )
