#!/usr/bin/env python3

import hydra
import importlib
import os
import torch
import transformers
import argparse
from pathlib import Path
import json

import logging

log = logging.getLogger()

from lm_polygraph.utils.manager import UEManager
from lm_polygraph.utils.dataset import Dataset
from lm_polygraph.utils.model import WhiteboxModel, create_ensemble
from lm_polygraph.utils.processor import Logger
from lm_polygraph.generation_metrics.accuracy import AccuracyMetric
from lm_polygraph.generation_metrics.bart_score import BartScoreSeqMetric
from lm_polygraph.generation_metrics.rouge import RougeMetric
from lm_polygraph.generation_metrics.bert_score import BertScoreMetric
from lm_polygraph.generation_metrics.sbert import SbertMetric
from lm_polygraph.estimators import *
from lm_polygraph.estimators.ensemble_token_measures import all_token_estimators
from lm_polygraph.estimators.ensemble_sequence_measures import all_ep_estimators, all_pe_estimators
from lm_polygraph.estimators.ensemble_token_measures import *
from lm_polygraph.ue_metrics import (
    ReversedPairsProportion,
    PredictionRejectionArea,
    RiskCoverageCurveAUC,
)

hydra_config = Path(os.environ["HYDRA_CONFIG"])


@hydra.main(
    version_base=None,
    config_path=str(hydra_config.parent),
    config_name=str(hydra_config.name),
)
def main(args):
    save_path = os.getcwd()
    log.info(f"Main directory: {save_path}")
    os.chdir(hydra.utils.get_original_cwd())

    save_path = args.save_path if "save_path" in args else save_path

    if args.seed is None or len(args.seed) == 0:
        args.seed = [1]

    model_kwargs = {}
    model_kwargs['use_auth_token'] = getattr(args, 'use_auth_token', None)

    cache_kwargs = {}
    if os.environ.get('HF_DATASETS_OFFLINE', '').strip() == '1':
        cache_kwargs = {'cache_dir': args.cache_path}

    device = args.device
    if device is None:
        device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"
        
    for seed in args.seed:
        log.info("=" * 100)
        log.info(f"SEED: {seed}")

        log.info(f"Loading model {args.model}...")
        transformers.set_seed(seed)

        model = WhiteboxModel.from_pretrained(
            args.model,
            device=device,
            **cache_kwargs,
            **model_kwargs
        )

        if args.ensemble:
            # Only MC-ensembles for now
            log.info(f"Creating ensemble...")
            ensemble_model = create_ensemble(model_paths=[args.model],
                                             mc=True,
                                             device='cpu',
                                             seed=args.seed[0],
                                             ensembling_mode=args.ensembling_mode,
                                             mc_seeds=args.mc_seeds,
                                             dropout_rate=float(args.dropout_rate),
                                             **cache_kwargs,
                                             **model_kwargs
                                             )
        else:
            ensemble_model = None

        log.info("Done with loading model.")

        log.info(f"Loading dataset {args.dataset}...")
        dataset = Dataset.load(
            args.dataset,
            args.text_column,
            args.label_column,
            batch_size=args.batch_size,
            prompt=args.prompt,
            split=args.eval_split,
            load_from_disk=args.load_from_disk,
            **cache_kwargs
        )

        if type(dataset.y[0]) is list:
            # We have several labels per input
            # Duplicate inputs for each label
            x_proc, y_proc = [], []

            for x, y in zip(dataset.x, dataset.y):
                for y_i in y:
                    x_proc.append(x)
                    y_proc.append(y_i)

            dataset = Dataset(x_proc, y_proc, args.batch_size)

        estimators = []
        estimators += get_ue_methods(args, model)
        density_based_ue_methods = get_density_based_ue_methods(args, model.model_type)
        estimators += density_based_ue_methods

        if any([not getattr(method, "is_fitted", False) for method in density_based_ue_methods]):
            if (args.train_dataset is not None) and (
                    args.train_dataset != args.dataset
            ):
                train_dataset = Dataset.load(
                    args.train_dataset,
                    args.text_column,
                    args.label_column,
                    batch_size=args.batch_size,
                    prompt=args.prompt,
                    split=args.train_split,
                    size=10_000,
                    load_from_disk=args.load_from_disk,
                    **cache_kwargs
                )
            elif args.train_test_split:
                X_train, X_test, y_train, y_test = dataset.train_test_split(
                    test_size=args.test_split_size, seed=seed, split=args.eval_split
                )
                train_dataset = Dataset(
                    x=X_train, y=y_train, batch_size=args.batch_size
                )
            else:
                train_dataset = Dataset.load(
                    args.dataset,
                    args.text_column,
                    args.label_column,
                    batch_size=args.batch_size,
                    prompt=args.prompt,
                    split=args.train_split,
                    size=10_000,
                    load_from_disk=args.load_from_disk,
                    **cache_kwargs
                )

            background_train_dataset = Dataset.load(
                args.background_train_dataset,
                args.background_train_dataset_text_column,
                args.background_train_dataset_label_column,
                batch_size=args.batch_size,
                data_files=args.background_train_dataset_data_files,
                split="train",
                size=100_000,
                load_from_disk=args.background_load_from_disk,
                **cache_kwargs
            )

            if args.subsample_train_dataset != -1:
                train_dataset.subsample(args.subsample_train_dataset, seed=seed)
            if args.subsample_background_train_dataset != -1:
                background_train_dataset.subsample(
                    args.subsample_background_train_dataset, seed=seed
                )
        else:
            train_dataset = None
            background_train_dataset = None

        if args.subsample_eval_dataset != -1:
            dataset.subsample(args.subsample_eval_dataset, seed=seed)

        log.info("Done with loading data.")

        generation_metrics = get_generation_metrics(args.generation_metrics)

        man = UEManager(
            dataset,
            model,
            estimators,
            generation_metrics,
            [
                ReversedPairsProportion(),
                PredictionRejectionArea(),
                RiskCoverageCurveAUC(),
            ],
            [
                Logger(),
            ],
            deberta_batch_size=getattr(args, 'deberta_batch_size', 10),
            train_data=train_dataset,
            ignore_exceptions=args.ignore_exceptions,
            background_train_data=background_train_dataset,
            max_new_tokens=args.max_new_tokens,
            ensemble_model=ensemble_model
        )

        man()

        man.save(save_path + f"/ue_manager_seed{seed}")


def get_density_based_ue_methods(args, model_type):
    estimators = []
    if args.use_density_based_ue:
        if getattr(args, 'parameters_path', False):
            parameters_path = args.parameters_path
        else:
            dataset_name = args.dataset if isinstance(args.dataset, str) else '_'.join(args.dataset)
            dataset_name = dataset_name.split("/")[-1].split(".")[0]
            model_name = args.model.split("/")[-1]
            parameters_path = f"{args.cache_path}/density_stats/{dataset_name}/{model_name}"
        
        if model_type == "Seq2SeqLM":
            estimators += [
                MahalanobisDistanceSeq("encoder", parameters_path=parameters_path),
                MahalanobisDistanceSeq("decoder", parameters_path=parameters_path),
                RelativeMahalanobisDistanceSeq(
                    "encoder", parameters_path=parameters_path
                ),
                RelativeMahalanobisDistanceSeq(
                    "decoder", parameters_path=parameters_path
                ),
                RDESeq("encoder", parameters_path=parameters_path),
                RDESeq("decoder", parameters_path=parameters_path),
                PPLMDSeq("encoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("encoder", md_type="RMD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="RMD", parameters_path=parameters_path),
            ]
        else:
            estimators += [
                MahalanobisDistanceSeq("decoder", parameters_path=parameters_path),
                RelativeMahalanobisDistanceSeq(
                    "decoder", parameters_path=parameters_path
                ),
                RDESeq("decoder", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="RMD", parameters_path=parameters_path),
            ]
    return estimators


def get_ue_methods(args, model):
    estimators = []
    if args.use_seq_ue:
        estimators += [
            MaximumSequenceProbability(),
            Perplexity(),
            MeanTokenEntropy(),
            MeanPointwiseMutualInformation(),
            MeanConditionalPointwiseMutualInformation(),
            PTrue(),
            PTrueSampling(),
            MonteCarloSequenceEntropy(),
            MonteCarloNormalizedSequenceEntropy(),
            LexicalSimilarity(metric="rouge1"),
            LexicalSimilarity(metric="rouge2"),
            LexicalSimilarity(metric="rougeL"),
            LexicalSimilarity(metric="BLEU"),
            NumSemSets(),
            EigValLaplacian(similarity_score="NLI_score", affinity="entail"),
            EigValLaplacian(similarity_score="NLI_score", affinity="contra"),
            EigValLaplacian(similarity_score="Jaccard_score"),
            DegMat(similarity_score="NLI_score", affinity="entail"),
            DegMat(similarity_score="NLI_score", affinity="contra"),
            DegMat(similarity_score="Jaccard_score"),
            Eccentricity(similarity_score="NLI_score", affinity="entail"),
            Eccentricity(similarity_score="NLI_score", affinity="contra"),
            Eccentricity(similarity_score="Jaccard_score"),
            SemanticEntropy(),
        ]

    if args.use_ens_ue:
        if not (model.model_type == "Seq2SeqLM"):
            raise NotImplementedError('Only Encoder-Decoder models can be ensembled at this time')

        token_measures = all_token_estimators()
        if args.ensembling_mode == 'pe':
            sequence_measures = all_pe_estimators()
        elif args.ensembling_mode == 'ep':
            sequence_measures = all_ep_estimators()
        else:
            raise ValueError(f'Ensemble type should be one of: "pe", "ep", but is {args.ens_type} instead')
        estimators += (token_measures + sequence_measures)

    if args.use_tok_ue:
        estimators += [
            MaximumTokenProbability(),
            TokenEntropy(),
            PointwiseMutualInformation(),
            ConditionalPointwiseMutualInformation(),
            SemanticEntropyToken(model.model_path, args.cache_path),
        ]

    additional_estimators = getattr(args, "additional_estimators", {})
    additional_estimators_kwargs = getattr(args, "additional_estimators_kwargs", {})

    for i, (module_name, estimator_classes) in enumerate(additional_estimators.items()):
        module = importlib.import_module(module_name)
        for j, estimator_class in enumerate(estimator_classes):
            try:
                estimator_kwargs = additional_estimators_kwargs[estimator_class]
            except KeyError:
                raise TypeError(f'Arguments for {estimator} were not passed')

            estimators.append(getattr(module, estimator_class)(**estimator_kwargs))

    return estimators


def get_generation_metrics(generation_metrics):
    if generation_metrics is None:
        return [
            RougeMetric("rouge1"),
            RougeMetric("rouge2"),
            RougeMetric("rougeL"),
            BartScoreSeqMetric("rh"),
            BertScoreMetric(),
            SbertMetric(),
            AccuracyMetric(),
        ]
    else:
        result = []
        for metric in generation_metrics:
            metric_name = metric["name"]
            metric_class = globals()[metric_name]
            result.append(metric_class(*metric.get("args", [])))
        return result


if __name__ == "__main__":
    main()
