#!/usr/bin/env python
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import ConvexHull
from argparse import ArgumentParser

from stollen import StolenProbabilitySearch, ApproxAlgorithms, ExactAlgorithms, lp_chebyshev


if __name__ == "__main__":

    parser = ArgumentParser(description='Run stolen probability script on Random Softmax Layer.')

    parser.add_argument('-d', '--dim', type=int, required=True,
                        help='Dimensionality of each softmax weight vector.')
    parser.add_argument('-C', '--num-classes', dest='num_classes',
                        type=int, required=True,
                        help='Number of classes in softmax layer.')
    parser.add_argument('-p', '--patience', type=int, default=100,
                        help='Number of swaps to attempt before giving up.')
    parser.add_argument('--approx-algorithm', default=ApproxAlgorithms.default(),
                        choices=ApproxAlgorithms.choices(),
                        help='Choice of approximate algorithm. Default: %s' %
                        ApproxAlgorithms.default())
    parser.add_argument('--exact-algorithm', default=ExactAlgorithms.default(),
                        choices=ExactAlgorithms.choices(),
                        help='Choice of exact algorithm. Default: %s' %
                        ExactAlgorithms.default())
    parser.add_argument('-s', '--seed', type=int, default=None,
                        help='Value to use for random seed')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='Whether to print additional info '
                        'when running algorithm.')

    args = parser.parse_args()

    if args.approx_algorithm == 'none':
        args.approx_algorithm = None

    if args.exact_algorithm == 'none':
        args.exact_algorithm = None
        args.approximate = True
    else:
        args.approximate = False

    assert (args.approx_algorithm is not None) or (args.exact_algorithm is not None)

    if args.seed is not None:
        np.random.seed(args.seed)

    NUM_CLASSES = args.num_classes
    DIM = args.dim
    PATIENCE = args.patience
    APPROXIMATE = args.approximate

    softmax_w = np.random.normal(0, 1, (NUM_CLASSES, DIM))
    # softmax_w = np.random.uniform(-1, 1, (NUM_CLASSES, DIM))
    sp_search = StolenProbabilitySearch(softmax_w, b=None)
    results = sp_search.find_bounded_classes(
        exact_algorithm=args.exact_algorithm,
        approx_algorithm=args.approx_algorithm,
        patience=PATIENCE)
    sorted_results = [r for r in sorted(results, key=lambda x: x.get('iterations', 0))]
    bounded = [r['index'] for r in results if r['is_bounded']]
    bounded_set = set(bounded)
    if APPROXIMATE:
        print('Potentially %d bounded in probability    : %s' % (len(bounded), bounded))
        # for b in sorted_results:
        #     res = lp_chebyshev(b['index'], softmax_w)
        #     if res['status'] == 2:
        #         # print(b['index'], b['iterations'], res['radius'])
        #         # Verify that the point computed using linear programming
        #         # does indeed produce an activation for which the target class
        #         # is the *argmax*.
        #         predicted = np.argmax(softmax_w.dot(np.array(b['point'])))
        #         if not b['is_bounded']:
        #             assert predicted == b['index']
    else:
        print('Bounded in probability    : %s' % bounded)


    # QHULL is extremely slow if run with more than 9 DIM.
    if DIM <= 9:
        # QHULL code
        hull = ConvexHull(softmax_w)
        qhull_vertices = hull.vertices
        bounded_qhull = set(range(NUM_CLASSES)) - set(qhull_vertices)
        print('Bounded according to qhull            : %r' % list(sorted(bounded_qhull)))

        min_len = min(len(bounded_qhull), len(bounded_set))
        if min_len > 0:
            agreement = bounded_set.intersection(bounded_qhull)
            print('Precision: %.2f%%' % (100 * (len(agreement) / len(bounded_set))))
            print('Recall   : %.2f%%' % (100 * (len(agreement) / len(bounded_qhull))))
        else:
            print('One of the two hulls is empty.')


    if DIM in [2, 3]:

        fig = plt.figure(figsize=(8, 8))
        if DIM == 3:
            ax = plt.axes(projection='3d')
        else:
            ax = plt.axes()
        ax.scatter(*softmax_w.T, s=20)
        for i in range(softmax_w.shape[0]):
            ax.text(*softmax_w[i, :], '%d' % i)
        for simplex in hull.simplices:
            if DIM == 2:
                ax.plot(softmax_w[simplex, 0], softmax_w[simplex, 1], 'k-')
        plt.show()
