import logging
from typing import Union

import math
import torch
import numpy as np
from toma import toma
from tqdm.auto import tqdm
from datasets.arrow_dataset import Dataset

from .al_strategy_utils import (
    take_idx,
    calculate_bald_score_cls,
    calculate_bald_score_ner,
    get_query_idx_for_selecting_by_number_of_tokens,
)
from .strategy_utils.batchbald.consistent_dropout import make_dropouts_consistent
from utils.transformers_dataset import TransformersDataset

log = logging.getLogger()

def compute_total_correlation(log_probs_N_K_C: torch.Tensor, device) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape
    total_correlation_N = torch.empty(N, dtype=torch.double)
    pbar = tqdm(total=N, desc="Total correlation", leave=False)

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
        n = log_probs_n_K_C.shape[0]
        a = torch.matmul(torch.exp(log_probs_n_K_C).permute(2, 1, 0)[:, :, :, None].to(device), torch.exp(log_probs_n_K_C).permute(2, 1, 0)[:, :, None, :].to(device))

        zero_diag_mask = (torch.ones(n) - torch.eye(n)).repeat(C, K, 1, 1)
        a = a * zero_diag_mask.to(device)
        a = a.sum(dim=(1, 2)) / K
        a = a.t()
    
        c = torch.matmul(torch.exp(mean_log_probs_n_C).permute(1, 0)[:, :, None].to(device), torch.exp(mean_log_probs_n_C).permute(1, 0)[:, None, :].to(device))
        zero_diag_mask2 = (torch.ones(n) - torch.eye(n)).repeat(C, 1, 1)
        c = c * zero_diag_mask2.to(device)
        c = c.sum(dim=1)
        c = c.t()
    
        nats_n = (a * (torch.log(1.0 + a) - torch.log(1.0 + c))).sum(dim=1)
    
        total_correlation_N[start:end].copy_(-nats_n)
        pbar.update(end - start)

    pbar.close()

    return total_correlation_N

def lbb_sampling(
    model,
    X_pool: Union[np.ndarray, Dataset, TransformersDataset],
    n_instances: int,
    select_by_number_of_tokens: bool = False,
    **bald_kwargs,
):
    """
    Jointly score points by estimating the pairwise mutual information of a pool
    points and BALD mutual information.
    """
    mc_iterations = bald_kwargs.get("mc_iterations", 10)
    use_stable_dropout = bald_kwargs.get("use_stable_dropout", True)

    # Make dropout consistent inside huggingface model
    if use_stable_dropout:
        make_dropouts_consistent(model.model)
    else:
        model.enable_dropout()

    if bald_kwargs.get("only_head_dropout", False):
        raise NotImplementedError
    else:
        # Stable dropout
        probas = []
        for _ in range(mc_iterations):
            if use_stable_dropout:
                # Reset masks
                model.enable_dropout()
                model.disable_dropout()
            probas_iter = model.predict_proba(X_pool, to_eval_mode=False)
            probas.append(probas_iter)

    if model.task == "cls":
        log_probs_N_K_C = np.log(np.stack(probas, -2))
        uncertainty_estimates = calculate_bald_score_cls(log_probs_N_K_C)
    elif model.task == "ner":
        uncertainty_estimates = calculate_bald_score_ner(probas)
    
    device = "cuda"
    uncertainty_estimates -= compute_total_correlation(torch.from_numpy(log_probs_N_K_C), device).cpu().detach().numpy()
    
    # The larger the score, the more confident the model is
    argsort = np.argsort(-uncertainty_estimates)

    if select_by_number_of_tokens:
        query_idx = get_query_idx_for_selecting_by_number_of_tokens(
            X_pool,
            argsort,
            n_instances,
            tokens_column_name=model.data_config["text_name"],
        )
    else:
        query_idx = argsort[:n_instances]
    query = take_idx(X_pool, query_idx)

    return query_idx, query, uncertainty_estimates