from construct.construct_active_learner import QUERY_STRATEGIES


class MetaActiveLearner:
    def __init__(self, learners, config):
        use_ups = config.sampling_type is not None
        postfix = ""
        if use_ups:
            postfix += "_" + config.sampling_type
        if ("split_by_tokens" in config) and (config.split_by_tokens):
            postfix += "_tokens"
        elif "split_by_tokens" in config:  # avoid adding "_samples" for classification
            postfix += "_samples"

        self.query_strategy = QUERY_STRATEGIES[f"{config.strategy}{postfix}"]

        self.estimators = [learner.estimator for learner in learners]
        self.ups = False
        self.ups = use_ups

    def query(
        self,
        X_pool,
        n_instances,
        k_confident_to_save=None,
        T=None,
        require_sampling=True,
    ):
        if self.ups:
            result = self.query_strategy(
                self.estimators,
                X_pool,
                n_instances,
                k_confident_to_save,
                T,
                require_sampling=require_sampling,
            )
        else:
            result = self.query_strategy(self.estimators, X_pool, n_instances)

        if not isinstance(result, tuple):
            result = (result, X_pool[result])
        return result
