import numpy as np
from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.linear_model import Lasso
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from scipy.optimize import root_scalar, minimize_scalar
from scipy.stats import ncx2, norm
import math


def weighted_mse(y_true, y_pred, weights):
    return np.mean((y_true - y_pred) ** 2 / weights)


def calculate_means(X, z, groups):
    unique_groups = np.unique(groups)
    X_means = np.array([X[groups == group].mean(axis=0) for group in unique_groups])
    z_means = np.array([z[groups == group].mean() for group in unique_groups])
    return X_means, z_means, unique_groups



def cross_validate_lasso(X, z, groups, alpha, n_splits=5):
    
    # Store scores to calculate average performance later
    scores = []
    
    # Encode groups if they are not numeric
    le = LabelEncoder()
    groups_encoded = le.fit_transform(groups)
    
    skf = StratifiedKFold(n_splits=n_splits)
    
    weights = np.array([np.max([np.var(z), 0.01]) / len(z[groups_encoded == group])
                            for group in np.unique(groups_encoded)])
    
    # Perform the cross-validation
    for train_idx, test_idx in skf.split(X = X, y = groups):
        
        X_train, z_train = calculate_means(X[train_idx], z[train_idx], groups_encoded[train_idx])[:2]
        
        model = Lasso(alpha=alpha)
        model.fit(X_train, z_train)
        
        X_test, z_test = calculate_means(X[test_idx], z[test_idx], groups_encoded[test_idx])[:2]
        z_pred = model.predict(X_test)
        
        score = weighted_mse(z_test, z_pred, weights)
        scores.append(score)
        
    # Return the average score across all folds
    return np.mean(scores)


# The code for this function was taken from 
# https://gitlab.com/dsbowen/conditional-inference/-/blob/master/src/multiple_inference/bayes/normal.py#L70
def compute_robust_critical_value(
    m2: float, kurtosis: float = np.inf, alpha: float = 0.05
) -> tuple[float, np.ndarray, np.ndarray]:

    """Compute the critical value for robust confidence intervals.

    Args:
        m2 (float): Equality constraint on :math:`E[b^2]`.
        kurtosis (float, optional): Estimated kurtosis of the prior distribution.
            Defaults to np.inf.
        alpha (float, optional): Significance level. Defaults to .05.

    Returns:
        tuple[float, np.ndarray, np.ndarray]: Critical value, array of :math:`x` values
        for the least favorable mass function, array of probabilities for the least
        favorable mass function.

    Notes:

        See Armstrong et al., 2020 for mathematical details. This function is equivalent
        to the ``cva`` function in the
        `ebci R package <https://cran.r-project.org/web/packages/ebci/index.html>`_ and
        uses the same variable names and tolerance thresholds.
    """

    def rho0(chi):
        t0 = rt0(chi)[0]
        return r(m2, chi) if m2 >= t0 else r(t0, chi) + (m2 - t0) * r1(t0, chi)

    def rt0(chi):
        # returns t0, inflection point
        # the inflection point is denoted t1 in the paper
        if (chi_sq := chi**2) < 3:
            return 0, 0

        if abs(r2(inflection := chi_sq - 1.5, chi=chi)) > tol:
            inflection = root_scalar(
                lambda x: r2(x, chi=chi), bracket=(chi_sq - 3, chi_sq), method="brentq"
            ).root

        func = lambda t: r(t, chi) - t * r1(t, chi) - r(0, chi)
        lower, upper = inflection, 2 * chi_sq
        t0 = (
            root_scalar(lambda x: func(x), bracket=(lower, upper), method="brentq").root
            if func(lower) < 0
            else lower
        )
        return t0, inflection

    def r(t, chi):
        # called r0 in the paper
        sqrt_t = np.sqrt(t)
        return (
            (norm.cdf(-sqrt_t - chi) + norm.cdf(sqrt_t - chi))
            if sqrt_t - chi < 5
            else 1
        )

    def r1(t, chi):
        # first derivative of r
        sqrt_t = np.sqrt(t)
        if t < 1e-8:
            # apply L'Hopital's rule
            return chi * norm.pdf(chi)
        return (norm.pdf(sqrt_t - chi) - norm.pdf(sqrt_t + chi)) / (2 * sqrt_t)

    def r2(t, chi):
        # second derivative of r
        sqrt_t = np.sqrt(t)
        if t < 2e-6:
            # apply L'Hopital's rule
            return chi * (chi**2 - 3) * norm.pdf(chi) / 6
        coef0 = chi * sqrt_t
        coef1 = t + 1
        return (
            (coef0 + coef1) * norm.pdf(sqrt_t + chi)
            + (coef0 - coef1) * norm.pdf(sqrt_t - chi)
        ) / (4 * t**1.5)

    def r3(t, chi):
        # third derivative of r
        sqrt_t = np.sqrt(t)
        if t < 2e-4:
            # apply L'Hopital's rule
            return (chi**5 - 10 * chi**3 + 15 * chi) * norm.pdf(chi) / 60
        coef0 = t**2 + (2 + chi**2) * t + 3
        coef1 = 2 * chi * t**1.5 + 3 * chi * sqrt_t
        return (
            (coef0 - coef1) * norm.pdf(chi - sqrt_t)
            - (coef0 + coef1) * norm.pdf(chi + sqrt_t)
        ) / (8 * t**2.5)

    def rho(chi):
        # return (optimum loss, x values for pmf, probabilities for pmf)
        if kurtosis == 1:
            return r(m2, chi), np.array([0, m2]), np.array([0, 1])

        r0, t0 = rho0(chi), rt0(chi)[0]
        if m2 >= t0:
            return r0, np.array([0, m2]), np.array([0, 1])

        if kurtosis == np.inf or m2 * kurtosis >= t0:
            return r0, np.array([0, t0]), np.array([1 - m2 / t0, m2 / t0])

        tbar = lam(0, chi)[1]
        lammax = lambda x0: delta(0, x0, chi) if x0 >= tbar else max(lam(x0, chi)[0], 0)
        loss = (
            lambda x0: r(x0, chi)
            + (m2 - x0) * r1(x0, chi)
            + (kurtosis * m2**2 - 2 * x0 * m2 + x0**2) * lammax(x0)
        )
        result_above = minimize_scalar(loss, bounds=(tbar, t0), method="bounded")
        if tbar > 0:
            result_below = minimize_scalar(loss, bounds=(0, tbar), method="bounded")
            minimum_below, fun_below = result_below.x, result_below.fun
        else:
            minimum_below, fun_below = 0, loss(0)
        minimum, fun = (
            (result_above.x, result_above.fun)
            if result_above.fun < fun_below
            else (minimum_below, fun_below)
        )

        values = np.sort([minimum, lam(minimum, chi)[1]])
        probability = (m2 - values[1]) / (values[0] - values[1])
        return fun, values, np.array([probability, 1 - probability])

    def lam(x0, chi):
        # returns delta(x*, x0, chi), x*
        # where x* = argmax_x(delta(x, x0, chi))
        # check 0, inflection point, t0, and x0
        x = np.sort(rt0(chi))
        x = np.array([0, x[0]]) if x0 >= x[0] else np.unique([0, x0, *x])
        derivatives = delta1(x, x0, chi)
        values = delta(x, x0, chi)
        optimum = values[0], 0
        if (derivatives <= 0).all() and values.argmax() == 0:
            return optimum

        if (np.diff(derivatives >= 0) >= 0).all() and derivatives[-1] <= 0:
            index = max((derivatives >= 0).argmin(), 1)
            bounds = x[index - 1], x[index]
        elif abs(derivatives).min() < 1e-6:
            argmax = values.argmax()
            bounds = x[max(argmax - 1, 0)], x[min(argmax + 1, len(values) - 1)]
        else:
            raise RuntimeError(
                f"There are multiple optima in the function delta(x, x0={x0}, chi={chi})."
            )

        result = minimize_scalar(
            lambda x: -delta(x, x0, chi), bounds=bounds, method="bounded"
        )
        return (-result.fun, result.x) if -result.fun > optimum[0] else optimum

    def delta(x, x0, chi):
        def func(x):
            return (
                0.5 * r2(x0, chi)  # apply L'Hopital's rule
                if abs(x - x0) < 1e-4
                else (r(x, chi) - r(x0, chi) - (x - x0) * r1(x0, chi)) / (x - x0) ** 2
            )

        return func(x) if np.isscalar(x) else np.array([func(x_i) for x_i in x])

    def delta1(x, x0, chi):
        # first derivative of delta
        def func(x):
            if abs(x - x0) < 1e-3:
                # apply L'Hoptial's rule
                return r3(x0, chi)
            return (
                r1(x, chi) + r1(x0, chi) - 2 * (r(x, chi) - r(x0, chi)) / (x - x0)
            ) / (x - x0) ** 2

        return func(x) if np.isscalar(x) else np.array([func(x_i) for x_i in x])

    tol = 1e-12
    critical_value_b = (
        np.sqrt(ncx2.ppf(1 - alpha, nc=m2, df=1))
        if m2 < 100
        else norm.ppf(1 - alpha, np.sqrt(m2))
    )
    if m2 == 0 or kurtosis == 1:
        return critical_value_b, np.array([0, m2]), np.array([0, 1])

    if m2 > 1 / tol and kurtosis != np.inf:
        kurtosis = np.inf

    # get bounds for when kappa == 1 and kappa == infinity
    lower, upper = critical_value_b - 0.01, np.sqrt((m2 + 1) / alpha)
    if abs(rho0(upper) - alpha) > 9e-6:
        upper = root_scalar(
            lambda chi: rho0(chi) - alpha, bracket=(lower, upper), method="brentq"
        ).root

    if rho(upper)[0] - alpha < -1e-5:
        critical_value = root_scalar(
            lambda chi: rho(chi)[0] - alpha, bracket=(lower, upper), method="brentq"
        ).root
    else:
        critical_value = upper
    return critical_value, *rho(critical_value)[1:]



class Estimator:
    import sklearn
    sklearn.set_config(enable_metadata_routing=True)

    def __init__(self, X, z, groups, reg, metric_fn=np.mean):
        self.X, self.z, self.groups = X, z, groups
        self.unique_groups, counts = np.unique(groups, return_counts=True)
        self.n_groups = dict(zip(self.unique_groups, counts))
        self.N = sum(counts)
        self.G = len(self.n_groups)

        # Add subgroup indicators 
        from sklearn.preprocessing import OneHotEncoder
        encoder = OneHotEncoder(sparse_output=False)
        group_indicators = encoder.fit_transform(groups.reshape(-1, 1))
        self.Xpall = np.hstack(
            (self.X, group_indicators)
        )
        
        # Save all the observations for structured regression
        self.zall = z
        self.groupsall = groups

        # Compress to group-level observations
        self.X = {g: self.X[self.groups == g].mean(axis=0) for g in self.unique_groups}
        self.Xp = {g: self.Xpall[self.groups == g].mean(axis=0) for g in self.unique_groups}
        self.z = {g: metric_fn(self.z[self.groups == g]) for g in self.unique_groups}
        
        
        # # Pooled estimator
        # pooled_variance = np.sum(
        #     [
        #         np.var(z[groups == g]) * (self.n_groups[g] - 1)
        #         for g in self.unique_groups
        #     ]
        # ) / np.sum([self.n_groups[g] - 1 for g in self.unique_groups])
        # self.sigma2 = {
        #     g: pooled_variance / self.n_groups[g] for g in self.unique_groups
        # }
        self.sigma2 = {g: np.max([np.var(z[groups == g]), 0.001]) / self.n_groups[g] for g in self.unique_groups}

        # estimate the regression on group-level observations
        reg.fit(
            X=[self.X[g] for g in self.unique_groups],
            y=[self.z[g] for g in self.unique_groups],
            sample_weight=([1 / self.sigma2[g] for g in self.unique_groups]))
        
        # save the predictions
        self.preds = {g: reg.predict([self.X[g]])[0] for g in self.unique_groups}
        
        
    def estimate_structured_regression(self, alphas=np.logspace(-10,10,20)):
        # Run the structured regression approach
        result = [cross_validate_lasso(self.Xpall, self.zall, self.groupsall, alpha = alpha) for alpha in alphas]
        best_alpha_index = np.argmin(result) 
        model = Lasso(alpha = alphas[best_alpha_index])
        model.fit([self.Xp[g] for g in self.unique_groups], [self.z[g] for g in self.unique_groups])
        self.preds_streg = {g: model.predict(self.Xp[g].reshape(1, -1))[0] for g in self.unique_groups}
        
        return self.preds_streg
    
    def estimate_gm(self):
        # Return the overall mean
        gm = np.mean(list(self.z.values()))
        return {g: gm for g in self.unique_groups}
    
    def estimate_js(self):
        # Return the James-Stein estimator that shrinks all observations towards the global mean
        epsilon = np.mean(list(self.z.values()))
        return {
            g: epsilon
            + max(
                [
                    1
                    - (self.G - 3)
                    / np.sum(
                        [
                            (self.z[g] - epsilon) ** 2 / self.sigma2[g]
                            for g in self.unique_groups
                        ]
                    ),
                    0,
                ]
            )
            * (self.z[g] - epsilon)
            for g in self.unique_groups
        }

        
              
    def estimate_eb(self, reg, compute_cis = False):
    
        # cross-fitting
        X_flat = [self.Xp[g] for g in self.unique_groups]
        z_flat = [self.z[g] for g in self.unique_groups]
        groups_flat = [g for g in self.unique_groups]

        kf = KFold(n_splits=10)

        eb = {}
        cis = {}
        
        self.preds_ebreg = {}
        
        for train_idx, test_idx in kf.split(X_flat, z_flat):
            X_train = [X_flat[i] for i in train_idx]
            z_train = [z_flat[i] for i in train_idx]
            groups_train = [groups_flat[i] for i in train_idx]
            
            reg.fit(
                X_train,
                z_train,
                sample_weight=([1 / self.sigma2[g] for g in groups_train])
            )
            X_test = [X_flat[i] for i in test_idx]
            preds_test = reg.predict(X_test)
            
            preds = {}
            for i in range(len(test_idx)):
                g = groups_flat[test_idx[i]]
                preds[g], self.preds_ebreg[g] = preds_test[i], preds_test[i]
                
            self.unique_groups_batch = list(preds.keys())
            
            # follow the setup of https://onlinelibrary.wiley.com/doi/pdf/10.3982/ECTA18597
            # includes what is in the paper but with some minor corrections
            epsilon2 = {g: (preds[g] - self.z[g]) ** 2 for g in self.unique_groups_batch}
            
            
            omega = {g: 1 / len(self.unique_groups_batch) for g in self.unique_groups_batch} 
            # mu2 is A in the paper
            mu2 = np.max(
                [
                    np.sum(
                        [
                            omega[g] * (epsilon2[g] - self.sigma2[g])
                            for g in self.unique_groups_batch
                        ]
                    )
                    / sum(omega.values()),
                    np.sum(
                        [
                            2 * omega[g] ** 2 * self.sigma2[g] ** 2
                            for g in self.unique_groups_batch
                        ]
                    )
                    / (
                        sum(omega.values())
                        * np.sum([omega[g] * self.sigma2[g] for g in self.unique_groups_batch])
                    ),
                ]
            )
            # # without omega
            # mu2 = np.max([
            #     np.mean([(epsilon2[g] - self.sigma2[g]) for g in self.unique_groups_batch]),
            #     np.sum([2 / (len(self.unique_groups_batch) ** 2) * self.sigma2[g] ** 2 for g in self.unique_groups_batch]) / (np.sum([self.sigma2[g] for g in self.unique_groups_batch]) / len(self.unique_groups_batch))
            # ])
            
            if compute_cis:
                kappa = np.max([
                    np.sum([omega[g] * (epsilon2[g] ** 2 - 6 * self.sigma2[g] * epsilon2[g] + 3 * self.sigma2[g] ** 2) for g in self.unique_groups_batch]) / (mu2 ** 2 * np.sum([omega[g] for g in self.unique_groups_batch])),
                    1 + (32 * np.sum([omega[g] ** 2 * self.sigma2[g] ** 4 for g in self.unique_groups_batch])) / (
                            mu2 ** 2 * sum(omega.values())
                            * np.sum([omega[g] * self.sigma2[g] for g in self.unique_groups_batch])
                        )
                ]
                )
            
            # compute weights
            weights = {g: mu2 / (mu2 + self.sigma2[g]) for g in self.unique_groups_batch}
            for g in self.unique_groups_batch:
                eb[g] = preds[g] + weights[g] * (self.z[g] - preds[g])
                if compute_cis:
                    # cis[g] = [eb[g] - 1.96 * weights[g] / np.sqrt(weights[g]) * np.sqrt(self.sigma2[g]), 
                    #           eb[g] + 1.96 * weights[g] / np.sqrt(weights[g]) * np.sqrt(self.sigma2[g])]
                    # cva = compute_robust_critical_value(m2 = float(self.sigma2[g] / mu2), kurtosis = float(kappa))[0]
                    # using infinty for the kurtosis works better
                    import math
                    cva = compute_robust_critical_value(m2 = float(self.sigma2[g] / mu2), kurtosis = math.inf)[0]
                    cis[g] = [eb[g] - cva * np.sqrt(self.sigma2[g]) * weights[g], eb[g] + cva * np.sqrt(self.sigma2[g]) * weights[g]]
                
        if compute_cis:
            return eb, cis
        else:
            return eb
        

