import functools
from importlib import reload

import ipdb
import numpy as np
import pandas as pd
import persist_to_disk as ptd
import tqdm

import dataeval.load as dload
import pipeline.summ as summ


class CaseStudy(summ.UQ_summ):
    def __init__(
        self,
        path,
        clean=True,
        split=None,
        cal_size: int = 1000,
        seed=0,
        setting="mlga~neg_mlgc",
        curve="auroc",
        num_gen=5,
        topk=10,
        acc_name="moe|acc",
        next_token=False,
    ) -> None:
        super().__init__(path, clean, split, cal_size, seed)
        if next_token:
            self.attn_token_level = dload.read_attn_loglikelihoods_all_next_token(
                self.path, clean=self.key[1], readonly=True
            )
        else:
            self.attn_token_level = dload.read_attn_loglikelihoods_all(
                self.path, clean=self.key[1], readonly=True
            )
        self.attn_name = "attnnll" + ("_nexttoken" if next_token else "")

        self.setting = setting
        self.curve = curve
        self.num_gen = num_gen
        self.topk = topk
        self.acc_name = acc_name

    def get_tuned_uq(self, name):
        return super().get_tuned_uq(
            name,
            setting=self.setting,
            curve=self.curve,
            num_gens=self.num_gen,
            acc_name=self.acc_name,
        )

    @functools.cached_property
    def layer_heads(self):
        perf = (
            self.tune_cal_obj._param_to_perfs(
                self.attn_name,
                setting=self.setting,
                acc_name=self.acc_name,
                num_gens=self.num_gen,
                curve=self.curve,
            )
            .sort_values()
            .iloc[-self.topk :]
        )
        return [eval(_)["layer_heads"] for _ in perf.index]

    @functools.lru_cache()
    def summ_df(self, gen_i="mlg", ref="nll|norm", x=None):
        if x is None:
            x = f"{self.attn_name}@10"
        text_key = "text_cleaned" if self.key[1] else "text"
        if gen_i in {"mlg", "most_likely_generation"}:
            df = {
                _: (-self.get_tuned_uq(_)["neg_mlgc"]) for _ in [x, ref]
            }  # higher = more confident
            df["acc"] = self.get_acc()["mlga"]
            df["text"] = [
                _["most_likely_generation"][text_key] for _ in self.generations
            ]
        else:
            df = {
                _: (-self.get_tuned_uq(_)["neg_ic"][gen_i]) for _ in [x, ref]
            }  # higher = more confident
            df["acc"] = self.get_acc()["ia"][gen_i]
            df["text"] = [_["generations"][text_key][gen_i] for _ in self.generations]
        for _ in [x, ref]:
            df[f"{_}_rank"] = df[_].rank(pct=True)
        df = pd.DataFrame(df)

        df["diff"] = df[x] - df[ref]
        df["diff_rank"] = df[f"{x}_rank"] - df[f"{ref}_rank"]

        df["question"] = [_["question"] for _ in self.generations]
        df["answer"] = [_["answer"] for _ in self.generations]
        df["len"] = df["text"].map(len)

        return df

    def get_weight(self, idx, gen_i="mlg"):
        if gen_i == "mlg":
            gen_i = "most_likely_generation"
        curr = self.attn_token_level[idx]
        df = curr["attn_loglikelihoods"][curr["mapping"][gen_i]]
        weighted = df.reindex(columns=self.layer_heads).mean(1)
        return weighted

    def read_gen(
        self,
        idx,
        gen_i="mlg",
        ref="nll|norm",
        layer_head=None,
        show_all_heads=False,
    ):
        assert idx in self.ids
        assert isinstance(gen_i, int) or gen_i in {"most_likely_generation", "mlg"}
        if gen_i == "mlg":
            gen_i = "most_likely_generation"
        curr = self.attn_token_level[idx]
        df = curr["attn_loglikelihoods"][curr["mapping"][gen_i]]

        base = df["token_nll"]
        if layer_head is None:
            weighted = df.reindex(columns=self.layer_heads).mean(1)
        else:
            weighted = df[layer_head]  # for gemma
        if show_all_heads:
            print(df.reindex(columns=self.layer_heads))

        equal_weight = np.ones_like(weighted) / len(weighted)
        diff_weight = weighted - equal_weight
        diff = diff_weight * (-base)
        ret = pd.DataFrame(
            {
                "Dweight (%)": diff_weight * 100,
                "diff": diff,
                "new": weighted,
                "old": equal_weight,
                "logits": -base,
            }
        )

        # add prompt etc
        curr_summ = self.summ_df(gen_i, ref=ref).loc[idx]

        print(f"""Question: {curr_summ['question']}

    Answer: {curr_summ['answer']}

    Response: {curr_summ['text']}

    Acc: {curr_summ['acc']}

    Diff Conf Rank (+): {curr_summ[f'{self.attn_name}@10_rank'] - curr_summ[f'{ref}_rank']:.3f}
    Diff (+): {diff.sum():.3f} ({curr_summ[f'{self.attn_name}@10']:.3f} - {curr_summ[ref]:.3f})

    Diff Weights (%):
    {ret}
    """)
        return None


@ptd.persistf()
def get_all_corrs(
    path,
    gen_i="mlg",
    clean=True,
    split="test",
    cal_size: int = 1000,
    seed=0,
    setting="mlga~neg_mlgc",
    curve="auroc",
    num_gen=5,
    topk=10,
    acc_name="moe|acc",
):
    obj = CaseStudy(
        path, clean, split, cal_size, seed, setting, curve, num_gen, topk, acc_name
    )
    obj_next = CaseStudy(
        path,
        clean,
        split,
        cal_size,
        seed,
        setting,
        curve,
        num_gen,
        topk,
        acc_name,
        next_token=True,
    )
    ret = {}
    for idx in tqdm.tqdm(obj.ids):
        ret[idx] = pd.DataFrame(
            {
                "old": obj.get_weight(idx, gen_i=gen_i),
                "new": obj_next.get_weight(idx, gen_i=gen_i),
            }
        )
    return ret


if __name__ == "__main__":
    from _settings import GEN_PATHS

    for dataset in ["nq_open_new", "triviaqa_new", "coqa_new"]:
        for model in ["llama2-13b", "gemma-7b", "mistral-7b"]:
            for temp in [0.5]:
                get_all_corrs(GEN_PATHS[temp][dataset][model], acc_name="moe|acc")
