cross_weight=0.4
latent_loss_weight=0.2
ocr_loss_weight=1-${cross_weight}-${latent_loss_weight}
python_path=python3

TORCH_CPP_LOG_LEVEL=1 ${python_path} -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" \
glyph_diffusion/train_txt2img_trainer.py \
--font_path data/fonts/SimSun.ttf \
--do_train True \
--seed 42 \
--per_device_train_batch_size 3 \
--learning_rate 2e-5 \
--lr_scheduler_type constant \
--report_to custom_visualdl \
--max_steps 10000 \
--pretrained_model_name_or_path stable-diffusion-xl-base-1.0 \
--train_file_list data/train.filelist.pdc.${data} \
--log_file_list data/log.filelist.pdc.${data} \
--resolution 1024 \
--dataloader_num_workers 8 \
--proportion_empty_prompts 0.1 \
--overwrite_output_dir \
--output_dir output \
--logging_dir logs \
--logging_steps 5 \
--max_logging_samples 20 \
--image_logging_steps 1000 \
--save_strategy steps \
--save_steps 1000 \
--save_total_limit 10 \
--train_lora False \
--train_text_encoder False \
--model_max_length 77 \
--snr_gamma 5.0 \
--ocr_rec_loss_weight ${ocr_loss_weight} \
--ocr_rec_config data/config/ch_PP-OCRv4_rec_hgnet_custom.yml \
--use_vision_embedding ${vision_emb} \
--tokenization_level ${token} \
--local_mse_loss_weight ${latent_loss_weight} \
--use_glyph_localization True \
--cross_glyph_localization_weight ${cross_weight} \
--sharding stage2 \
--sharding_degree 8 \
--enable_xformers_memory_efficient_attention True \
--bf16 True \
--fp16_opt_level O2 \
--max_grad_norm 1 \
--recompute \
--gradient_accumulation_steps 1 > train.log 2>&1