import numpy as np
import transformers
from datasets.arrow_dataset import Dataset
from typing import Union, Optional, Tuple
import torch
from datasets import load_metric
from tqdm import tqdm
from math import ceil
from omegaconf.dictconfig import DictConfig
import logging

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

from .al_strategy_utils import (
    get_X_pool_subsample,
    get_similarities,
    filter_by_uncertainty,
    filter_by_metric,
    calculate_bleuvar_scores,
    assign_ue_scores_for_unlabeled_data,
    take_idx,
    calculate_unicentroid_mahalanobis_distance,
)

from utils.transformers_dataset import TransformersDataset
from construct.transformers_api.construct_transformers_wrapper import (
    construct_transformers_wrapper,
)


log = logging.getLogger()

### Abstractive summarization strategies


def sequence_score_sampling(
    model,
    X_pool: Union[np.ndarray, Dataset, TransformersDataset],
    n_instances: int,
    **kwargs,
):
    # Filtering part begin
    filtering_mode = kwargs.get("filtering_mode", None)
    uncertainty_threshold = kwargs.get("uncertainty_threshold", None)
    uncertainty_mode = kwargs.get(
        "uncertainty_mode", "absolute"
    )  # "relative" or "absolute"
    # Filtering part end
    generate_output = model.generate(X_pool, to_numpy=True)
    scores = generate_output["sequences_scores"]
    sequences_ids = generate_output["sequences"]
    # The larger the score, the more confident the model is
    uncertainty_estimates = -scores
    # Filtering part begin
    if filtering_mode == "uncertainty":
        query_idx, uncertainty_estimates = filter_by_uncertainty(
            uncertainty_estimates=uncertainty_estimates,
            uncertainty_threshold=uncertainty_threshold,
            uncertainty_mode=uncertainty_mode,
            n_instances=n_instances,
        )
    elif filtering_mode in ["rouge1", "rouge2", "rougeL", "sacrebleu"]:
        query_idx, uncertainty_estimates = filter_by_metric(
            uncertainty_estimates=uncertainty_estimates,
            uncertainty_threshold=uncertainty_threshold,
            uncertainty_mode=uncertainty_mode,
            n_instances=n_instances,
            texts=X_pool[model.data_config["text_name"]],
            generated_sequences_ids=sequences_ids,
            tokenizer=model.tokenizer,
            metric_cache_dir=model.cache_dir / "metrics",
            metric_name=filtering_mode,
            agg=kwargs.get("filtering_aggregation", "precision"),
        )
    else:
        argsort = np.argsort(-uncertainty_estimates)
        query_idx = argsort[:n_instances]
    # Filtering part end
    query = take_idx(X_pool, query_idx)

    return query_idx, query, uncertainty_estimates


# def mc_dropout_sequence_score_sampling(
#     model,
#     X_pool: Union[np.ndarray, Dataset, TransformersDataset],
#     n_instances: int,
#     **kwargs,
# ):
#     mc_iterations = kwargs.get("mc_iterations", 10)
#     # Filtering args
#     filtering_mode = kwargs.get("filtering_mode", None)
#     uncertainty_threshold = kwargs.get("uncertainty_threshold", None)
#     uncertainty_mode = kwargs.get(
#         "uncertainty_mode", "absolute"
#     )  # "relative" or "absolute"
#     metric_name = kwargs.get("var_metric", "sacrebleu")
#     bleuvar_threshold = kwargs.get("uncertainty_threshold", 1.0)
#
#     X_pool_subsample, subsample_indices = get_X_pool_subsample(
#         X_pool, mc_iterations, model.seed
#     )
#
#     model.enable_dropout()
#     log_scores = []  # mc_iterations x num_samples
#     summaries = []
#     for _ in range(mc_iterations):
#         generate_output = model.generate(
#             X_pool_subsample, return_decoded_preds=True, to_eval_mode=False
#         )
#         log_scores.append(generate_output["sequences_scores"])
#         summaries.append(generate_output["predictions"])
#
#     mc_scores = np.exp(log_scores)
#     scores = np.mean(mc_scores)
#
#     if filtering_mode == "uncertainty":
#         summaries = [
#             [text if len(text) > 0 else "CLS" for text in mc_sequence] for mc_sequence in summaries
#         ]  # fix empty texts
#         bleu_vars = calculate_bleuvar_scores(summaries, metric_name=metric_name, cache_dir=model.cache_dir / "metrics")
#         subsample_query_idx, bleu_vars = filter_by_uncertainty(
#             bleu_vars, bleuvar_threshold, uncertainty_mode, n_instances
#         )
#     else:
#         pass
#         # subsample_query_idx = np.argsort(-bleu_vars)[:n_instances]
#
#     generation_output = model.generate(X_pool, to_numpy=True)
#     scores = generation_output["sequences_scores"]
#     sequences_ids = generation_output["sequences"]
#     # The larger the score, the more confident the model is
#     uncertainty_estimates = -scores
#     # Filtering part begin
#     if filtering_mode == "uncertainty":
#         query_idx, uncertainty_estimates = filter_by_uncertainty(
#             uncertainty_estimates=uncertainty_estimates,
#             uncertainty_threshold=uncertainty_threshold,
#             uncertainty_mode=uncertainty_mode,
#             n_instances=n_instances,
#         )
#     if filtering_mode in ["rouge1", "rouge2", "rougeL", "sacrebleu"]:
#         query_idx, uncertainty_estimates = filter_by_metric(
#             uncertainty_estimates=uncertainty_estimates,
#             uncertainty_threshold=uncertainty_threshold,
#             uncertainty_mode=uncertainty_mode,
#             n_instances=n_instances,
#             texts=X_pool[model.data_config["text_name"]],
#             generated_sequences_ids=sequences_ids,
#             tokenizer=model.tokenizer,
#             metric_cache_dir=model.cache_dir / "metrics",
#             metric_name=filtering_mode,
#             agg=kwargs.get("filtering_aggregation", "precision"),
#         )
#     else:
#         argsort = np.argsort(-uncertainty_estimates)
#         query_idx = argsort[:n_instances]
#     # Filtering part end
#     query = take_idx(X_pool, query_idx)
#
#     return query_idx, query, uncertainty_estimates


def longest_sampling(
    model: "ModalTransformersWrapper",
    X_pool: Union[np.ndarray, Dataset, TransformersDataset],
    n_instances: int,
    **kwargs,
):
    tokenizer = model.tokenizer
    documents_name = model.data_config["text_name"]

    def tokenize_fn(instances):
        encoded = tokenizer(instances[documents_name])
        return encoded

    tokenized_data = X_pool.map(tokenize_fn, batched=True)
    uncertainty_estimates = np.array([len(i) for i in tokenized_data["input_ids"]])
    argsort = np.argsort(-uncertainty_estimates)
    query_idx = argsort[:n_instances]
    query = take_idx(X_pool, query_idx)

    return query_idx, query, uncertainty_estimates


def longest_generation_sampling(
    model: "ModalTransformersWrapper",
    X_pool: Union[np.ndarray, Dataset, TransformersDataset],
    n_instances: int,
    **kwargs,
):
    # Filtering part begin
    filtering_mode = kwargs.get("filtering_mode", None)
    uncertainty_threshold = kwargs.get("uncertainty_threshold", None)
    uncertainty_mode = kwargs.get(
        "uncertainty_mode", "absolute"
    )  # "relative" or "absolute"
    # Filtering part end
    generation_output = model.generate(X_pool, to_numpy=True)
    scores = generation_output["sequences_scores"]
    sequences_ids = generation_output["sequences"]
    # The larger the score, the more confident the model is
    uncertainty_estimates = np.array([len(x) for x in sequences_ids], dtype=float)
    # Filtering part begin
    if filtering_mode == "uncertainty":
        query_idx, uncertainty_estimates = filter_by_uncertainty(
            uncertainty_estimates=uncertainty_estimates,
            uncertainty_threshold=uncertainty_threshold,
            uncertainty_mode=uncertainty_mode,
            n_instances=n_instances,
            uncertainty_scores=-scores,
        )
    elif filtering_mode in ["rouge1", "rouge2", "rougeL", "sacrebleu"]:
        query_idx, uncertainty_estimates = filter_by_metric(
            uncertainty_estimates=uncertainty_estimates,
            uncertainty_threshold=uncertainty_threshold,
            uncertainty_mode=uncertainty_mode,
            n_instances=n_instances,
            texts=X_pool[model.data_config["text_name"]],
            generated_sequences_ids=sequences_ids,
            tokenizer=model.tokenizer,
            metric_cache_dir=model.cache_dir / "metrics",
            metric_name=filtering_mode,
            agg=kwargs.get("filtering_aggregation", "precision"),
        )
    else:
        argsort = np.argsort(-uncertainty_estimates)
        query_idx = argsort[:n_instances]
    # Filtering part end
    query = take_idx(X_pool, query_idx)

    return query_idx, query, uncertainty_estimates


def sequence_score_stochastic_sampling(
    model: "ModalTransformersWrapper",
    X_pool: Union[np.ndarray, Dataset, TransformersDataset],
    n_instances: int,
    **kwargs,
):
    aggregation = kwargs.get("aggregation", "var")
    func_for_agg = getattr(np, aggregation)
    larger_is_more_uncertain = True
    if aggregation in ["mean", "median", "max", "min"]:
        larger_is_more_uncertain = False
    use_log = kwargs.get("use_log", False)

    generate_kwargs = dict(to_numpy=True, do_sample=False, to_eval_mode=False)
    if kwargs.get("enable_dropout", False):
        # Since BART exploits F.dropout instead of nn.Dropout, we can only turn it on via .train()
        model.model.train()
    else:
        model.model.eval()
        generate_kwargs["do_sample"] = True
        generate_kwargs["top_p"] = kwargs.get("generate_top_p", 0.95)

    mc_iterations = kwargs.get("mc_iterations", 5)
    if kwargs.get("subsample_size_mc_dropout", True):
        X_pool_subsample, subsample_indices = get_X_pool_subsample(
            X_pool, mc_iterations, model.seed
        )
    else:
        X_pool_subsample, subsample_indices = X_pool, np.arange(len(X_pool))

    summaries, log_scores = [], []
    for _ in range(mc_iterations):
        generate_output = model.generate(X_pool_subsample, **generate_kwargs)
        log_scores.append(generate_output["sequences_scores"])
        summaries.append(generate_output["predictions"])

    scores = np.r_[log_scores]
    if not use_log:
        scores = np.exp(scores)
    subsample_uncertainty_estimates = func_for_agg(scores, axis=0)
    if larger_is_more_uncertain:
        subsample_uncertainty_estimates = -subsample_uncertainty_estimates

    argsort = np.argsort(subsample_uncertainty_estimates)
    subsample_query_idx = argsort[:n_instances]
    query = X_pool_subsample.select(subsample_query_idx)
    query_idx = subsample_indices[subsample_query_idx]

    uncertainty_estimates = assign_ue_scores_for_unlabeled_data(
        len(X_pool), subsample_indices, subsample_uncertainty_estimates
    )

    return query_idx, query, uncertainty_estimates


def label_discrimination_sampling(
    model,
    X_pool: Union[Dataset, TransformersDataset],
    n_instances: int,
    X_train: Union[Dataset, TransformersDataset],
    discriminator_test_size: Optional[int] = None,
    **kwargs,
):
    human_summaries = np.array(X_train[model.data_config["label_name"]])
    generated_summaries_tokenized = model.generate(X_pool, to_numpy=True)["sequences"]
    generated_summaries = model.tokenizer.batch_decode(
        generated_summaries_tokenized, skip_special_tokens=True
    )

    np.random.seed(model.seed)
    gen_sum_train_idx = np.random.choice(
        range(len(generated_summaries)), len(human_summaries), False
    )
    gen_sum_query_idx = np.setdiff1d(range(len(generated_summaries)), gen_sum_train_idx)
    generated_summaries_for_training = np.array(generated_summaries)[gen_sum_train_idx]

    # Texts
    summaries = np.hstack([human_summaries, generated_summaries_for_training])
    # Labels
    is_human = np.hstack(
        [
            np.ones_like(human_summaries, dtype=np.int64),
            np.zeros_like(generated_summaries_for_training, dtype=np.int64),
        ]
    )

    if discriminator_test_size is not None:
        train_indices, test_indices = train_test_split(
            range(len(is_human)),
            test_size=discriminator_test_size,
            stratify=is_human,
            shuffle=True,
            random_state=model.seed,
        )
        test_summaries, test_is_human = summaries[test_indices], is_human[test_indices]
        summaries, is_human = summaries[train_indices], is_human[train_indices]
    else:
        test_summaries, test_is_human, test_indices = None, None, None

    id2label = {0: "fake", 1: "real"}
    default_data_config = {
        "dataset_name": None,
        "text_name": "text",
        "label_name": "label",
    }
    train_data = TransformersDataset(
        {"text": summaries, "label": is_human},
        text_column_name="text",
        label_column_name="label",
        id2label=id2label,
        task="cls",
    )
    if test_summaries is not None:
        dev_data = TransformersDataset(
            {"text": test_summaries, "label": test_is_human},
            text_column_name="text",
            label_column_name="label",
            id2label=id2label,
            task="cls",
        )
    else:
        dev_data = None
    query_data = TransformersDataset(
        {"summary": np.array(generated_summaries)[gen_sum_query_idx]},
        text_column_name="summary",
        task="abs-sum",
    )

    (model.cache_dir / "discriminator").mkdir(exist_ok=True)
    discriminator_model_config = kwargs["discriminator_model"]
    discriminator_general_config = DictConfig(
        dict(
            seed=discriminator_model_config.seed,
            cache_model_and_dataset=model.cache_model,
            cache_dir=str(model.cache_dir / "discriminator"),
        )
    )

    discriminator_model = construct_transformers_wrapper(
        config=discriminator_general_config,
        model_cfg=discriminator_model_config,
        dev_data=dev_data,
        id2label=id2label,
        name="discriminator",
        default_data_config=default_data_config,
    )
    discriminator_model.fit(
        train_data, from_scratch=False
    )  # to avoid re-creation of the model

    if discriminator_test_size is not None:
        test_is_human_probas = discriminator_model.predict_proba(
            dev_data, to_numpy=True
        )[:, 1]
        discriminator_test_score = roc_auc_score(test_is_human, test_is_human_probas)
    else:
        discriminator_test_score = None

    all_is_human = np.ones(len(generated_summaries_tokenized))
    all_is_human[gen_sum_query_idx] = discriminator_model.predict_proba(
        query_data, to_numpy=True
    )[:, 1]
    # Remove examples selected to train classifier from consideration
    all_is_human[gen_sum_train_idx] = 1
    # take instances with small is_human probability
    uncertainty_estimates = 1 - all_is_human
    argsort = np.argsort(-uncertainty_estimates)
    query_idx = argsort[:n_instances]
    query = take_idx(X_pool, query_idx)

    query_meta = {"discriminator_test_score": discriminator_test_score}

    return query_idx, query, uncertainty_estimates, query_meta


def ngram_sampling(
    model,
    X_pool: Union[Dataset, TransformersDataset],
    n_instances: int,
    X_train: Union[Dataset, TransformersDataset],
    **strategy_kwargs,
) -> Tuple[np.ndarray, Union[Dataset, TransformersDataset], np.ndarray]:
    text_name = model.data_config["text_name"]
    ngram_range = strategy_kwargs.get("ngram_range", (1, 1))
    if isinstance(ngram_range, int):
        ngram_range = (ngram_range, ngram_range)
    truncate_texts = strategy_kwargs.get("truncate_texts", True)
    if truncate_texts:
        tokenizer = model.tokenizer
        X_train_texts = tokenizer.batch_decode(
            tokenizer(X_train[text_name], truncation=True)["input_ids"]
        )
        X_pool_texts = tokenizer.batch_decode(
            tokenizer(X_pool[text_name], truncation=True)["input_ids"]
        )

    counter = CountVectorizer(ngram_range=ngram_range)
    unlabeled_ngramms = counter.fit_transform(X_pool_texts)
    labeled_ngramms = counter.transform(X_train_texts)

    ngrams_unl_freq = unlabeled_ngramms.sum(axis=0).A.ravel()
    ngrams_lab_freq = labeled_ngramms.sum(axis=0).A.ravel() + 1
    ngrams_scores = np.log(ngrams_unl_freq / ngrams_lab_freq)

    uncertainty_estimates = (
        unlabeled_ngramms.dot(ngrams_scores) / unlabeled_ngramms.sum(axis=1).A.ravel()
    )
    query_idx = np.argsort(-uncertainty_estimates)[:n_instances]
    query = take_idx(X_pool, query_idx)
    return query_idx, query, uncertainty_estimates


def beam_variance(model, X_pool, n_instances, **kwargs):
    """
    for each object, get n most probable summaries
    from beam search
    and calculate uncertainty as BLEUVar between these summaries
    """
    metric_name = kwargs.get("var_metric", "sacrebleu")
    num_beams = kwargs.get("num_beams", 10)
    num_summaries_per_object = kwargs.get("num_summaries_per_object", 10)

    filtering_mode = kwargs.get("filtering_mode", None)
    bleuvar_threshold = kwargs.get("uncertainty_threshold", 1.0)
    uncertainty_mode = kwargs.get(
        "uncertainty_mode", "absolute"
    )  # "relative" or "absolute"
    assert (
        num_summaries_per_object <= num_beams
    ), "num_beams must be >= num_summaries_per_object"

    # `num_beams / 4` since by default we use 4 beams, and now we use `num_beams / 4` times more
    X_pool_subsample, subsample_indices = get_X_pool_subsample(
        X_pool, num_beams / 4, model.seed
    )
    raw_summaries = model.generate(
        X_pool_subsample,
        return_scores=False,
        to_eval_mode=False,
        num_beams=num_beams,
        num_return_sequences=num_summaries_per_object,
    )

    summaries = model.tokenizer.batch_decode(
        raw_summaries, skip_special_tokens=True
    )  # (N * num_beams,)

    sacrebleu_metric = load_metric(metric_name, cache_dir=model.cache_dir / "metrics")
    sacrebleu_vars = []

    for idx in tqdm(range(len(X_pool_subsample))):
        sacrebleuvar_sum = 0.0
        for i in range(num_summaries_per_object):
            for j in range(num_summaries_per_object):
                if i == j:
                    continue

                sacrebleuvar_sum += (
                    1
                    - round(
                        sacrebleu_metric.compute(
                            predictions=[summaries[idx * num_summaries_per_object + i]],
                            references=[
                                [summaries[idx * num_summaries_per_object + j]]
                            ],
                        )["score"],
                        4,
                    )
                    / 100.0
                ) ** 2

        sacrebleu_vars.append(
            1
            / (num_summaries_per_object * (num_summaries_per_object - 1))
            * sacrebleuvar_sum
        )

    sacrebleu_vars = np.array(sacrebleu_vars)
    if filtering_mode == "uncertainty":
        subsample_query_idx, sacrebleu_vars = filter_by_uncertainty(
            sacrebleu_vars, bleuvar_threshold, uncertainty_mode, n_instances
        )
    else:
        subsample_query_idx = np.argsort(-sacrebleu_vars)[:n_instances]

    query = take_idx(X_pool_subsample, subsample_query_idx)
    query_idx = subsample_indices[subsample_query_idx]

    uncertainty_estimates = assign_ue_scores_for_unlabeled_data(
        len(X_pool), subsample_indices, sacrebleu_vars
    )

    return query_idx, query, uncertainty_estimates


# https://arxiv.org/pdf/2006.08344.pdf
def bleuvar_sampling(
    model: "ModalTransformersWrapper",
    X_pool: Union[Dataset, TransformersDataset],
    n_instances: int,
    **kwargs,
):
    mc_iterations = kwargs.get("mc_iterations", 10)
    metric_name = kwargs.get("var_metric", "sacrebleu")

    filtering_mode = kwargs.get("filtering_mode", None)
    bleuvar_threshold = kwargs.get("uncertainty_threshold", 1.0)
    uncertainty_mode = kwargs.get(
        "uncertainty_mode", "absolute"
    )  # "relative" or "absolute"
    if kwargs.get("subsample_size_mc_dropout", True):
        X_pool_subsample, subsample_indices = get_X_pool_subsample(
            X_pool, mc_iterations, model.seed
        )
    else:
        X_pool_subsample, subsample_indices = X_pool, np.arange(len(X_pool))

    generate_kwargs = dict(
        return_decoded_preds=True, do_sample=False, to_eval_mode=False
    )
    if kwargs.get("enable_dropout", False):
        model.enable_dropout()  # model.model.train()
    else:
        model.model.eval()
        generate_kwargs["do_sample"] = True
        generate_kwargs["top_p"] = kwargs.get("generate_top_p", 0.95)

    summaries = []  # mc_iterations x len(X_pool_subsample) of str
    for _ in range(mc_iterations):
        generated_texts = model.generate(X_pool_subsample, **generate_kwargs)[
            "predictions"
        ]
        generated_texts = [
            text if len(text) > 0 else "CLS" for text in generated_texts
        ]  # fix empty texts
        summaries.append(generated_texts)

    # sacrebleu is normally more robust than bleu
    bleu_vars = calculate_bleuvar_scores(
        summaries,
        metric_name=metric_name,
        cache_dir=model.cache_dir / "metrics",
        tokenizer=model.tokenizer,
    )

    if filtering_mode == "uncertainty":
        subsample_query_idx, bleu_vars = filter_by_uncertainty(
            bleu_vars, bleuvar_threshold, uncertainty_mode, n_instances
        )
    else:
        subsample_query_idx = np.argsort(-bleu_vars)[:n_instances]

    query = take_idx(X_pool_subsample, subsample_query_idx)
    query_idx = subsample_indices[subsample_query_idx]

    uncertainty_estimates = assign_ue_scores_for_unlabeled_data(
        len(X_pool), subsample_indices, bleu_vars
    )

    return query_idx, query, uncertainty_estimates


def frequent_words_sampling(model, X_pool, n_instances, text_field_name: str = None):
    """
    model-agnostic sampling
    returns n_instances from X_pool
    that have the biggest fraction of unique words.
    we consider that text is already tokenized,
    and text_field_name is the field name of iterable of tokens
    """

    def get_mapper(mode: str, text_field_name: str):
        """
        mode is either dataset or ndarray
        if dataset, it returns dict so that
        transformers dataset would be modified
        after applying map function.
        if mode = ndarray,
        function will simply return np.ndarray
        with fractions of unique words.
        in this case text_field_name is ignored
        """
        assert mode in ["dataset", "ndarray"], 'mode must be "dataset" or "ndarray"'

        def _count_unique_words_fraction(data):
            if mode == "dataset":
                text = data[text_field_name]
                return {"uniques_fraction": len(set(text)) / len(text)}
            else:
                return len(set(data)) / len(data)

        return _count_unique_words_fraction

    if isinstance(X_pool, Dataset) or isinstance(X_pool, TransformersDataset):
        fracs = X_pool.map(get_mapper("dataset", text_field_name))
        uncertainty_estimates = fracs.to_dict()["uniques_fraction"]
    else:
        uncertainty_estimates = map(get_mapper("ndarray", ""), X_pool)

    uncertainty_estimates = np.array(uncertainty_estimates)
    argsort = np.argpartition(-uncertainty_estimates, kth=n_instances)
    query_idx = argsort[:n_instances]
    query = take_idx(X_pool, query_idx)

    return query_idx, query, uncertainty_estimates


def embeddings_similarity(
    model, X_pool, n_instances, X_train, seed=None, device=None, **kwargs
):
    cache_dir = kwargs.get("cache_dir")
    model_name = kwargs.get("embeddings_model_name", "bert-base-uncased")
    text_name = kwargs.get("text_name", "document")
    label_name = kwargs.get("label_name", "summary")
    obj_id_name = kwargs.get("obj_id_name", "id")
    subsample_ratio = kwargs.get("subsample_ratio", 5)
    lamb = kwargs.get("lambda", 0.667)
    normalize = kwargs.get("normalize", True)
    average = kwargs.get("average", False)
    use_maha_sims = kwargs.get("use_maha_sims", False)
    filter_outliers = kwargs.get("filter_outliers", None)
    filtering_mode = kwargs.get("filtering_mode", None)
    batch_size = kwargs.get("embeddings_batch_size", 100)

    if filtering_mode is not None:
        uncertainty_threshold = kwargs.get("uncertainty_threshold", 0.0)
        uncertainty_mode = kwargs.get(
            "uncertainty_mode", "absolute"
        )  # "relative" or "absolute"
        generation_output = model.generate(X_pool, to_numpy=True)
        scores = generation_output["sequences_scores"]
        sequences_ids = generation_output["sequences"]

        if filtering_mode == "uncertainty":
            query_idx, uncertainty_estimates = filter_by_uncertainty(
                uncertainty_estimates=-scores,
                uncertainty_threshold=uncertainty_threshold,
                uncertainty_mode=uncertainty_mode,
                n_instances=n_instances,
            )

        elif filtering_mode in ["rouge1", "rouge2", "rougeL", "sacrebleu"]:
            query_idx, uncertainty_estimates = filter_by_metric(
                uncertainty_threshold=uncertainty_threshold,
                uncertainty_mode=uncertainty_mode,
                texts=X_pool[model.data_config["text_name"]],
                generated_sequences_ids=sequences_ids,
                tokenizer=model.tokenizer,
                metric_cache_dir=model.cache_dir / "metrics",
                metric_name=filtering_mode,
                agg=kwargs.get("filtering_aggregation", "precision"),
                modify_uncertainties=False,
            )

    if use_maha_sims:
        log.info("Using mahalanobis similarities")
    # subsample size = pool size / subsample_ratio
    if device is None:
        device = model.model.device
    if seed is None:
        seed = model.seed

    if subsample_ratio is not None:
        X_pool_subsample, subsample_indices = get_X_pool_subsample(
            X_pool, subsample_ratio, seed
        )  # `subsample_indices` indicated the indices of the subsample in the original data
    else:
        X_pool_subsample = X_pool

    similarities, counts, embeddings = get_similarities(
        model_name,
        X_pool_subsample,
        X_train,
        use_maha_sims=use_maha_sims,
        normalize=normalize,
        average=average,
        text_name=text_name,
        device=device,
        cache_dir=cache_dir,
        return_embeddings=True,
        batch_size=batch_size,
    )
    num_obs = len(similarities)
    if X_train is None:
        X_train = []

    labeled_indices = list(range(num_obs - len(X_train), num_obs))
    unlabeled_indices = list(range(num_obs - len(X_train)))

    unlabeled_indices_without_queries = list(unlabeled_indices)
    top_scores_indices = []
    top_scores = []

    if filter_outliers is not None:
        outliers_idx = []
        num_outliers = round(filter_outliers * num_obs)

    for i_query in range(n_instances):
        # Calculate similarities
        similarities_with_unlabeled = (
            similarities[unlabeled_indices][:, unlabeled_indices_without_queries].sum(
                dim=1
            )
            - 1
        ) / (len(unlabeled_indices_without_queries) - 1)
        if len(labeled_indices) == 0:
            similarities_with_labeled = torch.zeros(len(unlabeled_indices)).to(
                similarities_with_unlabeled
            )
        else:
            similarities_with_labeled = similarities[unlabeled_indices][
                :, labeled_indices
            ].mean(dim=1)
        scores = (
            (
                similarities_with_unlabeled * lamb
                - similarities_with_labeled * (1 - lamb)
            )
            .cpu()
            .detach()
            .numpy()
        )
        scores[top_scores_indices] = -np.inf
        if filter_outliers is not None and len(outliers_idx) > 0:
            scores[outliers_idx.cpu().numpy()] = -np.inf

        # TODO: BUG when subsample_ratio is not None
        most_similar_idx = np.argmax(scores)
        labeled_indices.append(most_similar_idx)
        unlabeled_indices_without_queries.remove(most_similar_idx)
        top_scores_indices.append(most_similar_idx)
        top_scores.append(scores[most_similar_idx])

        if filter_outliers is not None and i_query > 0:
            outliers_idx = (
                calculate_unicentroid_mahalanobis_distance(embeddings, labeled_indices)
                .topk(num_outliers)
                .indices
            )

    scores[top_scores_indices] = top_scores
    top_scores_idx = [counts.index(i) for i in top_scores_indices]
    scores = scores[counts]

    if subsample_ratio is not None:
        query_idx = subsample_indices[top_scores_idx]
        uncertainty_estimates = assign_ue_scores_for_unlabeled_data(
            len(X_pool), subsample_indices, scores
        )
    else:
        query_idx = np.array(top_scores_idx)
        uncertainty_estimates = scores

    query = X_pool.select(query_idx)

    return query_idx, query, uncertainty_estimates
