import argparse
import contextlib
import pathlib
import sys
import time

import skemb.dataset

import numpy as np
import sklearn.metrics
import sklearn.neighbors
import sklearn.neural_network
import sklearn.svm


parser = argparse.ArgumentParser()

method_mapping = {
    'knn': sklearn.neighbors.KNeighborsClassifier,
    'mlp': sklearn.neural_network.MLPClassifier,
    'svm': sklearn.svm.LinearSVC,
}

parser.add_argument('--train-dataset', type=pathlib.Path, default='dataset/train')
parser.add_argument('--train-vectors', type=pathlib.Path, default='embedding/embeddings.npy')
parser.add_argument('--test-dataset', type=pathlib.Path, default='dataset/test')
parser.add_argument('--test-vectors', type=pathlib.Path, default='test-embeddings/embeddings.npy')
parser.add_argument('--method', choices=method_mapping, default='knn')
parser.add_argument('--method-hyperparams', type=pathlib.Path)
parser.add_argument('--log-file', type=pathlib.Path)


def msg(*k, **kw):
    kw['file'] = sys.stderr
    print(*k, **kw)


def run(train_dataset, train_vectors, test_dataset, test_vectors, method, method_hyperparams, log_file):
    def log(*k, **kw):
        if not log_file:
            return
        kw['file'] = log_file
        print(*k, **kw)

    with contextlib.ExitStack() as exit_stack:
        if log_file:
            log_file = open(log_file, 'w')
            exit_stack.enter_context(log_file)

        msg('Loading datasets and vectors')
        train_ds = list(skemb.dataset.read_dataset(train_dataset))
        train_labels = np.array([label for label, _ in train_ds])
        train_vectors = np.load(train_vectors)

        test_ds = list(skemb.dataset.read_dataset(test_dataset))
        test_labels = np.array([label for label, _ in test_ds])
        test_vectors = np.load(test_vectors)

        if method_hyperparams:
            code = method_hyperparams.read_text()
            g = {}
            exec(code, g)
            hyper_params_k = g.get('ARGS', [])
            hyper_params_kw = g.get('KEYWORDS', {})
        else:
            hyper_params_k = []
            hyper_params_kw = {}

        model = method_mapping[method](*hyper_params_k, **hyper_params_kw)

        msg('Fitting model')
        t0 = time.perf_counter()
        model.fit(train_vectors, train_labels)
        dt = time.perf_counter() - t0
        msg(f'Model fit in {dt} seconds')

        msg('Predicting on test data')
        t0 = time.perf_counter()
        predictions = model.predict(test_vectors)
        dt = time.perf_counter() - t0
        msg(f'Prediction done in {dt} seconds')

        accuracy = sklearn.metrics.accuracy_score(test_labels, predictions)
        msg(f'Accuracy: {accuracy:.4f}')
        log(f'Accuracy: {accuracy:.4f}')

        mcc = sklearn.metrics.matthews_corrcoef(test_labels, predictions)
        msg(f'MCC: {mcc:.4f}')
        log(f'MCC: {mcc:.4f}')

        # NOTE: The metrics below are specific for the ham/spam dataset
        label_set = set(train_labels) | set(test_labels)
        if label_set == {'ham', 'spam'}:
            spam_caught = ((test_labels == 'spam') & (predictions == test_labels)).sum()
            spam_caught /= (test_labels == 'spam').sum()
            msg(f'SC: {spam_caught:.4f}')
            log(f'SC: {spam_caught:.4f}')

            blocked_hams = ((test_labels == 'ham') & (predictions != test_labels)).sum()
            blocked_hams /= (test_labels == 'ham').sum()
            msg(f'BH: {blocked_hams:.4f}')
            log(f'BH: {blocked_hams:.4f}')


if __name__ == '__main__':
    args = parser.parse_args()
    run(**vars(args))
