#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import itertools
import logging
import os
import json
from ldm.txt2img_dataset import TextImagePairDataset

import sys
import warnings
from typing import Dict, Optional
import paddle
import numpy as np
from ldm.callback import (
    VisualDLCallbackWithImage,
)

from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.utils.log import logger

from ldm.model import GlyphDiffusionModel
from paddlenlp.trainer.integrations import (
    INTEGRATION_TO_CALLBACK,
)
import math
from ldm.ldm_trainer import GlyphDiffusionTrainer
from ldm.ldm_args import ModelArguments, DataTrainingArguments
from PIL import Image
import cv2

cv2.setNumThreads(1)

logging.getLogger("PIL.TiffImagePlugin").setLevel(logging.CRITICAL + 1)
logging.getLogger("PIL.PngImagePlugin").setLevel(logging.CRITICAL + 1)
logging.getLogger("PIL.Image").setLevel(logging.CRITICAL + 1)
logging.getLogger("PIL").setLevel(logging.CRITICAL + 1)

sys.path.append(os.getcwd())

print("path: ", sys.path)

from paddlenlp.utils.log import logger


def main():
    parser = PdArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    # set some args
    training_args.image_logging_steps = data_args.image_logging_steps = (
        math.ceil(data_args.image_logging_steps / training_args.logging_steps)
        * training_args.logging_steps
    )
    training_args.max_logging_samples = data_args.max_logging_samples
    training_args.log_file_list = data_args.log_file_list
    model_args.ocr_model_path = training_args.ocr_model_path = data_args.ocr_model_path
    from datetime import datetime

    TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
    training_args.logging_dir = os.path.join(training_args.logging_dir, TIMESTAMP)
    training_args.font_path = data_args.font_path
    training_args.resolution = model_args.resolution = data_args.resolution
    training_args.tokenization_level = data_args.tokenization_level = model_args.tokenization_level

    # Log on each process the small summary:
    logger.info(f"Training/evaluation parameters {training_args}")
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")
    ## ************** Load checkpoint. ************** ##
    paddle.set_device(training_args.device)

    # Detecting last checkpoint.
    last_checkpoint = None
    if (
        os.path.isdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif (
            last_checkpoint is not None and training_args.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    ## ****************** Load Model ****************** ##
    model = GlyphDiffusionModel(model_args=model_args)

    ## ***************** Load Dataset ***************** ##
    train_dataset = TextImagePairDataset(
        file_list=data_args.train_file_list,
        size=data_args.resolution,
        num_records=data_args.num_records,
        tokenizer=model.torch_tokenizer,
        ocr_model_path=data_args.ocr_model_path,
        font_path=data_args.font_path,
        tokenization_level=data_args.tokenization_level if model_args.use_vision_embedding is False else 'BPE'
    )

    INTEGRATION_TO_CALLBACK.update({"custom_visualdl": VisualDLCallbackWithImage})

    # Initialize Trainer
    trainer = GlyphDiffusionTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
    )

    trainer.model.set_recompute(training_args.recompute)

    params_to_train = model.get_trainable_parameters()
    params_to_train = itertools.chain.from_iterable(params_to_train)

    trainer.set_optimizer_grouped_parameters(params_to_train)
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    trainer.train(resume_from_checkpoint=checkpoint)
    trainer.save_model()
    trainer.save_state()


if __name__ == "__main__":
    main()
