# %%
import json

import numpy as np
from scipy import stats


# %%
# n_shuffles = num trials = 10
# alpha = 0.05
def run_dominance_test(results: dict, n_shuffles=10, alpha=0.05):
  # z_score = stats.norm.ppf(1 - alpha / 2)

  U = 0
  w_vals = []
  for prompt_id in results:
    w = 0
    kl_orig_arr = np.array(results[prompt_id]["orig"])
    kl_opt_arr = np.array(results[prompt_id]["optim"])
    w = np.sum(kl_opt_arr < kl_orig_arr) / n_shuffles
    w_vals.append(w)
    a = 1 if w > 0.5 else 0 if w < 0.5 else 0.5
    U += a / len(results)
  U = 1 if U > 1.0 else U  # floating point

  # U CIs
  result = stats.binomtest(
    int(U * len(results)),
    n=len(results),
    p=0.5,
    alternative="two-sided",  # "greater",
  )
  ci = result.proportion_ci(confidence_level=0.95, method="wilson")
  U_upper = ci.high
  U_lower = ci.low

  # w CIs
  w_mean = np.mean(w_vals)
  result = stats.binomtest(
    int(w_mean * len(w_vals)),
    n=len(w_vals),
    p=0.5,
    alternative="two-sided",  # "greater",
  )
  ci = result.proportion_ci(confidence_level=0.95, method="wilson")
  w_lower = ci.low
  w_upper = ci.high

  return {
    "U": U,
    "U_lower": U_lower,
    "U_upper": U_upper,
    "w_avg": w_mean,
    "w_lower": w_lower,
    "w_upper": w_upper,
  }


# %%
all_res = json.load(open("shuffle_results_gemma_COLD.json", "r"))
# all_res
# %%
latex_output = ""
for model_name in all_res:
  if "pythia" in model_name:
    continue

  cur = run_dominance_test(results=all_res[model_name], n_shuffles=10, alpha=0.05)
  latex_output += f"{model_name} & ${cur['U']:.2f}$ (${cur['U_lower']:.2f}$, ${cur['U_upper']:.2f}$) & ${cur['w_avg']:.2f}$ (${cur['w_lower']:.2f}$, ${cur['w_upper']:.2f}$)"

  latex_output += " \\\\" + "\n"
print(latex_output)

# %%
print(latex_output)

# %%
