import os.path as op

from utils.distributed_processing import synchronize
from utils.load_and_save import save_ema_model_from_ckpt


def model_factory(logger, args, config_class, model_class, tokenizer_class):
    # Setup configs for model
    if args.do_train:
        assert args.model_name_or_path is not None
        config = config_class.from_pretrained(
            args.config_name if args.config_name else args.model_name_or_path,
            num_labels=args.num_labels, finetuning_task='gqa'
        )

        # Tokenizer
        tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name \
                                                        else args.model_name_or_path, do_lower_case=args.do_lower_case)

        # Model parameters
        config.img_feature_dim = args.img_feature_dim
        config.img_feature_type = args.img_feature_type
        config.hidden_dropout_prob = args.drop_out
        config.loss_type = args.loss_type
        config.tie_weights = args.tie_weights
        config.freeze_embedding = args.freeze_embedding
        config.freeze_backbone = args.freeze_backbone
        config.classifier = args.classifier
        config.cls_hidden_scale = args.cls_hidden_scale

        # prefix_cfgs
        config.add_prefix = args.add_prefix
        if config.add_prefix is True:
            # Check if the prefix is in vocab.txt file and works with tokenizer.
            prefix_id = tokenizer.convert_tokens_to_ids("[prefix]")
            UNK_id = tokenizer.convert_tokens_to_ids("[UNK]")
            if isinstance(prefix_id, int) is not True or prefix_id == UNK_id:
                raise ValueError("[prefix] Token is not in the vocab.txt")

            config.num_prefix = args.num_prefix
            config.freeze_prefix = args.freeze_prefix
            config.mlp_for_prefix = args.mlp_for_prefix
            config.prefix_drop_prob = args.prefix_drop_prob
            config.prefix_shuffle_prob = args.prefix_shuffle_prob
            config.special_tokens = dict(zip(tokenizer.all_special_tokens, tokenizer.all_special_ids))
            config.prefix_no_pos_emb = args.prefix_no_pos_emb if hasattr(args, "prefix_no_pos_emb") else None

        else:
            config.special_tokens = dict(zip(tokenizer.all_special_tokens, tokenizer.all_special_ids))

        model = model_class.from_pretrained(args.model_name_or_path,
                                            from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
        checkpoint=None

    else:
        # Eval model checkpoint
        checkpoint = args.eval_model_dir
        assert op.isdir(checkpoint)

        # Model parameters
        config = config_class.from_pretrained(checkpoint)
        config.output_hidden_states = args.output_hidden_states
        config.output_attentions = args.output_attentions
        tokenizer = tokenizer_class.from_pretrained(checkpoint)
        logger.info(f"Evaluate the following checkpoint: {checkpoint}")

        # prefix_cfgs
        config.add_prefix = args.add_prefix
        if config.add_prefix is True:
            # Check if the prefix is in vocab.txt file and works with tokenizer.
            prefix_id = tokenizer.convert_tokens_to_ids("[prefix]")
            UNK_id = tokenizer.convert_tokens_to_ids("[UNK]")
            if isinstance(prefix_id, int) is not True or prefix_id == UNK_id:
                raise ValueError("[prefix] Token is not in the vocab.txt")

            config.num_prefix = args.num_prefix
            config.special_tokens = dict(zip(tokenizer.all_special_tokens, tokenizer.all_special_ids))
            config.prefix_no_pos_emb = args.prefix_no_pos_emb if hasattr(args, "prefix_no_pos_emb") else None

        else:
            config.special_tokens = dict(zip(tokenizer.all_special_tokens, tokenizer.all_special_ids))
        model = model_class.from_pretrained(checkpoint, config=config)
        synchronize()

        # For EMA model
        if args.eval_ema_num is not None:
            logger.info(f"Calculating EMA model with last {int(args.eval_ema_num)} checkpoints:")
            checkpoint = save_ema_model_from_ckpt(
                model=model,
                checkpoint_path=checkpoint,
                num_model=int(args.eval_ema_num)
            )
            synchronize()

            model = model_class.from_pretrained(checkpoint, config=config)
            logger.info("EMA Model Checkpoint Loaded:")
            print(f" - Process in local rank {args.local_rank} load ckpt from: {checkpoint}")
            synchronize()

    return model, tokenizer, checkpoint