#!/usr/bin/env python
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import ConvexHull
from argparse import ArgumentParser

from stollen import StolenProbabilitySearch, ApproxAlgorithms, ExactAlgorithms, lp_chebyshev


if __name__ == "__main__":

    parser = ArgumentParser(description='Verify we can prevent stolen probability.')

    parser.add_argument('-d', '--dim', type=int, required=True,
                        help='Dimensionality of each softmax weight vector.')
    parser.add_argument('-C', '--num-classes', dest='num_classes',
                        type=int, required=True,
                        help='Number of classes in softmax layer.')
    parser.add_argument('-s', '--seed', type=int, default=None,
                        help='Value to use for random seed')
    parser.add_argument('-b', '--use-bias', action='store_true',
                        help='Whether to include a bias term')
    parser.add_argument('--do-not-prevent', action='store_true',
                        help='Whether to avoid forcing stollen probability not to occur.')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='Whether to print additional info '
                        'when running algorithm.')

    args = parser.parse_args()

    if args.seed is not None:
        np.random.seed(args.seed)

    NUM_CLASSES = args.num_classes
    DIM = args.dim

    softmax_w = np.random.normal(0, 1, (NUM_CLASSES, DIM))
    if args.use_bias:
        # Prevent stolen probability with a bias term
        if args.do_not_prevent:
            softmax_b = np.random.normal(0, 1, (NUM_CLASSES, 1))
        else:
            # Prevent stolen probability whith a bias term (see Appendix D in paper)
            softmax_b = - .5 * (np.linalg.norm(softmax_w, axis=1, keepdims=True) ** 2)
    else:
        softmax_b = None
        # Prevent stolen probability when no bias term is present
        # by normalising the weight vectors (restrict them to unit hypersphere)
        if not args.do_not_prevent:
            softmax_w = softmax_w / np.linalg.norm(softmax_w, axis=1, keepdims=True)
    sp_search = StolenProbabilitySearch(softmax_w, b=softmax_b)
    results = sp_search.find_bounded_classes(
        exact_algorithm='lp_chebyshev',
        approx_algorithm=None)
    bounded = [r['index'] for r in results if r['is_bounded']]
    print('Bounded classes:', bounded)
    assert(len(bounded) == 0)
