import argparse
import contextlib
import pathlib
import shutil
import sys

import numpy as np

import skemb.dataset
import skemb.sgpattern
import skemb.embedding
import skemb.tokenizer


method_mapping = {
    'BinaryBagOfPatterns': skemb.embedding.BinaryBagOfPatterns,
    'PatternsProbabilities': skemb.embedding.PatternsProbabilities,
    'Tok1': skemb.embedding.Tok1BagOfWords,
    'Tok2': skemb.embedding.Tok2BagOfWords,
}

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=pathlib.Path, default='dataset/train')
parser.add_argument('--output-dir', '-o', type=pathlib.Path, default='embedding')
parser.add_argument('--force', '-f', action='store_true')
parser.add_argument('--method', choices=method_mapping, default='PatternsProbabilities')
parser.add_argument('--params', type=pathlib.Path)
parser.add_argument('--sgpatterns', type=pathlib.Path, default='sgpatterns/patterns.pickle')
parser.add_argument('--plot', action='store_true')
parser.add_argument('--interactive-plot', action='store_true')
parser.add_argument('--no-logging', action='store_true')


def msg(*k, **kw):
    kw['file'] = sys.stderr
    print(*k, **kw)


def run(dataset, output_dir, force, method, params, sgpatterns, plot, interactive_plot, no_logging):
    if output_dir.exists():
        if not force:
            raise Exception(f'directory {output_dir} exists. Use -f to overwrite it')
        else:
            msg(f'Removing existing directory {output_dir}')
            shutil.rmtree(output_dir)
    output_dir.mkdir()

    if not no_logging:
        log_file = output_dir / 'log.txt'
    else:
        log_file = None


    ds = list(skemb.dataset.read_dataset(dataset))
    labels = [label for label, _ in ds]
    texts = [text for _, text in ds]
    sequences = [skemb.tokenizer.tokenize(text) for text in texts]

    with contextlib.ExitStack() as exit_stack:
        if log_file:
            log_file = open(log_file, 'w')
            exit_stack.enter_context(log_file)

        msg('Training embedding method')
        if params:
            code = params.read_text()
            g = {}
            exec(code, g)
            params_k = g.get('ARGS', [])
            params_kw = g.get('KEYWORDS', {})
        else:
            params_k = []
            params_kw = {}

        embedding_method = method_mapping[method]()
        embedding_method.set_params(*params_k, **params_kw)
        embedding_method.logfile = log_file

        if isinstance(embedding_method, skemb.embedding.SGPatternEmbedding):
            if not sgpatterns:
                raise Exception('option --patterns is required when using an SGPatternEmbedding method')
            patterns = skemb.sgpattern.load(sgpatterns)
            embedding_method.train(texts, sequences, labels, patterns)
        else:
            embedding_method.train(texts, sequences, labels)
        out = output_dir / 'model'
        embedding_method.save(out)
        msg(f'Embedding model saved to {out}')

    data = np.array([
        embedding_method.extract(text, seq)
        for text, seq in zip(texts, sequences)
    ])
    out = output_dir / 'embeddings.npy'
    np.save(out, data)
    msg(f'Embeddings saved at {out}')

    if plot or interactive_plot:
        # NOTE: we only import here because there is an overhead associated
        # with importing these packages.
        import matplotlib.pyplot as plt
        import umap
        import umap.plot

        # Make a plot of the embeddings using UMAP
        mapper = umap.UMAP(verbose=True).fit(data)
        umap.plot.points(mapper, labels=np.array(labels), color_key_cmap='Paired', theme='fire')
        plt.savefig(output_dir / 'plot.pdf')

        if interactive_plot:
            plt.show()


if __name__ == '__main__':
    args = parser.parse_args()
    run(**vars(args))
