"""
TODO: Document this.
"""
import itertools
import pathlib
import typing

import pydantic

from . import pipeline
from . import runners


class DatasetConf(pydantic.BaseModel):
    NAME: str
    TRAIN: pathlib.Path
    VALIDATION: pathlib.Path
    TEST: pathlib.Path


class PatternMinerConf(pydantic.BaseModel):
    PARAMS: dict = dict(
        skipgram_k=6,
        min_support=10,
        max_iterations=6,
        pattern_buffer_size=6,
    )


class EmbeddingConf(pydantic.BaseModel):
    METHOD: str
    PARAMS: dict = {}


class ClassificationConf(pydantic.BaseModel):
    METHOD: str
    PARAMS: dict = {}


class PipelineConf(pydantic.BaseSettings):
    DATASETS: typing.List[DatasetConf]

    PATTERN_MINER: typing.List[PatternMinerConf]

    EMBEDDING: typing.List[EmbeddingConf]

    CLASSIFICATION: typing.List[ClassificationConf]


def pipeline_from_conf(conf):
    """
    Build a Pipeline object from a configuration object.
    """
    pl = pipeline.Pipeline()

    # One can see the configuration of execution units as an expansion of a
    # tree:
    # - On the first level, we define units for pattern miners;
    # - Then, on the second level, for each pattern miner unit, we will define
    #   a set of units for embedding models, which will depend on the result of
    #   the "parent" unit, which is the miner unit.
    # - And so on...
    #
    # We use a variable `fringe` to store things from the previous level that
    # will be needed on the next one.
    fringe = [{'dataset_conf': ds_conf} for ds_conf in conf.DATASETS]

    # Create execution units for pattern miners
    # =========================================
    next_fringe = []
    pairs = itertools.product(conf.PATTERN_MINER, fringe)
    for i, (pm_conf, parent) in enumerate(pairs):
        dataset_conf = parent['dataset_conf']
        unit = pl.unit(
            kw={
                'dataset': pipeline.PathIn(dataset_conf.TRAIN),
                'params': pm_conf.dict(),
            },
            info={
                'unit_type': 'patterns',
                'dataset': dataset_conf.dict(),
            },
            runner=runners.learn_patterns,
        )

        next_fringe.append({
            'dataset_conf': dataset_conf,
            'patterns_unit': unit,
        })
    fringe = next_fringe

    # Create execution units for embedding models
    # ===========================================
    next_fringe = []
    pairs = itertools.product(conf.EMBEDDING, fringe)
    for i, (embedding_conf, parent) in enumerate(pairs):
        patterns_unit = parent['patterns_unit']
        dataset_conf = parent['dataset_conf']

        # Unit for generating embedding model
        embedding_model_unit = pl.unit(
            kw={
                'patterns': patterns_unit,
                'dataset': pipeline.PathIn(dataset_conf.TRAIN),
                'params': embedding_conf.dict(),
            },
            info={
                'unit_type': 'embedding-model',
                'dataset': dataset_conf.dict(),
                'embedding': embedding_conf.dict(),
            },
            runner=runners.learn_embedding,
        )

        # Units for extracting embeddings from train, validation and test
        # datasets
        extractor_units = {}
        for ds_type in ('TRAIN', 'VALIDATION', 'TEST'):
            unit = pl.unit(
                kw={
                    'model': embedding_model_unit,
                    'dataset': pipeline.PathIn(getattr(dataset_conf, ds_type)),
                },
                info={
                    'unit_type': f'{ds_type.lower()}-embedding',
                    'dataset': dataset_conf.dict(),
                    'embedding': embedding_conf.dict(),
                },
                runner=runners.extract_embeddings,
            )
            extractor_units[ds_type] = unit

        next_fringe.append({
            'extractor_units': extractor_units,
            'dataset_conf': dataset_conf,
            'embedding_conf': embedding_conf,
        })
    fringe = next_fringe

    # Create execution units for classification methods
    # =================================================
    next_fringe = []
    pairs = itertools.product(conf.CLASSIFICATION, fringe)
    for i, (classification_conf, parent) in enumerate(pairs):
        extractor_units = parent['extractor_units']
        dataset_conf = parent['dataset_conf']
        embedding_conf = parent['embedding_conf']

        # Unit for generating classification model
        model_unit = pl.unit(
            kw={
                'embeddings': extractor_units['TRAIN'],
                'val_embeddings': extractor_units['VALIDATION'],
                'dataset': pipeline.PathIn(dataset_conf.TRAIN),
                'val_dataset': pipeline.PathIn(dataset_conf.VALIDATION),
                'params': classification_conf.dict(),
            },
            info={
                'unit_type': 'classification-model',
                'dataset': dataset_conf.dict(),
                'embedding': embedding_conf.dict(),
                'classification': classification_conf.dict(),
            },
            runner=runners.fit_classification_model,
        )

        for ds_type in ('TRAIN', 'VALIDATION', 'TEST'):
            unit = pl.unit(
                kw={
                    'embeddings': extractor_units[ds_type],
                    'model': model_unit,
                },
                info={
                    'unit_type': f'{ds_type.lower()}-prediction',
                    'dataset': dataset_conf.dict(),
                    'embedding': embedding_conf.dict(),
                    'classification': classification_conf.dict(),
                },
                runner=runners.predict,
            )
            pl.unit(
                kw={
                    'predictions': unit,
                    'dataset': pipeline.PathIn(getattr(dataset_conf, ds_type)),
                },
                info={
                    'unit_type': f'{ds_type.lower()}-metrics',
                    'dataset': dataset_conf.dict(),
                    'embedding': embedding_conf.dict(),
                    'classification': classification_conf.dict(),
                },
                runner=runners.calc_metrics,
                always=True,
            )

        next_fringe.append({})

    return pl
