"""Single entry-point for ..."""
from typing import Dict
from operator import or_
from functools import reduce
import json
import sys
import os
import argparse
import time
import matplotlib as mpl
import matplotlib.pyplot as plt
from itertools import count
import logging
from requests import ReadTimeout, ConnectTimeout, ConnectionError
import wandb
import pandas as pd
import numpy as np
from scipy.stats import norm
from tqdm.auto import tqdm

from run_glue import main as main_glue
from run_drop_glue import main as main_drop_glue
from run_qa import main as main_qa
from args import GLUEArgs, QAArgs


logger = logging.getLogger(__name__)


PER_TASK = {
    "rte": ("RTE", "accuracy", "acc", "Accuracy"),
#    "wnli": ("WNLI", "accuracy", "acc", "Accuracy"),
    "mrpc": ("MRPC", "f1", "f1", "F1 Score"),
    "qqp": ("QQP", "f1", "f1", "F1 Score"),
    "sst2": ("SST-2", "accuracy", "acc", "Accuracy"),
    "qnli": ("QNLI", "accuracy", "acc", "Accuracy"),
    "mnli": ("MNLI", "accuracy", "acc", "Accuracy"),
    "cola": ("CoLA", "matthews_correlation", "mat_cor", "Matthews Correlation"),
    "stsb": ("STS-B", "spearmanr", "spearmanr", "Spearmanr"),
    "squad": ("SQUAD", "f1", "f1", "F1 Score")
}

MODES = [
    'baseline (Adapter)',
    'baseline (Adapters left out == BERT)',
    'baseline (BERT only)',
    'with regularization',
    'with regularization and relu',
    'without skip-connections',
    'without skip-connections and relu',
    'without skip-connections at training time',
    'without skip-connections at training time and relu',
    'fixed with regularization',
    'fixed without skip-connections',
    'fixed without skip-connections at training time',
]

MODES_BASELINE = [
    'baseline (Adapter)',
    'baseline (Adapters left out == BERT)',
    'baseline (BERT only)',
]

MODES_SWITCHES = [
    'with regularization',
    'with regularization and relu',
    'without skip-connections',
    'without skip-connections and relu',
    'without skip-connections at training time',
    'without skip-connections at training time and relu',
]


def _get_last_notna_value(df):
    df = df[df.notna()]
    return df.iloc[-1]


def _get_wandb_api(timeout=11):
    return wandb.Api(timeout=timeout)


def _get_runs(project, filters={}, timeout=11):
    api = wandb.Api(timeout=timeout)
    yield from api.runs(project, filters=filters, per_page=10)


def get_runs(project, filters={}, limit=None, timeout=11):
    api = wandb.Api(timeout=timeout)
    runs = api.runs(project, filters=filters, per_page=10)

    total = None
    while total is None:
        try:
            total = len(runs)
        except (ReadTimeout, ConnectTimeout):
            logger.info("Error connecting, waiting 5 seconds before trying again.")
            time.sleep(5)

    if limit is not None:
        total = min(limit, total)

    with tqdm(total=total, smoothing=0.01) as bar:
        idx = 0
        while idx < total:
            try:
                run = runs[idx]
                filename = f".cache/{run.id}"
                if os.path.exists(filename) and run.state == "finished":
                    df = pd.read_pickle(filename)
                else:
                    df = run.history()
                    if run.state == "finished" and os.path.exists('.cache'):
                        df.to_pickle(filename)
                run.history = lambda: df
                yield run
            except (ReadTimeout, ConnectTimeout, ConnectionError):
                logger.info("Error connecting, waiting 5 seconds before trying again.")
                time.sleep(5)
                continue
            except Exception as e:
                logger.error("Unexpected expected exception of type {type(e)}:")
                logger.error(e)
                time.sleep(5)
                continue

            idx += 1
            bar.update(1)


def collect_data(wandb_project, limit=None, task_name=None, seed=None, **kwargs):

    filters = {
        'state': 'finished',
#        'config.adapters_at': {'$in': [None, '[]']}
    }

    if isinstance(seed, int):
        filters['config.seed'] = seed
    elif isinstance(seed, list) and len(seed) > 0:
        filters['config.seed'] = {'$in': seed[:]}
    else:
        filters['config.seed'] = {'$in': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}

    if isinstance(task_name, str):
        if task_name != 'all':
            filters['config.task_name'] = task_name
    elif isinstance(task_name, list) and len(task_name) > 0:
        filters['config.task_name'] = {'$in': task_name[:]}

    if 'low_resources' in kwargs:
        if kwargs['low_resources'] is None:
            filters['config.low_resources'] = 'None'
        else:
            filters['config.low_resources'] = kwargs['low_resources']

    if 'baseline' in kwargs:
        filters['config.baseline'] = kwargs['baseline']

    if 'without_skip_connections' in kwargs:
        filters['config.baseline'] = False
        if kwargs['without_skip_connections']:
            filters['config.adapter_drop_skip_connections'] = True
        else:
            filters['config.adapter_drop_skip_connections'] = False

    if 'without_skip_connections_training_only' in kwargs:
        filters['config.baseline'] = False
        if kwargs['without_skip_connections_training_only']:
            filters['config.adapter_drop_skip_connections_training_only'] = True
        else:
            filters['config.adapter_drop_skip_connections_training_only'] = False

    if 'with_regularization' in kwargs:
        filters['config.baseline'] = False
        if kwargs['with_regularization']:
            filters['config.switch_regularization'] = 'square'
        else:
            filters['config.switch_regularization'] = 'None'

    # CAPTURE_LAST_FLOAT_VALUE_OF
    LAST_FLOAT = [
        "eval/final_runtime",
        "eval/final_steps_per_second",
        "eval/final_samples_per_second",
        "test/runtime",
        "test/steps_per_second",
        "test/samples_per_second",
        "train/runtime",
        "train/steps_per_second",
        "train/samples_per_second",
    ]

    HIST_FLOAT = [
        "eval/steps_per_second",
        "eval/samples_per_second",
        "eval/runtime"
    ]

    data = []
    print(filters)
    for run in get_runs(wandb_project, filters, limit=limit):
        hist: pd.DataFrame = run.history()
        task_name = run.config['task_name']

        perf_eval_col = f"eval/final_{PER_TASK[task_name][1]}"
        perf_test_col = f"test/{PER_TASK[task_name][1]}"

        # Get the performance at test time and the best evaluation.
        perf_eval = _get_last_notna_value(hist[perf_eval_col])
        perf_test = _get_last_notna_value(hist[perf_test_col])
        if task_name != 'squad':
            perf_eval *= 100
            perf_test *= 100

        item = {
            'task_name': task_name,
            'seed': run.config['seed'],
            'baseline': run.config['baseline'],
            'baseline_bert': run.config.get('baseline_bert', False),
            'baseline_leave_out_all': run.config.get('baseline_leave_out_all', False),
            'use_switches': run.config['use_switches'],
            'without_skip_connections': run.config['adapter_drop_skip_connections'],
            'low_resources': run.config['low_resources'],
            'required_params': int(
                _get_last_notna_value(hist["eval/final_required_params"])
            ),
            'switch_regularization': run.config['switch_regularization'],
            'perf_eval': perf_eval,
            'perf_test': perf_test
        }

        missing_data = False
        for name in LAST_FLOAT:
            if name not in hist:
                missing_data = True
                print(f"Column {name} not in history of {run}.")

        if missing_data:
            continue

        for name in LAST_FLOAT:
            if name not in hist:
                continue
            item[name] = float(_get_last_notna_value(hist[name]))

        for name in HIST_FLOAT:
            item[name] = hist[name].dropna().apply(float).to_list()

        if 'adapter_drop_skip_connections_training_only' in run.config:
            item['without_skip_connections_training_only'] = run.config['adapter_drop_skip_connections_training_only']
        else:
            item['without_skip_connections_training_only'] = False

        for i in range(12):
            for j in count():
                tag = f'eval/final_layer.{i}.prob.{j}'
                if tag in hist:
                    item[f'layer_{i}.prob.{j}'] = _get_last_notna_value(hist[tag])
                else:
                    break

        # Mode for plots and tables.
        if item['baseline_bert'] is True:
            item['mode'] = 'baseline (BERT only)'
        elif item['baseline_leave_out_all'] is True:
            item['mode'] = 'baseline (Adapters left out == BERT)'
        elif item['baseline'] is True:
            item['mode'] = 'baseline (Adapter)'
        elif item['use_switches']:
            if item['switch_regularization'] == 'square':
                item['mode'] = 'with regularization'
            elif item['without_skip_connections'] is True:
                if item['without_skip_connections_training_only'] is True:
                    item['mode'] = 'without skip-connections at training time'
                else:
                    item['mode'] = 'without skip-connections'
            else:
                raise Exception(f"Unknown mode for run {run}.")

            try:
                switch_inputs = json.loads(
                    run.config['switch_inputs'].replace("'", "\"")
                )
                if switch_inputs[1] == 'pfeiffer:relu':
                    item['mode'] += ' and relu'
            except json.JSONDecodeError:
                pass
        else:
            if item['without_skip_connections'] is True:
                if item['without_skip_connections_training_only'] is True:
                    item['mode'] = 'fixed without skip-connections at training time'
                else:
                    item['mode'] = 'fixed without skip-connections'
            else:
                item['mode'] = 'fixed with regularization'

        data.append(item)

    df = pd.DataFrame(data)
    tags = {f'layer_{i}.prob.1': f'Layer No. {i+1:02d}' for i in range(12)}
    df[list(tags.keys())] = (df[list(tags.keys())] > 0.5).astype(int)
    df['num_adapters'] = df[list(tags.keys())].sum(axis=1)
    return df.rename(columns=tags)


def _plot_norm(ax: plt.Axes, data):

    # Get the max and min in both coords.
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()

    # Fit a normal distribution to the data:
    mu, std = norm.fit(data)

    # Draw a central line.
    ax.vlines([mu], ymin, ymax, linestyles='dashed', linewidth=2)
    x = np.linspace(xmin, xmax, 100)
    p = norm.pdf(x, mu, std)
    ax.plot(x, p, 'k', linewidth=2)


def create_subplots(
    dfs: Dict[str, pd.DataFrame],
    title: str,
    sharex=False,
    sharey=False,
    hspace=0.3,
    wspace=0.3
):

    fig, axis = plt.subplots(
        3, 3, sharex=sharex, sharey=sharey, figsize=(14, 14), dpi=92
    )
    axis = list(axis.flat)
    fig.suptitle(title)
    for ax, (task_name, df) in zip(axis, dfs.items()):
        if len(df) == 0:
            ax.set_axis_off()
            continue
        yield task_name, df.sort_values('seed'), ax
    plt.subplots_adjust(wspace=wspace, hspace=hspace)


def _plot_baseline_hist(dfs: Dict[str, pd.DataFrame]):
    axis = []
    ymax = None
    for task_name, df, ax in create_subplots(dfs, "Baseline Performance"):
        df = df[df['baseline']]
        perf_col = PER_TASK[task_name][2]
        ax.set_title(PER_TASK[task_name][0])
        ax.set_ylabel(PER_TASK[task_name][3])
        ax.grid()
        n, _, _ = ax.hist(df[perf_col], bins=11, alpha=0.6, color='g', density=True)
        ymax = max(n) if ymax is None else max(max(n), ymax)
        _plot_norm(ax, df[perf_col])
    for ax in axis:
        ax.set_ylim((0, ymax))


def _plot_params_vs_performance(
    dfs: Dict[str, pd.DataFrame],
    skip_connections=False,
    skip_connections_training_only=False,
    switch_regularization=None,
    sharex=True
):

    if skip_connections:
        if switch_regularization is not None:
            title = "Switches with the skip-connections and regularization."
        else:
            title = "Switches with the skip-connections."
    elif skip_connections_training_only:
        title = "Switches without the skip-connections at training only."
    else:
        title = "Switches without the skip-connections."

    for task_name, df, ax in create_subplots(dfs, title, sharex=sharex):
        ax.grid()
        ax.set_title(PER_TASK[task_name][0])
        ax.set_xlabel("Number of Parameters")
        ax.set_ylabel(PER_TASK[task_name][3])
        perf = PER_TASK[task_name][2]

        df_b = df[df['baseline']]
        df_s = df[~df['baseline']]

        if skip_connections:
            df_s = df_s[~df_s['without_skip_connections']]
        else:
            df_s = df_s[df_s['without_skip_connections']]

        if skip_connections_training_only:
            df_s = df_s[df_s['without_skip_connections_training_only']]
        else:
            df_s = df_s[~df_s['without_skip_connections_training_only']]

        if switch_regularization is None:
            df_s = df_s[df_s['switch_regularization'] == 'None']
        else:
            df_s = df_s[df_s['switch_regularization'] == switch_regularization]

        df_s = df_s.sort_values('seed')
        df_b = df_b.sort_values('seed')
        print(df_s)
        print(df_b)
        ax.plot(df_s['required_params'], df_s[perf].values - df_b[perf].values, 'bo')


def _plot_matshow(ax: plt.Axes, arr, fill=False):

    for i in range(arr.shape[1]):
        for j in range(arr.shape[0]):
            if arr[j, i]:
                ax.add_patch(
                    mpl.patches.Rectangle(
                        (i + 0.5, j), 1, 1,
                        hatch="////\\\\\\\\",
                        color='k',
                        fill=fill,
                        snap=True,
                        lw=0.0
                    )
                )

    # This is very hack-ish
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticks([i + 1.5 for i in range(arr.shape[1] - 1)], minor=True)
    ax.set_xticks([i + 1 for i in range(arr.shape[1])])

    ax.grid(which='minor')
    ax.set_xlim((0.5, arr.shape[1] + 0.5))
    ax.set_ylim((0, arr.shape[0]))


def _plot_switch_results(
    dfs: Dict[str, pd.DataFrame], skip_conn=False, skip_conn_to=False, reg=None, fill=False
):

    if skip_conn:
        if reg is not None:
            title = "Switches configuration with the skip-connections and regularization."
        else:
            title = "Switches configuration with the skip-connections."
    elif skip_conn_to:
        title = "Switches configuration without the skip-connections at training only."
    else:
        title = "Switches configuration without the skip-connections."

    for task_name, df, ax in create_subplots(dfs, title, wspace=0.1, hspace=0.4):
        df_s = df[~df['baseline']]

        if skip_conn:
            df_s = df_s[~df_s['without_skip_connections']]
        else:
            df_s = df_s[df_s['without_skip_connections']]

        if skip_conn_to:
            df_s = df_s[df_s['without_skip_connections_training_only']]
        else:
            df_s = df_s[~df_s['without_skip_connections_training_only']]

        if reg is None:
            df_s = df_s[df_s['switch_regularization'] == 'None']
        else:
            df_s = df_s[df_s['switch_regularization'] == reg]

        tags = [f'layer_{i}.prob.1' for i in range(12)]
        df = df_s[tags]
        #df['total'] = df.sum(1)
        #df = df.sort_values(['total'] + list(reversed(tags)))[tags]

        _plot_matshow(ax, (df > 0.5).to_numpy(), fill=fill)

        # ax.matshow(df > 0.5, aspect="auto")
        ax.set_title(PER_TASK[task_name][0])
        ax.set_xlabel('Layer No.')
        ax.xaxis.set_ticks_position('bottom')


def _plot_performance_all(dfs: Dict[str, pd.DataFrame], sharex=True):
    title = "All switches experiments"
    for task_name, df, ax in create_subplots(dfs, title, sharex=sharex):
        ax.grid()
        ax.set_title(PER_TASK[task_name][0])
        ax.set_xlabel("Number of Parameters")
        ax.set_ylabel(PER_TASK[task_name][3])
        perf = PER_TASK[task_name][2]

        df_b = df[df['baseline']]
        df_s = df[~df['baseline']]

        sel_without_sc = df_s['without_skip_connections']
        sel_without_sc_to = sel_without_sc & df_s['without_skip_connections_training_only']
        sel_without_sc_full= sel_without_sc & ~df_s['without_skip_connections_training_only']

        sel_with_sc = ~sel_without_sc
        sel_reg_none = df_s['switch_regularization'] == 'None'
        sel_reg = (df_s['switch_regularization'] == 'square')

        df_with_sc = df_s[sel_with_sc & sel_reg_none].sort_values('seed')
        df_without_sc_to = df_s[sel_without_sc_to & sel_reg_none].sort_values('seed')
        df_without_sc_full = df_s[sel_without_sc_full & sel_reg_none].sort_values('seed')
        df_reg = df_s[sel_with_sc & sel_reg].sort_values('seed')

        ax.plot(
            df_without_sc_to['required_params'],
            df_without_sc_to[perf].values - df_b[perf].values, 'b3',
            mfc='none',
            label="without skip-connections at training only"
        )

        ax.plot(
            df_without_sc_full['required_params'],
            df_without_sc_full[perf].values - df_b[perf].values, 'r+',
            mfc='none',
            label="without skip-connections"
        )

        ax.plot(
            df_reg['required_params'],
            df_reg[perf].values - df_b[perf].values, 'gx',
            mfc='none',
            label="with regularization"
        )
        ax.legend()


def fixed(args):
    df = collect_data(
        args.wandb_project,
        limit=None,
        task_name=args.task_name,
        seed=args.seed,
        low_resources=args.low_resources,
        without_skip_connections=args.without_skip_connections,
        with_regularization=args.with_regularization,
        without_skip_connections_training_only=args.without_skip_connections_training_only
    )

    assert len(df) == 1, "We need only one entry!!"

    tags = [f'layer_{i}.prob.1' for i in range(12)]
    adapters_at = []
    values = (df[tags] > 0.5).values[0].tolist()
    for i, v in enumerate(values):
        if v:
            adapters_at.append(i)

    if args.task_name == 'squad':
        qa_args = QAArgs(
            baseline=False,
            use_switches=False,
            adapters_at=adapters_at,
            adapter_drop_skip_connections=args.without_skip_connections,
            adapter_drop_skip_connections_training_only=args.without_skip_connections_training_only,
            **qa_default_args(args)
        )
        main_qa(qa_args)
    else:
        glue_args = GLUEArgs(
            baseline=False,
            use_switches=False,
            adapters_at=adapters_at,
            adapter_drop_skip_connections=args.without_skip_connections,
            adapter_drop_skip_connections_training_only=args.without_skip_connections_training_only,
            **default_args(args)
        )
        main_glue(glue_args)


def plots(args):
    """Recreate the plots in the paper and more"""

    df = collect_data(
        args.wandb_project,
        limit=None,
        task_name=args.task_name,
        low_resources=args.low_resources,
        seed=args.seed
    )

    # Split dataframe per-task
    dfs = {task: df[df['task_name'] == task] for task in PER_TASK}

    # First graph: Histogram of baseline performance.
    _plot_baseline_hist(dfs)
    _plot_params_vs_performance(dfs)
    _plot_params_vs_performance(dfs, skip_connections_training_only=True)
    _plot_params_vs_performance(dfs, skip_connections=True, switch_regularization='square')
    _plot_performance_all(dfs)

    # Plot swithc configurations.
    _plot_switch_results(dfs, fill=True)
    _plot_switch_results(dfs, skip_conn_to=True, fill=True)
    _plot_switch_results(dfs, skip_conn=True, reg='square', fill=True)

    plt.show()


def default_args(args):
    params = {
        'task_name': args.task_name,
        'train_adapter': True,
        'do_train': True,
        'do_eval': True,
        'do_predict': True,
        'per_device_train_batch_size': args.batch_size,
        'per_device_eval_batch_size': args.batch_size,
        'output_dir': f'final_output/adaptable-adapter/{args.task_name}_{args.seed}_{int(time.time() * 1e7)}',
        'overwrite_output_dir': True,
        'seed': args.seed,
        'logging_steps': 20,
        'low_resources': args.low_resources,
        'save_total_limit': 2,
        'evaluation_strategy': 'epoch',
        'learning_rate': 1e-4,
        'num_train_epochs': 20,
        'load_best_model_at_end': True,
        'metric_for_best_model': 'eval_accuracy',
        'lr_for_switches': args.lr_for_switches,
        'lr_for_rational_activations': 0.01,
        'switch_regularization_weight': 0.01,
        'switch_inputs': ['minimal:identity', 'pfeiffer:rational:one']
    }

    if hasattr(args, 'adapter_non_linearity'):
        params['switch_inputs'] =  [
            'minimal:identity',
            f'pfeiffer:{args.adapter_non_linearity}'
        ]

    return params

def qa_default_args(args):
    params = {
        'task_name': args.task_name,
        #'validation_file': 'data/dev-v1.1.json',
        #'train_file': 'data/train-v1.1.json',
        #'test_file': 'data/quoref-train-dev-v0.1/quoref-dev-v0.1.json',
        'train_adapter': True,
        'do_train': True,
        'do_eval': True,
        'do_predict': True,
        'per_device_train_batch_size': args.batch_size,
        'per_device_eval_batch_size': args.batch_size,
        'output_dir': f'final_output/adaptable-adapter/{args.task_name}_{args.seed}_{int(time.time() * 1e7)}',
        'overwrite_output_dir': True,
        'seed': args.seed,
        'logging_steps': 20,
        'low_resources': args.low_resources,
        'save_total_limit': 2,
        'evaluation_strategy': 'epoch',
        'learning_rate': 1e-4,
        'num_train_epochs': 10,
        'max_seq_length': 384 ,
        'doc_stride': 128, 
        'load_best_model_at_end': True,
        'lr_for_switches': args.lr_for_switches,
        'lr_for_rational_activations': 0.01,
        'switch_regularization_weight': 0.01,
        'switch_inputs': ['minimal:identity', 'pfeiffer:rational:one']
    }

    if hasattr(args, 'adapter_non_linearity'):
        params['switch_inputs'] =  [
            'minimal:identity',
            f'pfeiffer:{args.adapter_non_linearity}'
        ]

    return params

def switches(args):
    if args.task_name == 'squad':
        qa_args = QAArgs(
            baseline=False,
            use_switches=True,
            adapter_drop_skip_connections=args.without_skip_connections,
            adapter_drop_skip_connections_training_only=args.without_skip_connections_training_only,
            switch_regularization='square' if args.with_regularization else None,
            switch_regularization_inputs_costs=[0, 2],
            **qa_default_args(args)
        )
        main_qa(qa_args)
    else:
        glue_args = GLUEArgs(
            baseline=False,
            use_switches=True,
            adapter_drop_skip_connections=args.without_skip_connections,
            adapter_drop_skip_connections_training_only=args.without_skip_connections_training_only,
            switch_regularization='square' if args.with_regularization else None,
            switch_regularization_inputs_costs=[0, 2],
            **default_args(args)
        )
        main_glue(glue_args)


def baseline(args):
    if args.task_name == 'squad':
        qa_args = QAArgs(
            baseline=True,
            baseline_bert=args.bert_only,
            baseline_leave_out_all=args.leave_out_all,
            **qa_default_args(args)
        )
        main_qa(qa_args)
    else:
        glue_args = GLUEArgs(
            baseline=True,
            baseline_bert=args.bert_only,
            baseline_leave_out_all=args.leave_out_all,
            **default_args(args)
        )
        main_glue(glue_args)

def drop(args):
    glue_args = GLUEArgs(
            baseline=True,
            baseline_bert=args.bert_only,
            baseline_leave_out_all=args.leave_out_all,
            **qa_default_args(args)
        )
    main_drop_glue(glue_args)

def _single_table(df, item=None):
    df_mean = df.groupby('task_name').mean()
    df_std = df.groupby('task_name').std()
    if item is None:
        item = {}

    for mean, std in zip(df_mean.itertuples(), df_std.itertuples()):
        assert mean[0] == std[0], "Both should have the same task_name."
        task_name = mean[0]

        perf = PER_TASK[task_name][2]

        # Fancy title and value.
        title = f"{task_name} ({perf})"
        value = f"{getattr(mean, perf):.2f} ± {getattr(std, perf):.2f}"
        item[title] = value
    return item


def _mean_and_std(df, a='task_name', cols=None):
    # From a dataframe return a Dataframe with mean +- std after grouping by task_name
    dfg = df.groupby(a)[list(cols)]
    df_m = dfg.mean().applymap(lambda x: f"{x:.2f}")
    df_s = dfg.std().applymap(lambda x: f"{x:.2f}")
    return pd.DataFrame({c: df_m[c].str.cat(df_s[c], sep=" ± ") for c in cols})


def _mean_to_perc(df, a='task_name', cols=None):
    # From a dataframe return a Dataframe with mean +- std after grouping by task_name
    dfg = df.groupby(a)[cols]
    df_m = dfg.mean().applymap(lambda x: f"{100 * x:03.2f} %")
    return pd.DataFrame({c: df_m[c] for c in cols})


def _print_table(df, index=False):
    if isinstance(df, pd.DataFrame):
        df = df.reindex(sorted(df.columns), axis=1)
    print(df.to_markdown(index=index))
    print("\n")


def _fix_prob(df):
    tags = {f'layer_{i}.prob.1': f'Layer # {i+1:02d}' for i in range(12)}
    df[list(tags.keys())] = (df[list(tags.keys())] > 0.5).astype(int)
    df = df.groupby('task_name').mean()[list(tags.keys())]
    return df.rename(columns=tags)


def _only_switches_mode(df):
    return df[reduce(or_, (df['mode'] == m for m in MODES_SWITCHES))]


def _tables_performance_per_required_params(df):

    for mode in MODES_SWITCHES:

        df_ = df[df['mode'] == mode]

        if len(df_) == 0:
            continue

        columns = {
            "perf_eval": f"Performance {mode} across tasks (best evaluation model).",
            "perf_test": f"Performance {mode} across tasks (test split).",
        }
        df_ = _mean_and_std(
            df_, a=['task_name', 'required_params'], cols=columns.keys()
        )

        for col in columns:
            print(f"\n### {columns[col]}\n")
            _print_table(
                df_[col].unstack('task_name').rename(
                    mapper=lambda x: f"{x} ({PER_TASK[x][2]})", axis='columns'
                ),
                index=True
            )
            print("\n")


def _tables_switch(df):

    _tables_performance_per_required_params(df)

    columns = {
        "num_adapters": "Total number of layers selecting adapters"
    }
    df_ = _mean_and_std(_only_switches_mode(df), a=['task_name', 'mode'], cols=columns.keys())
    for col in columns:
        print(f"\n## {columns[col]}\n")
        _print_table(df_[col].unstack('task_name'), index=True)
        print("\n")

    return
    print(f"# How often are adapters selected per layer {mode}?")
    _print_table(_mean_to_perc(df, cols=list(tags.values())).T, index=True)


def _tables_runtime(df: pd.DataFrame):
    columns = {
        "eval/final_runtime": "Runtime in seconds (eval)",
        "eval/final_samples_per_second": "Samples per second (eval)",
        "test/runtime": "Runtime (test) in seconds",
        "test/samples_per_second": "Samples per second (test)"
    }
    df = _mean_and_std(df, a=['task_name', 'mode'], cols=columns.keys())
    for col in columns:
        print(f"\n## {columns[col]}\n")
        _print_table(df[col].unstack('task_name'), index=True)
        print("\n")


def _tables_performance(df: pd.DataFrame):

    columns = {
        "perf_eval": "Performance across tasks (best evaluation model).",
        "perf_test": "Performance across tasks (test split).",
    }
    df = _mean_and_std(df, a=['task_name', 'mode'], cols=columns.keys())

    for col in columns:
        print(f"\n## {columns[col]}\n")
        _print_table(
            df[col].unstack('task_name').rename(
                mapper=lambda x: f"{x} ({PER_TASK[x][2]})", axis='columns'
            ),
            index=True
        )
        print("\n")


def tables(args):
    """Reproduce the tables in the paper (and more)"""

    df = collect_data(
        args.wandb_project,
        limit=None,
        task_name=args.task_name,
        seed=args.seed,
        low_resources=args.low_resources
    )

    # First the runtimes.
    _tables_runtime(df)

    # Performance.
    _tables_performance(df)

    # Results of switches.
    _tables_switch(df)

    if args.switches:
        print(df.columns)
        df_ = df[~df['baseline']]
        df_ = df_[df_['without_skip_connections']]
        df_ = df_[df_['without_skip_connections_training_only']]
        df_ = df_[df_['switch_regularization'] == 'None']
        _tables_switch(df_, "without skip-connections at training")

        # Switch with regularization
        df_ = df[~df['baseline']]
        df_ = df_[~df_['without_skip_connections']]
        df_ = df_[df_['switch_regularization'] == 'square']
        _tables_switch(df_, "with regularizations")



class EnvDefault(argparse.Action):
    def __init__(self, envvar, required=True, default=None, **kwargs):
        if envvar in os.environ:
            default = os.environ[envvar]
        if required and default:
            required = False
        super().__init__(default=default, required=required, **kwargs)

    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, values)


def _add_id_arguments(parser, single=True):
    if single:
        parser.add_argument('--task_name', type=str, required=True)
        parser.add_argument('--seed', type=int, required=True)
    else:
        parser.add_argument('--seed', type=int, nargs='+')
        parser.add_argument('--task_name', type=str, nargs='+')
    parser.add_argument('--low_resources', type=int, default=None)
    parser.add_argument('--lr_for_switches', type=float, default=0.05)


def _add_train_arguments(parser):
    parser.add_argument(
        '--batch_size',
        type=int,
        default=32,
        help="Set the 'per_device_{train,eval}_batch_size' argument."
    )


if __name__ == "__main__":

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
        level=logging.INFO,
        force=True
    )

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--wandb_project',
        action=EnvDefault,
        envvar='WANDB_PROJECT',
        default="huggingface"
    )

    # We need some subparsers.
    subparsers = parser.add_subparsers(required=True)

    # Baseline.
    baseline_p = subparsers.add_parser('baseline')
    _add_id_arguments(baseline_p)
    _add_train_arguments(baseline_p)
    baseline_p.add_argument('--bert_only', action='store_true')
    baseline_p.add_argument('--leave_out_all', action='store_true')
    baseline_p.set_defaults(func=baseline)

    # Drop.
    dbaseline_p = subparsers.add_parser('drop')
    _add_id_arguments(dbaseline_p)
    _add_train_arguments(dbaseline_p)
    dbaseline_p.add_argument('--bert_only', action='store_true')
    dbaseline_p.add_argument('--leave_out_all', action='store_true')
    dbaseline_p.set_defaults(func=drop)

    # Switches.
    switches_p = subparsers.add_parser('switches')
    _add_id_arguments(switches_p)
    _add_train_arguments(switches_p)
    switches_p.add_argument('--adapter_non_linearity', type=str, default='rational:one')
    switches_p.add_argument('--without_skip_connections', action='store_true')
    switches_p.add_argument(
        '--without_skip_connections_training_only', action='store_true'
    )
    switches_p.add_argument('--with_regularization', action='store_true')
    switches_p.set_defaults(func=switches)

    # Tables
    tables_p = subparsers.add_parser('tables')
    _add_id_arguments(tables_p, single=False)
    tables_p.add_argument('--baseline', action='store_true')
    tables_p.add_argument('--switches', action='store_true')
    tables_p.set_defaults(func=tables)

    # Plots
    plots_p = subparsers.add_parser('plots')
    _add_id_arguments(plots_p, single=False)
    plots_p.set_defaults(func=plots)

    # Parser to get the switch results.
    fixed_p = subparsers.add_parser('fixed')
    _add_id_arguments(fixed_p)
    _add_train_arguments(fixed_p)
    fixed_p.add_argument('--without_skip_connections', action='store_true')
    fixed_p.add_argument(
        '--without_skip_connections_training_only', action='store_true'
    )
    fixed_p.add_argument('--with_regularization', action='store_true')
    fixed_p.set_defaults(func=fixed)

    # Parse and call the right function.
    try:
        args = parser.parse_args()
    except TypeError:
        parser.print_help()
        exit(-1)
    args.func(args)
