from pathlib import Path
from typing import Type

import click
from pydantic import BaseModel

from align.app import load_config, setup_logging
from align.config.app_config import AppConfig, DatasetName
from align.exploration.compute_datasets_stats import compute_datasets_stats
from align.exploration.deepl_translate import (
    translate_LIMA_all_languages,
    translate_MT_bench_all_languages,
    translate_MT_bench_ref_answers_languages,
    translate_MTbench_judge_prompts,
)
from align.exploration.format_to_fastchat import format_bactrianx_to_fastchat_and_create_val
from align.exploration.format_to_lima import format_to_lima
from align.exploration.language_detection import (
    _annotate_lima_dataset,
    detect_language,
    detect_language_all,
)
from align.preprocess.dataset_creation import (
    create_multi_lingual_lima_datasets_for_fastchat,
    create_multi_lingual_stackexchange_val_sets,
)
from align.utils import PROJECT_ROOT, jsonl_to_csv, lima_csv_to_jsonl, lima_val_csv_to_jsonl


@click.group()
def main() -> None:
    pass


@main.group()
def explore() -> None:
    pass


@main.group()
def preprocess() -> None:
    pass


@main.group()
def util() -> None:
    pass


config_option = click.option(
    "--config_dir_path",
    type=Path,
    default="align/config_files/default",
    help="Path to a folder with all YAML config files.",
)


def _setup_entry_point(
    config_dir_path: Path,
    relative_config_file_path: str,
    class_type: Type[BaseModel],
) -> BaseModel:
    setup_logging(Path(config_dir_path) / "logging.yaml")
    config = load_config(
        config_path=Path(config_dir_path) / relative_config_file_path,
        class_type=class_type,
    )
    return config


@util.command(name="lima_csv_to_jsonl")
@config_option
def run_lima_csv_to_jsonl(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    lima_csv_to_jsonl(PROJECT_ROOT / config.data.dataset_name_to_path[DatasetName.lima])


@util.command(name="lima_val_csv_to_jsonl")
@config_option
def run_lima_val_csv_to_jsonl(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    lima_val_csv_to_jsonl(PROJECT_ROOT / config.data.dataset_name_to_path[DatasetName.lima])


@util.command(name="jsonl_to_csv")
@config_option
def run_jsonl_to_csv(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    jsonl_to_csv(PROJECT_ROOT / config.data.dataset_name_to_path[config.data.dataset_name])


@util.command(name="format_to_lima")
@config_option
def run_format_to_lima(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    format_to_lima(config.data.dataset_name_to_path)


@explore.command(name="compute_datasets_stats")
@config_option
def run_compute_datasets_stats(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    compute_datasets_stats(config=config)


@util.command(name="format_bactrianx_to_fastchat_and_create_val")
@config_option
def run_format_bactrianx_to_fastchat_and_create_val(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    format_bactrianx_to_fastchat_and_create_val(config.data.dataset_name_to_path[DatasetName.bactrianx])


@explore.command(name="detect_language")
@config_option
def run_detect_language(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    detect_language(
        dataset_name=config.data.dataset_name,
        dataset_path=config.data.dataset_name_to_path[config.data.dataset_name.value],
    )


@explore.command(name="detect_language_all")
@config_option
def run_detect_all_languages(config_dir_path: Path):
    config = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    detect_language_all(config=config)


@preprocess.command(name="annotate_lima_dataset")
@config_option
def run_annotate_lima_dataset(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    _annotate_lima_dataset(PROJECT_ROOT / config.data.dataset_name_to_path[DatasetName.lima])


@preprocess.command(name="translate_LIMA_all")
@config_option
def run_translate_LIMA_all(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    translate_LIMA_all_languages(
        dataset_name=DatasetName.lima,
        dataset_path=PROJECT_ROOT / config.data.dataset_name_to_path[DatasetName.lima],
        deepl_api_key=config.translation.api_key,
        target_lang_codes=config.translation.target_lang_codes,
    )


@preprocess.command(name="translate_MTbench_all")
@config_option
def run_translate_MTbench_all(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    translate_MT_bench_all_languages(
        dataset_name=DatasetName.mt_bench,
        dataset_path=PROJECT_ROOT / config.data.dataset_name_to_path[DatasetName.mt_bench],
        deepl_api_key=config.translation.api_key,
        target_lang_codes=config.translation.target_lang_codes,
    )


@preprocess.command(name="translate_MTbench_reference_answers")
@config_option
def run_translate_MTbench_reference_answers(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    translate_MT_bench_ref_answers_languages(
        dataset_name=DatasetName.mt_bench_ref_answers,
        dataset_path=PROJECT_ROOT / config.data.dataset_name_to_path[DatasetName.mt_bench_ref_answers],
        deepl_api_key=config.translation.api_key,
        target_lang_codes=config.translation.target_lang_codes,
    )


@preprocess.command(name="translate_MTbench_judge_prompts")
@config_option
def run_translate_MTbench_judge_prompts(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    translate_MTbench_judge_prompts(
        dataset_name=DatasetName.mt_bench_judge_prompts,
        dataset_path=PROJECT_ROOT / config.data.dataset_name_to_path[DatasetName.mt_bench_judge_prompts],
        deepl_api_key=config.translation.api_key,
        target_lang_codes=config.translation.target_lang_codes,
    )


@preprocess.command(name="create_multi_lingual_lima_datasets_for_fastchat")
@config_option
def run_create_multi_lingual_lima_datasets_for_fastchat(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    create_multi_lingual_lima_datasets_for_fastchat(
        dataset_path=config.data.dataset_name_to_path[DatasetName.mulima],
        languages_per_dataset=config.languages_per_dataset,
        column_prefix="conversations",
    )


@preprocess.command(name="create_multi_lingual_stackexchange_val_sets")
@config_option
def run_create_multi_lingual_stackexchange_val_sets(config_dir_path: Path):
    config: AppConfig = _setup_entry_point(
        config_dir_path=config_dir_path,
        relative_config_file_path="app_config.yaml",
        class_type=AppConfig,
    )
    create_multi_lingual_stackexchange_val_sets(
        dataset_path=config.data.dataset_name_to_path[DatasetName.own_stackexchange],
        train_dataset_path=config.data.dataset_name_to_path[DatasetName.lima],
        max_num_samples_per_source=5,
    )


if __name__ == "__main__":
    main()
