import argparse
import pathlib
import sys

import numpy as np

import skemb.dataset


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=pathlib.Path, default='dataset/train')
parser.add_argument('--embeddings', type=pathlib.Path, default='embedding/embeddings.npy')
parser.add_argument('--output', '-o', type=pathlib.Path, default='embedding/plot.pdf')
parser.add_argument('--force', '-f', action='store_true')
parser.add_argument('--interactive-plot', action='store_true')


def msg(*k, **kw):
    kw['file'] = sys.stderr
    print(*k, **kw)


def run(dataset, embeddings, output, force, interactive_plot):
    if output.exists():
        if output.is_dir():
            raise Exception(f'{output} is a directory')

        if not force:
            raise Exception(f'{output} exists. Use -f to overwrite it')
        else:
            output.unlink()

    msg('Loading dataset')
    ds = list(skemb.dataset.read_dataset(dataset))
    labels = [label for label, _ in ds]

    msg('Loading embeddings')
    data = np.load(embeddings)

    # NOTE: importing these only here because of the overhead associated. If
    # the user wants just to see a help message, then this point is not
    # reached.
    msg('Importing libraries')
    import matplotlib.pyplot as plt
    import umap
    import umap.plot

    # Make a plot of the embeddings using UMAP
    msg('Running UMAP')
    mapper = umap.UMAP(verbose=True).fit(data)
    umap.plot.points(mapper, labels=np.array(labels), color_key_cmap='Paired', theme='fire')
    plt.savefig(output)
    msg(f'Plot saved at {output}')

    if interactive_plot:
        plt.show()


if __name__ == '__main__':
    args = parser.parse_args()
    run(**vars(args))
