"""
TODO: document this
"""
import contextlib
import copy
import time

import numpy as np
import sklearn.metrics
import sklearn.neighbors
import sklearn.neural_network
import sklearn.svm

import skemb.dataset
import skemb.embedding
import skemb.tokenizer
import skemb.sgpattern_miner
import skemb.sgpattern
import skemb.pipeline


def learn_patterns(ctx, dataset, params):
    ds = list(skemb.dataset.read_dataset(dataset))
    labels = [label for label, _ in ds]
    texts = [text for _, text in ds]
    sequences = [skemb.tokenizer.tokenize(text) for text in texts]

    with contextlib.ExitStack() as exit_stack:
        logfile = None
        if ctx.workdir:
            logfile = open(ctx.workdir / 'log.txt', 'w')
            exit_stack.enter_context(logfile)

        miner = skemb.sgpattern_miner.SGPatternMiner(
            sequences=sequences,
            labels=labels,
            str_representations=texts,
            **params['PARAMS'],
        )
        miner.logfile = logfile
        patterns = miner.run()

        return patterns


def learn_embedding(ctx, dataset, params, patterns=None):
    print('Generating embedding model')

    method_mapping = {
        'BinaryBagOfPatterns': skemb.embedding.BinaryBagOfPatterns,
        'PatternsProbabilities': skemb.embedding.PatternsProbabilities,
        'Tok1': skemb.embedding.Tok1BagOfWords,
        'Tok2': skemb.embedding.Tok2BagOfWords,
    }

    ds = list(skemb.dataset.read_dataset(dataset))
    labels = [label for label, _ in ds]
    texts = [text for _, text in ds]
    sequences = [skemb.tokenizer.tokenize(text) for text in texts]

    with contextlib.ExitStack() as exit_stack:
        logfile = None
        if ctx.workdir:
            logfile = open(ctx.workdir / 'log.txt', 'w')
            exit_stack.enter_context(logfile)

        embedding_method = method_mapping[params['METHOD']]()
        embedding_method.set_params(**params['PARAMS'])
        embedding_method.logfile = logfile

        if isinstance(embedding_method, skemb.embedding.SGPatternEmbedding):
            if not patterns:
                ctx.error('patterns are required when using a SGPatternEmbedding method')
            embedding_method.train(texts, sequences, labels, patterns)
        else:
            embedding_method.train(texts, sequences, labels)

        embedding_method.logfile = None

        return embedding_method


def extract_embeddings(ctx, model, dataset):
    print('Extracting embeddings')
    print('Tokenizing')
    t0 = time.perf_counter()
    p0 = time.process_time()

    ds = list(skemb.dataset.read_dataset(dataset))
    labels = [label for label, _ in ds]
    texts = [text for _, text in ds]
    sequences = [skemb.tokenizer.tokenize(text) for text in texts]

    dt = time.perf_counter() - t0
    pdt = time.process_time() - p0
    print(f'Tokenization done in {dt} (process: {pdt}) seconds')

    print('Extracting embeddings')
    t0 = time.perf_counter()
    p0 = time.process_time()

    data = np.array([
        model.extract(text, seq)
        for text, seq in zip(texts, sequences)
    ])


    dt = time.perf_counter() - t0
    pdt = time.process_time() - p0
    print(f'Embedding done in {dt} (process: {pdt}) seconds')

    return data


def fit_classification_model(ctx, dataset, embeddings, params, val_embeddings, val_dataset):
    print('Generating classification model')
    method_mapping = {
        'knn': sklearn.neighbors.KNeighborsClassifier,
        'mlp': sklearn.neural_network.MLPClassifier,
        'svm': sklearn.svm.LinearSVC,
    }

    ds = list(skemb.dataset.read_dataset(dataset))
    labels = np.array([label for label, _ in ds])

    val_ds = list(skemb.dataset.read_dataset(val_dataset))
    val_labels = np.array([label for label, _ in val_ds])

    print('Fitting model')
    t0 = time.perf_counter()
    if params['METHOD'] == 'mlp':
        model_params = params['PARAMS']
        model_params['warm_start'] = True
        max_iter = model_params.get('max_iter', 200)
        model_params['max_iter'] = 1
        model = method_mapping[params['METHOD']](**params['PARAMS'])
        best_model = model
        best_val_acc = 0
        iterations_since_best = 0
        for i in range(max_iter):
            model.fit(embeddings, labels)
            val_acc = model.score(val_embeddings, val_labels)
            print(f'    i={i} train_loss={model.loss_:.6f} val_acc={val_acc:.6f}')
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model = copy.deepcopy(model)
                iterations_since_best = 0
            else:
                iterations_since_best += 1
                if iterations_since_best > 10:
                    print('Stopping training because of 10 iterations without improving validation accuracy')
                    print(f'Best validation accuracy: {best_val_acc:.6f}')
                    break
        model = best_model
    else:
        model = method_mapping[params['METHOD']](**params['PARAMS'])
        model.fit(embeddings, labels)

    dt = time.perf_counter() - t0
    print(f'Model fit in {dt} seconds')

    return model


def predict(ctx, embeddings, model):
    print('Prediction running')
    t0 = time.perf_counter()
    predictions = model.predict(embeddings)
    dt = time.perf_counter() - t0
    print(f'Prediction done in {dt} seconds')

    return predictions


def calc_metrics(ctx, predictions, dataset):
    ds = list(skemb.dataset.read_dataset(dataset))
    true_labels = np.array([label for label, _ in ds])
    accuracy = sklearn.metrics.accuracy_score(true_labels, predictions)
    print(f'{ctx.unit.info["unit_type"]} accuracy: f{accuracy}')
    return {
        'accuracy': accuracy,
    }
