#!/usr/bin/env python
import os
import yaml
import torch
import langid
import datetime
import numpy as np
from argparse import ArgumentParser

from stollen import StolenProbabilitySearch, ApproxAlgorithms, ExactAlgorithms
from stollen.utils import load_vocab
from stollen.server import create_app
from stollen.server.data_model import Experiment, Model, Result, Solution


def get_model_dict(filename):
    if filename.endswith('.npz'):
        model = np.load(filename)
    elif filename.endswith('.pt'):
        model = torch.load(filename, map_location=torch.device('cpu'))['model']
        for k in model:
            try:
                model[k] = model[k].detach().numpy()
            except:
                pass
    else:
        raise ValueError('Unexpected file format %s' % filename)
    return model


if __name__ == "__main__":

    parser = ArgumentParser(description='Run stolen probability script on Random Softmax Layer.')

    parser.add_argument('-f', '--numpy-file', type=str, dest='numpy_file', required=True,
                        help='Path to numpy model (.npz) format.')
    parser.add_argument('--W', type=str, required=True,
                        help='Name for softmax weight W entry in npz.')
    parser.add_argument('--b', type=str, default=None,
                        help='Name for softmax bias entry in npz.')
    parser.add_argument('--vocab', type=str, default=None,
                        help='Path to spm vocab file.')
    parser.add_argument('--approx-algorithm', default=ApproxAlgorithms.default(),
                        choices=ApproxAlgorithms.choices(),
                        help='Choice of approximate algorithm. Default: %s' %
                        ApproxAlgorithms.default())
    parser.add_argument('--exact-algorithm', default=ExactAlgorithms.default(),
                        choices=ExactAlgorithms.choices(),
                        help='Choice of exact algorithm. Default: %s' %
                        ExactAlgorithms.default())
    parser.add_argument('--logit-upper-bound', type=float, default=100.,
                        help='Largest possible logit activation')
    parser.add_argument('--logit-lower-bound', type=float, default=-100.,
                        help='Smallest possible logit activation')
    parser.add_argument('-p', '--patience', type=int, default=100,
                        help='Number of swaps to attempt before giving up.')
    parser.add_argument('--save-db', action='store_true',
                        help='Whether to save results to database.')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='Whether to print additional info '
                        'when running algorithm.')

    args = parser.parse_args()

    if args.approx_algorithm == 'none':
        args.approx_algorithm = None

    if args.exact_algorithm == 'none':
        args.exact_algorithm = None
        args.approximate = True
    else:
        args.approximate = False

    assert (args.approx_algorithm is not None) or (args.exact_algorithm is not None)

    modelpath, modelfile = os.path.split(args.numpy_file)
    _, modelname = os.path.split(modelpath)
    modelname = '%s:%s' % (modelname, modelfile)
    print('Processing model %s...' % modelname)

    time_start = datetime.datetime.now()

    vocab = load_vocab(args.vocab)
    inv_vocab = {v:k for k, v in vocab.items()}
    print('Loading vocabulary of size %d...' % len(vocab))

    print('Loading weights from "%s"...' % (args.numpy_file))
    model = get_model_dict(args.numpy_file)
    W = model[args.W]
    num_classes, dim = W.shape
    if args.b is not None:
        b = model[args.b].reshape(-1, 1)
    else:
        b = None
    print('\tWeight matrix found with dim %s' % repr(W.shape))
    if b is not None:
        assert b.shape[0] == W.shape[0]
        print('\tBias vector found with dim %s' % repr(b.shape))

    sp_search = StolenProbabilitySearch(W, b=b)
    results = sp_search.find_bounded_classes(class_list=tuple(range(num_classes)),
                                             exact_algorithm=args.exact_algorithm,
                                             approx_algorithm=args.approx_algorithm,
                                             lb=args.logit_lower_bound,
                                             ub=args.logit_upper_bound,
                                             patience=args.patience)
    # Add token to the results
    for r in results:
        r['token'] = inv_vocab.get(r['index'], -1)
        if r['token'] == -1:
            print('%d not found!' % r['index'])

    num_bounded = 0

    if args.approximate:
        print('Potentially bounded in probability:\n')
    else:
        print('Bounded in probability:\n')
    for each in results:
        if each['is_bounded']:
            num_bounded += 1
            print('\t%d' % (each['index']))
    print('*%d/%d* in total were found to be bounded' % (num_bounded, len(results)))
    time_end = datetime.datetime.now()

    if args.save_db:
        # Prepare to write results to db
        db_results = []
        for r in results:
            point = r.pop('point', None)
            if point is not None:
                solution = Solution(point)
            else:
                solution = None
            db_results.append(Result(**r, solution=solution))

        model = Model(name=modelname,
                      task='text2text-generation',
                      vocab_size=W.shape[0],
                      embed_dim=W.shape[1],
                      has_bias=b is not None,
                      url=None)

        algos = [args.approx_algorithm, args.exact_algorithm]
        algorithm = '+'.join([a for a in algos if a])

        exp = Experiment(algorithm=algorithm,
                         started=time_start,
                         finished=time_end,
                         patience=args.patience,
                         model=model,
                         num_bounded=sum(r.is_bounded for r in db_results),
                         results=db_results)

        app, db = create_app()
        with app.app_context():
            print('Saving experiment results to database..')
            db.session.add(exp)
            db.session.commit()
