# Erik McGuire

from statsmodels.stats.power import TTestIndPower
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import pingouin as pg
import seaborn as sns
import pandas as pd
import numpy as np
import random

plt.style.use('seaborn')

ZUCO_PTH = "data"
FIGPTH = f"{ZUCO_PTH}/figures/" # For saving for report.
MODEL_PTH = f"{ZUCO_PTH}/models"

def permutation_test(A, B, t, n, R, power, save, plot: bool = True, verbose: bool = True):
    """Two-tailed permutation test for comparing classifiers."""
    model_A, data_A = A
    model_B, data_B = B

    assert(len(data_A["labels"]) == len(data_A["preds"]))
    assert data_A["labels"] == data_B["labels"]
    y_true = data_A["labels"]

    acc_a = accuracy_score(y_true, data_A["preds"])
    acc_b = accuracy_score(y_true, data_B["preds"])
    delta_orig = np.abs(acc_b - acc_a)

    cnt = 0
    temp_As, temp_Bs = [], []
    for r in tqdm(range(1, R + 1), desc="Permuting data"):
        # without replacement: permutation (cf. bootstrap)
        # may alternatively permute labels
        indices = random.sample(range(n), n)
        temp_As.append([data_A["preds"][z] for z in indices])
        temp_Bs.append([data_B["preds"][z] for z in indices])

    x, y = [], []
    deltas = []
    for sample_a, sample_b in zip(temp_As, temp_Bs):
        sample_acc_a = accuracy_score(y_true, sample_a)
        sample_acc_b = accuracy_score(y_true, sample_b)
        delta = np.abs(sample_acc_b - sample_acc_a)
        deltas.append(delta)
        x.append(sample_acc_b)
        y.append(sample_acc_a)
        if delta > delta_orig:
            cnt += 1

    deltas = np.array(deltas)
    if plot:
        plot_deltas(model_A, model_B, deltas, delta_orig, t, save)
    x, y = np.array(x), np.array(y)
    if verbose:
        print(f"The mean of {model_B} permutation accuracies: {x.mean()}.")
        print(f"The mean of {model_A} permutation accuracies: {y.mean()}.\n")
    if plot:
        sns.histplot(y, color="red", label=f"Model A: {model_A}")
        sns.histplot(x, color="skyblue", label=f"Model B: {model_B}")
    if save:
        plt.savefig(f"{FIGPTH}{model_A}_{model_B}_{t}_accs.png",
                     transparent=True)
    if plot:
        plt.xlabel("permutation accuracy")
        plt.legend()

    # Effect size: Cohen's d
    cohensd = pg.compute_effsize(x, y, eftype='cohen')

    # Confidence intervals (lower/upper bounds) for Cohen's d
    ci = pg.compute_esci(cohensd, x.shape[0], y.shape[0], decimals=3)
    # p-value
    pval = float(cnt + 1)/float(R + 1)

    analysis = TTestIndPower()
    # Obs needed to achieve given power w/ given significance level, effect size:
    sample_size = (analysis.solve_power(cohensd, power=power,
                                        nobs1=None, ratio=1.0,
                                        alpha=0.05)
                   if cohensd != 0 else 0)

    # Power w/ given significance, effect size, sample size
    power = (analysis.solve_power(cohensd, power=None,
                                  nobs1=n, ratio=1.0,
                                  alpha=0.05)
              if cohensd != 0 else 0)

    return pval, cohensd, sample_size, power, ci

def load_data(model, t, chkpt, experiment):
    """
        Load model predictions, labels.
        Combine into dataframe, convert to dictionary.
    """
    exp = f"experiment_{experiment}"

    model = f"{exp}_{model}" if not model in ["baseline", "baseline_filtered", "random"] else model
    pth = f"{MODEL_PTH}/{exp}/{model}/checkpoint-{chkpt}/{model}_{t}-{chkpt}"
    sentences = pd.DataFrame()
    try:
        labels = pd.read_csv(f"{pth}/labels.csv", header=None)
        preds = pd.read_csv(f"{pth}/preds.csv", header=None)
    except:
        npth = f"{MODEL_PTH}/{exp}/{model}/checkpoint-{chkpt}/_{t}-{chkpt}"
        labels = pd.read_csv(f"{npth}/labels.csv", header=None)
        preds = pd.read_csv(f"{npth}/preds.csv", header=None)
    labels.columns = ["labels"]
    preds.columns = ["preds"]
    sentences.columns = ["text"]
    data = pd.concat([labels, preds], axis=1)
    data = pd.DataFrame.to_dict(data, orient='list')
    n = len(data["labels"])
    assert n == len(data["preds"])
    return data, n, sentences

def get_size(cohensd):
    """Heuristic. More info:
         https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3444174/
    """
    v = "positive" if cohensd >= 0 else "negative"
    s = f" {v}"
    cohensd = np.abs(cohensd)
    if cohensd == 0:
        s = "Effect size 0 with 100% overlap"
    elif cohensd <= 0.2:
        # 0.2: 58th percentile
        s = f"Small {v} effect size with <= 85% overlap"
    elif 0.2 < cohensd <= 0.5:
        # 0.5: 69th percentile
        s = f"Medium {v} effect size with <= 53% overlap"
    elif 0.5 < cohensd <= 0.8:
        # 0.8: 79th percentile
        # group 2: a higher score than 79% of the people from group 1
        s = f"Large {v} effect size with <= 53% overlap"
    elif 0.8 < cohensd:
        if 1.0 <= cohensd < 1.5:
            # 84th percentile
            s = f"Very large {v} effect size with <= 45% overlap"
        elif 1.5 <= cohensd < 2.0:
            # 93rd percentile
            s = f"Very large {v} effect size with <= 29% overlap"
        elif 2.0 <= cohensd:
            # 97th percentile
            s = f"Very large {v} effect size with <= 19% overlap"
        else:
            s = f"Large {v} effect size with > 53% overlap"
    return s

def plot_deltas(model_A,
                model_B,
                deltas,
                delta_orig,
                t,
                save):
    """Show differences between samples' and original's test statistic across permutations."""
    delta_orig *= 1
    deltas = list(map(lambda x: x * 1, deltas))
    p1 = sns.distplot(deltas)
    plt.ylabel('Frequency')
    plt.xlabel('$\delta(X)$')
    h = sorted(p1.patches,
               key=lambda h: h.get_height())[-1].get_height()
    p1.text(delta_orig * .95,
            h * .25,
            s=f'$\delta(x)$ = {delta_orig:.3f}',
            rotation=0,
            horizontalalignment='center',
            verticalalignment='center')
    p1.axvline(x=delta_orig,
               ymax=.2,
               color='orange')
    if save:
        plt.savefig(f"{FIGPTH}{model_A}_{model_B}_{t}_deltas.png",
                    transparent=True)
    plt.show()


def f(model_A: str, model_B: str, rounds: int,
      t: str, run: bool, save: bool, chkpt_A: int, chkpt_B: int,
      experiment: str):
    if run:
        dph.children[1].value = False
        data_A, n, _ = load_data(model_A, t, chkpt_A, experiment)
        data_B, m, _ = load_data(model_B, t, chkpt_B, experiment)
        assert n == m
        print(f"\nLoaded {n} samples for each model.\n")
        # prob of finding effect when actual stat. sig. difference exists:
        power = 0.8
        pval, cohensd, sample_size, powerr, ci = permutation_test((model_A,
                                                                   data_A),
                                                                  (model_B,
                                                                   data_B),
                                                                   t, n, rounds, power,
                                                                   save)
        if (float(pval) <= float(0.05)):
            print(f"Significant (p-value: {pval:.3f})")
        else:
            print(f"\nNot significant (p-value: {pval})")

        if sample_size != 0:
            s = get_size(cohensd)
            print(f"\n{s}: {cohensd:.2f}.\n")
            print(f"CI (lower/upper): {ci}.\n")
            print(f"Power: {powerr * 100:.2f}%.")
            if powerr < power:
                print(f'\nNeeded sample size for power {power}: {int(sample_size)}.')
        else:
            print("100% overlap in results.")
