#!/usr/bin/env python3

import hydra
import importlib
import os
import torch
import transformers
import argparse
from pathlib import Path
import json

import logging

log = logging.getLogger()

from lm_polygraph.utils.manager import UEManager
from lm_polygraph.utils.dataset import Dataset
from lm_polygraph.utils.model import WhiteboxModel
from lm_polygraph.utils.processor import Logger
from lm_polygraph.generation_metrics.accuracy import AccuracyMetric
from lm_polygraph.generation_metrics.bart_score import BartScoreSeqMetric
from lm_polygraph.generation_metrics.model_score import (
    ModelScoreSeqMetric,
    ModelScoreTokenwiseMetric,
)
from lm_polygraph.generation_metrics.rouge import RougeMetric
from lm_polygraph.estimators import *
from lm_polygraph.ue_metrics import (
    ReversedPairsProportion,
    PredictionRejectionArea,
    RiskCoverageCurveAUC,
)

DENSITY_BASED_ESTIMATORS = [
    "MahalanobisDistanceSeq",
    "RelativeMahalanobisDistanceSeq",
    "RDESeq",
    "PPLMDSeq",
]

hydra_config = Path(os.environ["HYDRA_CONFIG"])


@hydra.main(
    version_base=None,
    config_path=str(hydra_config.parent),
    config_name=str(hydra_config.name),
)
def main(args):
    save_path = os.getcwd()
    log.info(f"Main directory: {save_path}")
    os.chdir(hydra.utils.get_original_cwd())

    save_path = args.save_path if "save_path" in args else save_path

    if args.seed is None or len(args.seed) == 0:
        args.seed = [1]
        
    model_kwargs = {}
    model_kwargs['use_auth_token'] = getattr(args, 'use_auth_token', None)

    cache_kwargs = {}
    if os.environ.get('HF_DATASETS_OFFLINE', '').strip() == '1':
        cache_kwargs = {'cache_dir': args.cache_dir}

    device = args.device
    if device is None:
        device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"

    for seed in args.seed:
        log.info("=" * 100)
        log.info(f"SEED: {seed}")

        log.info(f"Loading model {args.model}...")
        transformers.set_seed(seed)
        model = WhiteboxModel.from_pretrained(
            args.model,
            device=device,
            **cache_kwargs,
            **model_kwargs
        )
        log.info("Done with loading model.")

        log.info(f"Loading dataset {args.dataset}...")
        dataset = Dataset.load(
            args.dataset,
            args.text_column,
            args.label_column,
            batch_size=args.batch_size,
            prompt=args.prompt,
            split=args.eval_split,
            load_from_disk=args.load_from_disk,
            **cache_kwargs
        )

        if type(dataset.y[0]) is list:
            # We have several labels per input
            # Duplicate inputs for each label
            x_proc, y_proc = [], []

            for x, y in zip(dataset.x, dataset.y):
                for y_i in y:
                    x_proc.append(x)
                    y_proc.append(y_i)

            dataset = Dataset(x_proc, y_proc, args.batch_size)

        estimators = []
        estimators += get_ue_methods(args, model)
        estimators += get_density_based_ue_methods(args, model.model_type)

        if any([str(estimator).split('_')[0] in DENSITY_BASED_ESTIMATORS for estimator in estimators]):
            if (args.train_dataset is not None) and (
                    args.train_dataset != args.dataset
            ):
                train_dataset = Dataset.load(
                    args.train_dataset,
                    args.text_column,
                    args.label_column,
                    batch_size=args.batch_size,
                    prompt=args.prompt,
                    split=args.train_split,
                    size=10_000,
                    load_from_disk=args.load_from_disk,
                    **cache_kwargs
                )
            elif args.train_test_split:
                X_train, X_test, y_train, y_test = dataset.train_test_split(
                    test_size=args.test_split_size, seed=seed, split=args.eval_split
                )
                train_dataset = Dataset(
                    x=X_train, y=y_train, batch_size=args.batch_size
                )
            else:
                train_dataset = Dataset.load(
                    args.dataset,
                    args.text_column,
                    args.label_column,
                    batch_size=args.batch_size,
                    prompt=args.prompt,
                    split=args.train_split,
                    size=10_000,
                    load_from_disk=args.load_from_disk,
                    **cache_kwargs
                )

                background_train_dataset = Dataset.load(
                    args.background_train_dataset,
                    args.background_train_dataset_text_column, 
                    args.background_train_dataset_label_column,
                    batch_size=args.batch_size,
                    data_files=args.background_train_dataset_data_files,
                    split="train",
                    size=100_000,
                    load_from_disk=args.background_load_from_disk,
                    **data_kwargs
                )

                if args.subsample_train_dataset != -1:
                    train_dataset.subsample(args.subsample_train_dataset, seed=seed)
                if args.subsample_background_train_dataset != -1:
                    background_train_dataset.subsample(
                        args.subsample_background_train_dataset, seed=seed
                    )
        else:
            train_dataset = None
            background_train_dataset = None
            density_based_ue = []

        if args.subsample_eval_dataset != -1:
            dataset.subsample(args.subsample_eval_dataset, seed=seed)

        log.info("Done with loading data.")

        man = UEManager(
            dataset,
            model,
            estimators,
            [
                RougeMetric("rouge1"),
                RougeMetric("rouge2"),
                RougeMetric("rougeL"),
                BartScoreSeqMetric("rh"),
                AccuracyMetric(),
            ],
            [
                ReversedPairsProportion(),
                PredictionRejectionArea(),
                RiskCoverageCurveAUC(),
            ],
            [
                Logger(),
            ],
            deberta_batch_size=getattr(args, 'deberta_batch_size', 10),
            train_data=train_dataset,
            ignore_exceptions=args.ignore_exceptions,
            background_train_data=background_train_dataset,
            max_new_tokens=args.max_new_tokens,
        )

        man()

        man.save(save_path + f"/ue_manager_seed{seed}")


def get_density_based_ue_methods(args, model_type):
    estimators = []
    if args.use_density_based_ue:
        dataset_name = args.dataset if isinstance(args.dataset, str) else '_'.join(args.dataset)
        dataset_name = dataset_name.split("/")[-1].split(".")[0]
        model_name = args.model.split("/")[-1]
        parameters_path = f"{args.cache_path}/density_stats/{dataset_name}/{model_name}"

        if model_type == "Seq2SeqLM":
            estimators += [
                MahalanobisDistanceSeq("encoder", parameters_path=parameters_path),
                MahalanobisDistanceSeq("decoder", parameters_path=parameters_path),
                RelativeMahalanobisDistanceSeq(
                    "encoder", parameters_path=parameters_path
                ),
                RelativeMahalanobisDistanceSeq(
                    "decoder", parameters_path=parameters_path
                ),
                RDESeq("encoder", parameters_path=parameters_path),
                RDESeq("decoder", parameters_path=parameters_path),
                PPLMDSeq("encoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("encoder", md_type="RMD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="RMD", parameters_path=parameters_path),
            ]
        else:
            estimators += [
                MahalanobisDistanceSeq("decoder", parameters_path=parameters_path),
                RelativeMahalanobisDistanceSeq(
                    "decoder", parameters_path=parameters_path
                ),
                RDESeq("decoder", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="RMD", parameters_path=parameters_path),
            ]
    return estimators


def get_ue_methods(args, model):
    estimators = []
    if args.use_seq_ue:
        estimators += [
            MaximumSequenceProbability(),
            Perplexity(),
            MeanTokenEntropy(),
            MeanPointwiseMutualInformation(),
            MeanConditionalPointwiseMutualInformation(),
            PTrue(),
            PTrueSampling(),
            MonteCarloSequenceEntropy(),
            MonteCarloNormalizedSequenceEntropy(),
            LexicalSimilarity(metric="rouge1"),
            LexicalSimilarity(metric="rouge2"),
            LexicalSimilarity(metric="rougeL"),
            LexicalSimilarity(metric="BLEU"),
            NumSemSets(),
            EigValLaplacian(similarity_score="NLI_score", affinity="entail"),
            EigValLaplacian(similarity_score="NLI_score", affinity="contra"),
            EigValLaplacian(similarity_score="Jaccard_score"),
            DegMat(similarity_score="NLI_score", affinity="entail"),
            DegMat(similarity_score="NLI_score", affinity="contra"),
            DegMat(similarity_score="Jaccard_score"),
            Eccentricity(similarity_score="NLI_score", affinity="entail"),
            Eccentricity(similarity_score="NLI_score", affinity="contra"),
            Eccentricity(similarity_score="Jaccard_score"),
            SemanticEntropy(),
        ]
    if args.use_tok_ue:
        estimators += [
            MaximumTokenProbability(),
            TokenEntropy(),
            PointwiseMutualInformation(),
            ConditionalPointwiseMutualInformation(),
            SemanticEntropyToken(model.model_path, args.cache_path),
        ]

    additional_estimators = getattr(args, "additional_estimators", {})
    additional_estimators_kwargs = getattr(args, "additional_estimators_kwargs", {})

    for i, (module_name, estimator_classes) in enumerate(additional_estimators.items()):
        module = importlib.import_module(module_name)
        for j, estimator_class in enumerate(estimator_classes):
            try:
                estimator_kwargs = additional_estimators_kwargs[estimator_class]
            except KeyError:
                raise TypeError(f'Arguments for {estimator} were not passed')

            estimators.append(getattr(module, estimator_class)(**estimator_kwargs))

    return estimators


if __name__ == "__main__":
    main()
