import argparse
import pathlib
import shutil
import sys
import time

import numpy as np

import skemb.dataset
import skemb.embedding
import skemb.tokenizer


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=pathlib.Path, default='dataset/test')
parser.add_argument('--output-dir', '-o', type=pathlib.Path, default='test-embeddings')
parser.add_argument('--force', '-f', action='store_true')
parser.add_argument('--model', type=pathlib.Path, default='embedding/model')
parser.add_argument('--plot', action='store_true')
parser.add_argument('--interactive-plot', action='store_true')


def msg(*k, **kw):
    kw['file'] = sys.stderr
    print(*k, **kw)


def run(dataset, output_dir, force, model, plot, interactive_plot):
    if output_dir.exists():
        if not force:
            raise Exception(f'directory {output_dir} exists. Use -f to overwrite it')
        else:
            msg(f'Removing existing directory {output_dir}')
            shutil.rmtree(output_dir)
    output_dir.mkdir()


    msg('Tokenizing')
    t0 = time.perf_counter()
    p0 = time.process_time()

    ds = list(skemb.dataset.read_dataset(dataset))
    labels = [label for label, _ in ds]
    texts = [text for _, text in ds]
    sequences = [skemb.tokenizer.tokenize(text) for text in texts]

    dt = time.perf_counter() - t0
    pdt = time.process_time() - p0
    msg(f'Tokenization done in {dt} (process: {pdt}) seconds')

    msg('Loading embedding model')
    t0 = time.perf_counter()
    p0 = time.process_time()

    model = skemb.embedding.load(model)

    dt = time.perf_counter() - t0
    pdt = time.process_time() - p0
    msg(f'Model loaded in {dt} (process: {pdt}) seconds')

    msg('Extracting embeddings')
    t0 = time.perf_counter()
    p0 = time.process_time()

    data = np.array([
        model.extract(text, seq)
        for text, seq in zip(texts, sequences)
    ])

    dt = time.perf_counter() - t0
    pdt = time.process_time() - p0
    msg(f'Embedding done in {dt} (process: {pdt}) seconds')

    out = output_dir / 'embeddings.npy'
    np.save(out, data)
    msg(f'Embeddings saved at {out}')

    if plot or interactive_plot:
        import matplotlib.pyplot as plt
        import umap
        import umap.plot

        # Make a plot of the embeddings using UMAP
        mapper = umap.UMAP(verbose=True).fit(data)
        umap.plot.points(mapper, labels=np.array(labels), color_key_cmap='Paired', theme='fire')
        plt.savefig(output_dir / 'plot.pdf')

        if interactive_plot:
            plt.show()


if __name__ == '__main__':
    args = parser.parse_args()
    run(**vars(args))
