from os.path import join
from run_tag import TagRunner
import os
import sys
import util
import pickle
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from seqeval.metrics.sequence_labeling import get_entities


class Uncertainty:
    def __init__(self, config_name, saved_suffix, gpu_id, mc=10):
        self.saved_suffix = saved_suffix
        self.runner = TagRunner(config_name, gpu_id)
        self.model = None  # Lazy loading
        self.mc = mc

        self.output_dir = join(self.runner.config['log_dir'], 'uncertainty', saved_suffix)
        os.makedirs(self.output_dir, exist_ok=True)

    def _get_predictions(self, dataset_name, lang, use_un_probs):
        other = '_nsp' if not use_un_probs else ''
        save_path = join(self.output_dir, f'pred_{dataset_name}_{lang}{other}.bin')
        if os.path.exists(save_path):
            with open(save_path, 'rb') as f:
                all_results, golds = pickle.load(f)
            return all_results, golds

        # Get results
        if self.model is None:
            self.model = self.runner.initialize_model(saved_suffix)
        all_results = []
        dataset = self.runner.data.get_data(dataset_name, 'test', lang, only_dataset=True)
        for mc_i in range(self.mc):
            metrics, _, golds, probs, logits = self.runner.evaluate(self.model, dataset, dropout=True,
                                                                    dataset_name=dataset_name, lang=lang,
                                                                    use_un_probs=use_un_probs)
            all_results.append((metrics, probs, logits))

        with open(save_path, 'wb') as f:
            pickle.dump((all_results, golds), f)
        return all_results, golds

    def _get_metrics(self, labels, un):
        auroc = [roc_auc_score(labels, -un[i]) for i in range(len(un))]
        aupr = [average_precision_score(labels, -un[i]) for i in range(len(un))]
        return auroc, aupr

    def _get_entity_uncertainty(self, all_results, golds):
        label_map = {i: label for i, label in enumerate(self.runner.data.get_labels(dataset_name))}

        probs_mc = [np.stack(inst_probs_mc, axis=-1) for inst_probs_mc in zip(*[probs for _, probs, _ in all_results])]  # [num inst, num tokens, num classes, mc]
        probs_mean = [inst_probs_mc.mean(axis=-1, keepdims=False) for inst_probs_mc in probs_mc]  # [num inst, num tokens, num classes]

        # Get raw tag labels and entities from prediction and gold
        final_labels = [[label_map[i] for i in inst_probs_mean.argmax(axis=-1)] for inst_probs_mean in probs_mean]  # [num inst, num tokens]
        pred_entities = [get_entities(inst_labels) for inst_labels in final_labels]  # [num inst, num entities]
        gold_entities = [get_entities(inst_golds) for inst_golds in golds]

        # Obtain entities and labels for uncertainty
        un_entities, un_labels = [], []
        for inst_i, (inst_pred_entities, inst_gold_entities) in enumerate(zip(pred_entities, gold_entities)):
            inst_pred_entities, inst_gold_entities = set(inst_pred_entities), set(inst_gold_entities)
            tp_entities = inst_pred_entities & inst_gold_entities
            fp_entities = inst_pred_entities - tp_entities
            fn_entities = inst_gold_entities - tp_entities
            # # Avoid double dipping fp&fn (same span but wrong type); not using for now to keep eval consistency
            # overlap_spans = {(e[1], e[2]) for e in fp_entities} & {(e[1], e[2]) for e in fn_entities}
            # fp_entities = {e for e in fp_entities if (e[1], e[2]) not in overlap_spans}
            # fn_entities = {e for e in fn_entities if (e[1], e[2]) not in overlap_spans}
            # wt_entities = {}

            # Assign pos label for true-pos entities (correctly predicted labels)
            un_entities += [(inst_i, (e[1], e[2]), e[0], 'tp') for e in tp_entities]
            un_labels += [1] * len(tp_entities)
            # Assign neg label for non true-pos entities (wrongly predicted labels)
            un_entities += [(inst_i, (e[1], e[2]), e[0], 'fp') for e in fp_entities]
            un_labels += [0] * len(fp_entities)
            un_entities += [(inst_i, (e[1], e[2]), e[0], 'fn') for e in fn_entities]
            un_labels += [0] * len(fn_entities)

        # Get uncertainty per entity
        def compute_span_uncertainty(probs_mc, probs_mean, un_entity):
            inst_i, (s_i, e_i), _, _ = un_entity
            assert e_i < probs_mean[inst_i].shape[0]
            tok_probs_mc = probs_mc[inst_i][s_i:e_i+1, :, :]  # [num tok, num classes, mc]
            tok_probs_mean = probs_mean[inst_i][s_i:e_i+1, :]  # [num tok, num classes]

            ent_of_exp_per_cls = util.compute_entropy(tok_probs_mean)  # [num tok, num classes]
            ent_of_exp_total = ent_of_exp_per_cls.sum(axis=-1, keepdims=False)  # [num tok]

            ent_per_cls = util.compute_entropy(tok_probs_mc)  # [num tok, num classes, mc]
            ent_total = ent_per_cls.sum(axis=1, keepdims=False)  # [num tok, mc]
            exp_of_ent_per_cls = ent_per_cls.mean(axis=-1, keepdims=False)  # [num tok, num classes]
            exp_of_ent_total = ent_total.mean(axis=-1, keepdims=False)  # [num tok]

            diff_per_cls = ent_of_exp_per_cls - exp_of_ent_per_cls  # [num tok, num classes]
            diff_total = diff_per_cls.sum(axis=-1, keepdims=False)  # [num tok]

            # Take avg token uncertainty as span uncertainty
            aleatoric = exp_of_ent_total.mean().item()
            epistemic = diff_total.mean().item()
            entropy = ent_of_exp_total.mean().item()
            return aleatoric, epistemic, entropy

        entities_aleatoric, entities_epistemic, entities_entropy = [], [], []
        for un_entity in un_entities:
            aleatoric, epistemic, entropy = compute_span_uncertainty(probs_mc, probs_mean, un_entity)
            entities_aleatoric.append(aleatoric)
            entities_epistemic.append(epistemic)
            entities_entropy.append(entropy)

        return (np.array(entities_aleatoric), np.array(entities_epistemic), np.array(entities_entropy)), \
               un_entities, un_labels, probs_mean

    def get_all_entity_uncertainty(self, dataset_name, use_un_probs=None):
        all_auroc, all_aupr = [], []
        for lang in util.langs[dataset_name]:
            all_results, golds = self._get_predictions(dataset_name, lang, use_un_probs=use_un_probs)
            un, un_entities, un_labels, probs_mean = self._get_entity_uncertainty(all_results, golds)

            auroc, aupr = self._get_metrics(un_labels, un)
            all_auroc.append(auroc)
            all_aupr.append(aupr)
            print(f'----------Uncertainties for {dataset_name}-{lang}')
            print(f'AUROC: Aleatoric = {auroc[0]:.4f}, Epistemic = {auroc[1]:.4f}, Entropy = {auroc[2]:.4f}')
            print(f'AUPR:  Aleatoric = {aupr[0]:.4f}, Epistemic = {aupr[1]:.4f}, Entropy = {aupr[2]:.4f}')

        print('-' * 20)
        print('\n'.join(util.print_all_scores([auroc[2] for auroc in all_auroc], 'AUROC on entropy', with_en=True)))
        print('-' * 20)
        print('\n'.join(util.print_all_scores([aupr[2] for aupr in all_aupr], 'AUPR on entropy', with_en=True)))

    def get_entity_examples(self, dataset_name, lang, top_k, use_un_probs=None):
        all_results, golds = self._get_predictions(dataset_name, lang, use_un_probs=use_un_probs)
        un, un_entities, _, probs_mean = self._get_entity_uncertainty(all_results, golds)

        # Keep examples and features consistent
        examples, _, _ = self.runner.data.get_data(dataset_name, 'test', lang, only_dataset=False)
        words, inst_i = [], 0
        for example in examples:
            if len(example.words) == probs_mean[inst_i].shape[0]:
                words.append(example.words)
                inst_i += 1
            else:
                while example.words:
                    words.append(example.words[:probs_mean[inst_i].shape[0]])
                    example.words = example.words[probs_mean[inst_i].shape[0]:]
                    inst_i += 1
        assert inst_i == len(probs_mean)

        # Select top uncertain entities
        un_entities = [entity + (ent,) for ent, entity in zip(un[-1].tolist(), un_entities)]
        top_entities = sorted(un_entities, key=lambda entity: entity[-1], reverse=True)[:top_k]

        # Print
        for entity in top_entities:
            inst_i, (s_i, e_i), entity_type, label_type, entropy = entity
            inst_words = words[inst_i]
            assert e_i < len(inst_words)
            print(' '.join(inst_words))
            print(' '.join(inst_words[s_i:e_i+1]), ':', f'GOLD-{entity_type}' if label_type in ['tp', 'fn'] else 'GOLD-NONE',
                  f'PRED-{entity_type}' if label_type in ['tp', 'fp'] else 'PRED-NONE', f'Entropy: {entropy:.4f}')
            print()

    def _get_tok_uncertainty(self, all_probs, mean_probs=None):
        if mean_probs is None:
            mean_probs = all_probs.mean(axis=-1, keepdims=False)

        ent_of_exp_per_cls = util.compute_entropy(mean_probs)  # [total tok, num classes]
        ent_of_exp_total = ent_of_exp_per_cls.sum(axis=-1, keepdims=False)

        ent_per_cls = util.compute_entropy(all_probs)  # [total tok, num classes, mc]
        ent_total = ent_per_cls.sum(axis=1, keepdims=False)  # [total tok, mc]
        exp_of_ent_per_cls = ent_per_cls.mean(axis=-1, keepdims=False)  # [total tok, num classes]
        exp_of_ent_total = ent_total.mean(axis=-1, keepdims=False)

        diff_per_cls = ent_of_exp_per_cls - exp_of_ent_per_cls  # [total tok, num classes]
        diff_total = diff_per_cls.sum(axis=-1, keepdims=False)

        return exp_of_ent_total, diff_total, ent_of_exp_total

    def get_all_tok_uncertainty(self, dataset_name, use_un_probs=None):
        label_map = {label: i for i, label in enumerate(self.runner.data.get_labels(dataset_name))}
        for lang in util.langs[dataset_name]:
            all_results, golds = self._get_predictions(dataset_name, lang, use_un_probs=use_un_probs)

            # Flatten results
            all_probs = np.stack([np.concatenate(probs, axis=0) for _, probs, _ in all_results], axis=-1)  # [total tok, num classes, mc]
            all_golds = np.array([label_map[gold] for gold in util.flatten(golds)])  # [total tok]
            assert all_probs.shape[0] == all_golds.shape[0]

            # Get final labels
            mean_probs = all_probs.mean(axis=-1, keepdims=False)
            final_labels = mean_probs.argmax(axis=-1) == all_golds  # [total tok]

            # Get uncertainties
            exp_of_ent_total, diff_total, ent_of_exp_total = self._get_tok_uncertainty(all_probs, mean_probs)

            # Get metrics
            un = (exp_of_ent_total, diff_total, ent_of_exp_total)  # 'aleatoric', 'epistemic', 'entropy'
            auroc, aupr = self._get_metrics(final_labels, un)
            print(f'----------Uncertainties for {dataset_name}-{lang}')
            print(f'AUROC: Aleatoric = {auroc[0]:.4f}, Epistemic = {auroc[1]:.4f}, Entropy = {auroc[2]:.4f}')
            print(f'AUPR:  Aleatoric = {aupr[0]:.4f}, Epistemic = {aupr[1]:.4f}, Entropy = {aupr[2]:.4f}')

    def get_tok_examples(self, dataset_name, lang, top_k, use_un_probs=None):
        label_map = {i: label for i, label in enumerate(self.runner.data.get_labels(dataset_name))}
        all_results, _ = self._get_predictions(dataset_name, lang, use_un_probs=use_un_probs)

        all_probs = np.stack([np.concatenate(probs, axis=0) for _, probs, _ in all_results], axis=-1)  # [total tok, num classes, mc]
        mean_probs = all_probs.mean(axis=-1, keepdims=False)
        predicted_labels = mean_probs.argmax(axis=-1)

        exp_of_ent_total, diff_total, ent_of_exp_total = self._get_tok_uncertainty(all_probs, mean_probs)
        argsort_ent_of_exp_total = np.argsort(ent_of_exp_total, axis=-1)
        un_tok_i = np.sort(argsort_ent_of_exp_total[-top_k:], axis=-1)

        examples, _, _ = self.runner.data.get_data(dataset_name, 'test', lang, only_dataset=False)
        tok_i, un_tok_offset = 0, 0
        for example in examples:
            if un_tok_offset == len(un_tok_i):
                break
            if len(example.words) + tok_i - 1 < un_tok_i[un_tok_offset]:
                tok_i += len(example.words)
                continue

            for word, label in zip(example.words, example.labels):
                if un_tok_offset < len(un_tok_i) and tok_i == un_tok_i[un_tok_offset]:
                    print(f'{word}\t{label}\t{"  ".join([f"{label_map[i]}: {prob:.4f}" for i, prob in enumerate(mean_probs[tok_i])])}', end='\t|\t')
                    un_tok_offset += 1
                else:
                    print(f'{word}\t{label}\t{label_map[predicted_labels[tok_i]]}', end='\t|\t')
                print(f'Aleatoric: {exp_of_ent_total[tok_i]:.2e}\tEpistemic: {diff_total[tok_i]:.2e}\tEntropy: {ent_of_exp_total[tok_i]:.2e}\t')
                tok_i += 1
            print()


if __name__ == '__main__':
    config_name, saved_suffix, gpu_id, dataset_name = sys.argv[1], sys.argv[2], int(sys.argv[3]), sys.argv[4]
    uncertainty = Uncertainty(config_name, saved_suffix, gpu_id)

    # uncertainty.get_all_uncertainty(dataset_name)
    # uncertainty.get_tok_examples(dataset_name, 'en', 100)

    uncertainty.get_all_entity_uncertainty(dataset_name, use_un_probs=True)
    # uncertainty.get_entity_examples(dataset_name, 'en', 300, use_un_probs=True)
