#!/usr/bin/env python
import os
import torch
import tqdm
import datetime
import numpy as np
import huggingface_hub as hf
# import networkx as nx

from argparse import ArgumentParser
from transformers import pipeline
# from collections import defaultdict

from stollen import StolenProbabilitySearch, ApproxAlgorithms, ExactAlgorithms
from stollen.utils import is_valid_token
from stollen.server import create_app
from stollen.server.data_model import Experiment, Model, Result, Solution


def get_vocab(tokenizer):
    # Facebook models have a get_tgt_vocab method
    if getattr(tokenizer, 'get_tgt_vocab', None):
        # But their key/values are swapped
        vocab = {v:k for k, v in tokenizer.get_tgt_vocab().items()}
    else:
        vocab = tokenizer.get_vocab()
    return vocab


if __name__ == "__main__":


    parser = ArgumentParser(description='Run stolen probability script on '
                            'Huggingface model.')

    parser.add_argument('-u', '--url', type=str, dest='url', required=True,
                        help='Url of hugging face model to check.')
    parser.add_argument('--approx-algorithm', default=ApproxAlgorithms.default(),
                        choices=ApproxAlgorithms.choices(),
                        help='Choice of approximate algorithm. Default: %s' %
                        ApproxAlgorithms.default())
    parser.add_argument('--exact-algorithm', default=ExactAlgorithms.default(),
                        choices=ExactAlgorithms.choices(),
                        help='Choice of exact algorithm. Default: %s' %
                        ExactAlgorithms.default())
    parser.add_argument('--check-tokens', nargs='+', default=None)
    parser.add_argument('--filter-lang', type=str, default=None)
    parser.add_argument('-p', '--patience', type=int, default=100,
                        help='Number of swaps to attempt before giving up '
                        'when using approximate algorithm.')
    parser.add_argument('--logit-upper-bound', type=float, default=100.,
                        help='Largest possible logit activation')
    parser.add_argument('--logit-lower-bound', type=float, default=-100.,
                        help='Smallest possible logit activation')
    parser.add_argument('--save-db', action='store_true',
                        help='Whether to save results to database.')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='Whether to print additional info '
                        'when running algorithm.')

    args = parser.parse_args()

    if args.approx_algorithm == 'none':
        args.approx_algorithm = None

    if args.exact_algorithm == 'none':
        args.exact_algorithm = None
        args.approximate = True
    else:
        args.approximate = False

    assert (args.approx_algorithm is not None) or (args.exact_algorithm is not None)

    # Extract information from hugging face URL
    hb = hf.HfApi()
    URL_PREFIX = 'https://huggingface.co/'
    model_info = hb.model_info(args.url[len(URL_PREFIX):])
    args.task = model_info.pipeline_tag
    args.model = model_info.modelId

    print('Loading *%s* model "%s"...' % (args.task, args.model))
    print('Preprocessing...')
    model = pipeline(args.task, args.model)
    L = model.model.get_output_embeddings()

    W = L.weight.detach().numpy()
    num_classes, dim = W.shape
    print('\tWeight matrix found with dim %s' % repr(W.shape))
    if getattr(model.model, 'final_logits_bias', None) is not None:
        b = model.model.final_logits_bias.detach().numpy().T
        print('\tBias vector found as "final_logits_bias" with dim %s' % repr(b.shape))
    elif getattr(L, 'bias', None) is not None:
        b = L.bias.detach().numpy().reshape(-1, 1)
        print('\tBias vector found as "output_embeddings" with dim %s' % repr(b.shape))
    else:
        print('##########################################')
        print('#####  WARNING!!! no bias term found.#####') 
        print('##########################################')
        b = None

    if b is not None:
        assert b.shape[0] == W.shape[0]
        assert b.shape[1] == 1

    # Construct vocabulary
    vocab = get_vocab(model.tokenizer)
    inv_vocab = {v:k for k, v in vocab.items()}

    if args.check_tokens is not None:
        with model.tokenizer.as_target_tokenizer():
            class_list = []
            for t in args.check_tokens:
                ids = model.tokenizer.encode(t)
                assert len(ids) == 2
                class_list.append(ids[0])
    else:
        if args.filter_lang is not None:
            class_list = []
            ignored = []
            assert len(vocab) == num_classes
            for token, i in vocab.items():
                if token in model.tokenizer.special_tokens_map.values():
                    class_list.append(i)
                elif is_valid_token(args.filter_lang, token):
                    class_list.append(i)
                else:
                    ignored.append(token)
            print('\tRetained %d/%d tokens in vocab...' % (len(class_list), num_classes))
            print('\tWriting ignored tokens to ignored.txt...')
            outfile = os.path.join('ignored',
                                   '%s.txt' % args.model.replace(os.path.sep, '_'))
            with open(outfile, 'w') as f:
                f.write('\n'.join(ignored))
        else:
            class_list = list(range(num_classes))

    time_start = datetime.datetime.now()
    print('Asserting whether some classes are bounded in probability ...')
    print('\tUsing approximate algorithm *%s* ...' % args.approx_algorithm)
    if args.exact_algorithm is not None:
        print('\tUsing exact algorithm *%s* ...' % args.exact_algorithm)
    sp_search = StolenProbabilitySearch(W, b=b)
    results = sp_search.find_bounded_classes(class_list=class_list,
                                             exact_algorithm=args.exact_algorithm,
                                             approx_algorithm=args.approx_algorithm,
                                             lb=args.logit_lower_bound,
                                             ub=args.logit_upper_bound,
                                             patience=args.patience)

    # Add token to the results
    for r in results:
        r['token'] = inv_vocab[r['index']]

    num_bounded = 0

    if args.approximate:
        print('Potentially bounded in probability:\n')
    else:
        print('Bounded in probability:\n')
    for each in results:
        if each['is_bounded']:
            num_bounded += 1
            print('\t%d\t%s' % (each['index'], each['token']))
    print('*%d/%d* in total were found to be bounded' % (num_bounded, len(results)))
    time_end = datetime.datetime.now()

    if args.save_db:
        # Prepare to write results to db
        db_results = []
        for r in results:
            point = r.pop('point', None)
            if point is not None:
                solution = Solution(point)
            else:
                solution = None
            db_results.append(Result(**r, solution=solution))

        model = Model(name=args.model,
                      task=args.task,
                      vocab_size=W.shape[0],
                      embed_dim=W.shape[1],
                      has_bias=b is not None,
                      url=args.url)

        algos = [args.approx_algorithm, args.exact_algorithm]
        algorithm = '+'.join([a for a in algos if a])

        exp = Experiment(algorithm=algorithm,
                         started=time_start,
                         finished=time_end,
                         patience=args.patience,
                         model=model,
                         num_bounded=sum(r.is_bounded for r in db_results),
                         results=db_results)

        app, db = create_app()
        with app.app_context():
            print('Saving experiment results to database..')
            db.session.add(exp)
            db.session.commit()
