import argparse
import subprocess
import itertools
from tqdm import tqdm
from os.path import join
import pickle
from functools import partial

n_trees_choices = [100, 150, 200, 250, 300, 350, 400, 450, 500]
max_depth_choices = [4, 6, 8, 10, 15, 20]

def _parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--n_run', type=int, default=20)
    parser.add_argument('--train_size', type=int, default=500)
    parser.add_argument('--split', type=str, default=None)
    parser.add_argument('--do_val', default=False, action='store_true')
    args = parser.parse_args()
    return args

def run_exp(args, method, n_tree, max_depth, train_size, force_dev_size=0):
    cmd = ['python', 'calib_exp/run_exp.py', '--dataset', args.dataset, '--n_run', str(args.n_run),
            '--arg_n_tree', str(n_tree), '--arg_max_depth',  str(max_depth), '--train_size', str(train_size)]
    if args.split is not None:
        cmd.extend(['--split', args.split])
    if force_dev_size > 0:
        cmd.extend(['--force_dev_size', str(force_dev_size)])
    if method == 'maxprob':
        cmd.append('--do_maxprob')
    elif method == 'baseline':
        cmd.append('--do_baseline')
    elif method == 'bow':
        cmd.append('--do_bow')
    else:
        cmd.extend(['--method', method])
    # try:
    out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode('utf-8')
    # except:
    #     out = None
    
    if out is None:
        vals = [0.0] * 5
    else:
        vals = out.split('\n')[-2].split(',')[1:]
        vals = [float(x) for x in vals]
    result = {'acc': vals[0], 'auc': vals[1], 'f1_points': vals[2:]}
    return result
    
    
# def grid_search(args):
#     all_result = []
#     for n_tree, max_depth in tqdm(itertools.product(n_trees_choices, max_depth_choices), total=n_trees_choices * max_depth_choices):
#         results = run_exp(args, 'lime', n_tree, max_depth_choices)
#         results['n_tree'] = n_tree
#         results['max_depth'] = max_depth
#         all_result.append(results)
#     return all_result

def hyper_search(args, method, search_on_dev=False):
    train_size = args.train_size
    if search_on_dev:
        if train_size == 500:
            force_dev_size = 100
        if train_size == 300:
            force_dev_size = 60
        if train_size == 100:
            force_dev_size = 25
    else:
        force_dev_size = 0

    all_result = []
    for n_tree, max_depth in itertools.product(n_trees_choices, max_depth_choices):
        results = run_exp(args, method, n_tree, max_depth, train_size, force_dev_size)
        print('[{}/{}]  N_TREE: {}  MAX_DEPTH: {}'.format(
            len(all_result), len(n_trees_choices) * len(max_depth_choices),
            n_tree, max_depth), results)
        results['n_tree'] = n_tree
        results['max_depth'] = max_depth
        all_result.append(results)
    return all_result

# def hyper_search_on_test(args, method, size):
#     pass

def print_search_result(results):
    # by acc
    top_acc = sorted(results, key=lambda x: x['acc'], reverse=True)[0]
    print('\tTop ACC', top_acc)

    # by auc
    top_auc = sorted(results, key=lambda x: x['auc'], reverse=True)[0]
    print('\tTop ACC', top_auc)

def complete_exp(args, search_on_dev=False):    
    f = partial(hyper_search, search_on_dev=search_on_dev)
    
    report = {}
    # MAXPROB
    result = run_exp(args, 'maxprob', 0, 0, args.train_size)
    print('MAXPROB', result)
    report['maxprob'] = result

    # KAMATH    
    print('----------KAMATH----------')
    results = f(args, 'baseline')
    print_search_result(results)
    report['kamath'] = results

    # BOW
    print('----------BOW----------')
    results = f(args, 'bow')
    print_search_result(results)
    report['bow'] = results

    # LIME    
    print('----------LIME----------')
    results = f(args, 'lime')
    print_search_result(results)
    report['lime'] = results

    # SHAP
    print('----------SHAP----------')
    results = f(args, 'shap')
    print_search_result(results)
    report['shap'] = results

    fname = '{}_{}_{}'.format(args.dataset, args.n_run, args.train_size)
    if search_on_dev:
        fname += '_forcedev'
    with open(join('calib_exp/reports', fname+  '.bin'), 'wb') as f:
        pickle.dump(report, f)


def choosing_search_result(results):

    # by acc
    top_acc_list = sorted(results, key=lambda x: x['acc'], reverse=True)
    for i, res in enumerate(top_acc_list[:10]):
        num_list = [res['acc'], res['auc']] + res['f1_points']
        print('\tTop ACC', i, 'NTree: {} Depth: {},   '.format(res['n_tree'], res['max_depth']), ','.join(map(str,num_list)) )

    # by auc
    top_auc_list = sorted(results, key=lambda x: x['auc'], reverse=True)
    for i, res in enumerate(top_auc_list[:10]):
        num_list = [res['acc'], res['auc']] + res['f1_points']
        print('\tTop AUC', i, 'NTree: {} Depth: {},   '.format(res['n_tree'], res['max_depth']), ','.join(map(str,num_list)) )


    for res in results:
        num_list = [res['acc'], res['auc']] + res['f1_points']
        print('NTree: {} Depth: {},   '.format(res['n_tree'], res['max_depth']), ','.join(map(str,num_list)) )

def inspect_exp(args, search_on_dev=False):    
    f = partial(hyper_search, search_on_dev=search_on_dev)
    fname = '{}_{}_{}'.format(args.dataset, args.n_run, args.train_size)
    if search_on_dev:
        fname += '_forcedev'
    with open(join('calib_exp/reports', fname+  '.bin'), 'rb') as f:
        report = pickle.load(f)
    # MAXPROB
    result = report['maxprob']
    print('MAXPROB', result)

    # KAMATH    
    print('----------KAMATH----------')
    results = report['kamath']
    choosing_search_result(results)

    # BOW
    print('----------BOW----------')
    results = report['bow']
    choosing_search_result(results)

    # LIME    
    print('----------LIME----------')
    results = report['lime']
    choosing_search_result(results)

    # SHAP
    print('----------SHAP----------')
    results = report['shap']
    choosing_search_result(results)
    

if __name__=='__main__':
    args = _parse_args()
    # print(run_exp(args, 'maxprob', 0, 0))
    # hyper_search_with_dev()
    # complete_exp(args)
    inspect_exp(args, args.do_val)
    
