
import logging

from .datasets_classes import DATASET_CLASS
from .datasets_meta import MetaHandler

LOGGER = logging.getLogger(__name__)


def _set_to_first_list_element(milie_args):
    milie_args.data_set = milie_args.data_set[0] if isinstance(milie_args.data_set,list) else milie_args.data_set
    milie_args.train_file = milie_args.train_file[0] if isinstance(milie_args.train_file,list) else milie_args.train_file
    milie_args.predict_file = milie_args.predict_file[0] if isinstance(milie_args.predict_file,list) else milie_args.predict_file
    milie_args.valid_gold = milie_args.valid_gold[0] if isinstance(milie_args.valid_gold,list) else milie_args.valid_gold
    return milie_args


def get_data_handler(milie_args, tokenizer=None, predict=False):
    """
    Factory for returning dataset specific handlers.

    :param milie_args: instance of NspArguments
    :param tokenizer: a tokenizer instance
    :param predict: if True then the requested datasethandler is for prediction, needed for
                    case distinction for the Meta dataset case where several datasets are provided
                    for training but only the first is used for prediction
    :return: an instance or a subclass instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`
    """
    if isinstance(milie_args.train_file, list):  # else this has been done in a previous call
        assert len(milie_args.data_set) == len(milie_args.train_file)
        if predict:
            assert len(milie_args.data_set) == len(milie_args.predict_file)
            assert len(milie_args.data_set) == len(milie_args.valid_gold)

    if len(milie_args.data_set) > 1 and isinstance(milie_args.train_file, list) and predict is False:
        dataset_handler = MetaHandler(milie_args)
    elif isinstance(milie_args.data_set, list) and predict is True:
        import copy
        milie_args_predict = _set_to_first_list_element(copy.deepcopy(milie_args))
        dataset_handler = DATASET_CLASS[milie_args_predict.data_set](milie_args_predict)
    else:
        milie_args = _set_to_first_list_element(milie_args)
        dataset_handler = DATASET_CLASS[milie_args.data_set](milie_args)

    if tokenizer is not None:
        dataset_handler.tokenizer = tokenizer
    return dataset_handler
