import argparse
import logging
import traceback
import numpy as np
from pandas import DataFrame
from pathlib import Path
from collections import Counter, defaultdict
from itertools import combinations

from evaluator import Evaluator
from evaluator import constants


class Tester:
    def __init__(self, models, tests, figures_dir: Path, use_latex: bool):
        self.models = sorted(models)
        self.evaluator = Evaluator(models, tests)
        self.save_dir = figures_dir
        self.use_latex = use_latex
        self.invariance_header = ['Model', 'Acc.', 'Original Acc.', 'Perturbed Acc.', 'Self-cons.', 'Comp. Acc.']
        self.negation_header = ['Model', 'Acc.', 'Original Acc.', 'Perturbed Acc.', 'Self-cons.', 'Comp. Acc.']
        self.perturbation_header = ['Model', 'Original Acc.', 'Perturbed Acc.', 'Self-cons.', 'Comp. Acc.']
        self.masked_header = ['Model', 'Original Acc.', 'Self-cons.']
        self.sigma_accuracy_header = ['Model', 'Original Acc.', r'$\sigma$ 3 Acc.', r'$\sigma$ 6 Acc.', r'$\sigma$ 9 Acc.']
        self.sigma_comp_accuracy_header = ['Model', 'Original Acc.', r'$\sigma$ 3 Comp. Acc.', r'$\sigma$ 6 Comp. Acc.', r'$\sigma$ 9 Comp. Acc.']
        self.sigma_invariance_header = ['Model', r'$\sigma$ 3 Self-cons', r'$\sigma$ 6 Self-cons', r'$\sigma$ 9 Self-cons']
        self.sigma_consistency_header = ['Model', r'$\sigma$ 3 Self-cons', r'$\sigma$ 6 Self-cons', r'$\sigma$ 9 Self-cons']
        self.prediction_distribution_header = ['Model', 'yes', 'no', 'other']

    def _generate_label(self, label):
        return ' '.join(label.lower().capitalize().split('_'))

    def _to_table(self, pd_table, label):
        latex_table = pd_table.to_latex(
            index=False,
            na_rep='-',
            float_format=lambda x: f'{(100 * x):.2f}',
            index_names=True,
            bold_rows=True,
            escape=False,
            label=f'app:' + label,
            caption=self._generate_label(label),
            position='H'
        )
        return latex_table

    def _to_markdown(self, pd_table):
        md_table = pd_table.to_markdown(
            index=True,
            tablefmt="github",
            floatfmt=".3f"
        )
        return md_table

    def _print_table(self, pd_table, label):
        if self.use_latex:
            latex_table = self._to_table(pd_table, label=label)
            print(latex_table)
            outfile_name = Path(self.save_dir, 'latex_tables', label).with_suffix('.tex')
            outfile_name.parent.mkdir(parents=True, exist_ok=True)
            with outfile_name.open('w+') as out:
                out.write(latex_table)
            outfile_name = Path(self.save_dir, 'csv_tables', label).with_suffix('.tsv')
            outfile_name.parent.mkdir(parents=True, exist_ok=True)
            pd_table.to_csv(outfile_name, sep='\t', index=False)
        else:
            markdown_table = self._to_markdown(pd_table)
            print(self._generate_label(label))
            print(markdown_table)

    def _do_invariance_test(self, test_name):
        header = self.invariance_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            accuracy, table_counts['Acc.'][model] = self.evaluator.accuracy(model, test_name)
            base_accuracy, table_counts['Original Acc.'][model] = self.evaluator.base_accuracy(model, test_name)
            dual_accuracy, table_counts['Perturbed Acc.'][model] = self.evaluator.dual_accuracy(model, test_name)
            invariance, table_counts['Self-cons.'][model] = self.evaluator.invariance(model, test_name)
            comp_acc, table_counts['Comp. Acc.'][model] = self.evaluator.comprehensive_accuracy(model, test_name)
            row = [constants.PRETTY_MODEL_NAMES[model], accuracy, base_accuracy, dual_accuracy, invariance, comp_acc]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_ontology_test(self, test_name):
        header = ['Model', 'Acc.', 'Original Acc.', 'Perturbed Acc.', 'Positive Invariance',
                  'Negative Invariance', 'Self-cons.', 'Comp. Acc.']
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            accuracy, table_counts['Acc.'][model] = self.evaluator.accuracy(model, test_name)
            base_accuracy, table_counts['Original Acc.'][model] = self.evaluator.base_accuracy(model, test_name)
            dual_accuracy, table_counts['Perturbed Acc.'][model] = self.evaluator.dual_accuracy(model, test_name)
            positive_invariance, table_counts['Positive Invariance'][model] = self.evaluator.positive_invariance(model, test_name)
            negative_invariance, table_counts['Negative Invariance'][model] = self.evaluator.negative_invariance(model, test_name)
            invariance, table_counts['Self-cons.'][model] = self.evaluator.invariance(model, test_name)
            comp_acc, table_counts['Comp. Acc.'][model] = self.evaluator.comprehensive_accuracy(model, test_name)
            row = [constants.PRETTY_MODEL_NAMES[model], accuracy, base_accuracy, dual_accuracy, positive_invariance,
                   negative_invariance, invariance, comp_acc]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_directional_test(self, test_name):
        header = self.negation_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            accuracy, table_counts['Acc.'][model] = self.evaluator.accuracy(model, test_name)
            base_accuracy, table_counts['Original Acc.'][model] = self.evaluator.base_accuracy(model, test_name)
            dual_accuracy, table_counts['Perturbed Acc.'][model] = self.evaluator.dual_accuracy(model, test_name)
            consistency, table_counts['Self-cons.'][model] = self.evaluator.directional(model, test_name)
            comp_acc, table_counts['Comp. Acc.'][model] = self.evaluator.comprehensive_accuracy(model, test_name)
            row = [constants.PRETTY_MODEL_NAMES[model], accuracy, base_accuracy, dual_accuracy, consistency, comp_acc]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_predictions_distribution(self, test_name):
        header = self.prediction_distribution_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            dist = Counter(self.evaluator.models_test_preds[model][test_name].values())
            total = sum(dist.values())
            dist = defaultdict(int, {k: v/total for k, v in dist.items()})
            for val in ['yes', 'no', 'other']:
                table_counts[val][model] = total
            row = [constants.PRETTY_MODEL_NAMES[model], dist['yes'], dist['no'], 1 - dist['yes'] - dist['no']]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_perturbation(self, test_name, perturbation):
        header = self.perturbation_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            base_accuracy, table_counts['Original Acc.'][model] = self.evaluator.accuracy(model, test_name)
            perturbed_accuracy, table_counts['Perturbed Acc.'][model] = self.evaluator.accuracy(model, test_name + f'+{perturbation}')
            consistency, table_counts['Self-cons.'][model] = self.evaluator.perturbed_invariance(model, test_name, perturbation)
            comprehensive_acc, table_counts['Comp. Acc.'][model] = self.evaluator.perturbed_comprehensive_accuracy(model, test_name, perturbation)
            row = [constants.PRETTY_MODEL_NAMES[model], base_accuracy, perturbed_accuracy, consistency, comprehensive_acc]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_masked(self, test_name, perturbation):
        header = self.masked_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            base_accuracy, table_counts['Original Acc.'][model] = self.evaluator.accuracy(model, test_name)
            consistency, table_counts['Self-cons.'][model] = self.evaluator.masked_consistency(model, test_name, perturbation)
            row = [constants.PRETTY_MODEL_NAMES[model], base_accuracy, consistency]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_sigma_accuracy(self, test_name):
        header = self.sigma_accuracy_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            base_accuracy, table_counts['Original Acc.'][model] = self.evaluator.accuracy(model, test_name)
            sigma_3_perturbed_accuracy, table_counts[r'$\sigma$ 3 Acc.'][model] = self.evaluator.accuracy(model, test_name + '+sigma3')
            sigma_6_perturbed_accuracy, table_counts[r'$\sigma$ 6 Acc.'][model] = self.evaluator.accuracy(model, test_name + '+sigma6')
            sigma_9_perturbed_accuracy, table_counts[r'$\sigma$ 9 Acc.'][model] = self.evaluator.accuracy(model, test_name + '+sigma9')
            row = [constants.PRETTY_MODEL_NAMES[model], base_accuracy, sigma_3_perturbed_accuracy,
                   sigma_6_perturbed_accuracy, sigma_9_perturbed_accuracy]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_sigma_comp_accuracy(self, test_name):
        header = self.sigma_comp_accuracy_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            base_accuracy, table_counts['Original Acc.'][model] = self.evaluator.accuracy(model, test_name)
            sigma_3_perturbed_accuracy, table_counts[r'$\sigma$ 3 Comp. Acc.'][model] = self.evaluator.perturbed_comprehensive_accuracy(model, test_name, 'sigma3')
            sigma_6_perturbed_accuracy, table_counts[r'$\sigma$ 6 Comp. Acc.'][model] = self.evaluator.perturbed_comprehensive_accuracy(model, test_name, 'sigma6')
            sigma_9_perturbed_accuracy, table_counts[r'$\sigma$ 9 Comp. Acc.'][model] = self.evaluator.perturbed_comprehensive_accuracy(model, test_name, 'sigma9')
            row = [constants.PRETTY_MODEL_NAMES[model], base_accuracy, sigma_3_perturbed_accuracy,
                   sigma_6_perturbed_accuracy, sigma_9_perturbed_accuracy]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_sigma_invariance(self, test_name):
        header = self.sigma_invariance_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            sigma_3_perturbed_consistency, table_counts[r'$\sigma$ 3 Self-cons'][model] = self.evaluator.perturbed_invariance(model, test_name, 'sigma3')
            sigma_6_perturbed_consistency, table_counts[r'$\sigma$ 6 Self-cons'][model] = self.evaluator.perturbed_invariance(model, test_name, 'sigma6')
            sigma_9_perturbed_consistency, table_counts[r'$\sigma$ 9 Self-cons'][model] = self.evaluator.perturbed_invariance(model, test_name, 'sigma9')
            row = [constants.PRETTY_MODEL_NAMES[model], sigma_3_perturbed_consistency, sigma_6_perturbed_consistency,
                   sigma_9_perturbed_consistency]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts

    def _do_masked_sigma_consistency(self, test_name):
        header = self.sigma_consistency_header
        table = list()
        table_counts = defaultdict(dict)
        for model in self.models:
            sigma_3_perturbed_consistency, table_counts[r'$\sigma$ 3 Self-cons'][model] = self.evaluator.masked_consistency(model, test_name, 'sigma3')
            sigma_6_perturbed_consistency, table_counts[r'$\sigma$ 6 Self-cons'][model] = self.evaluator.masked_consistency(model, test_name, 'sigma6')
            sigma_9_perturbed_consistency, table_counts[r'$\sigma$ 9 Self-cons'][model] = self.evaluator.masked_consistency(model, test_name, 'sigma9')
            row = [constants.PRETTY_MODEL_NAMES[model], sigma_3_perturbed_consistency, sigma_6_perturbed_consistency,
                   sigma_9_perturbed_consistency]
            table.append(row)
        pd_table = DataFrame(table, columns=header)
        return pd_table, table_counts


    def attribute_verification_antonym_test(self):
        test_name = 'attribute_verification_antonym_test'
        label = 'attribute_antonym_directional_expectation'
        pd_table, _ = self._do_directional_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_conjunction_template_test(self):
        test_name = 'object_verification_conjunction_template_test'
        label = 'conjunction_phrasal_invariance'
        pd_table, _ = self._do_invariance_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_conjunction_negation_test(self):
        test_name = 'object_verification_conjunction_negation_test'
        label = 'conjunction_negation_directional_test'
        pd_table, _ = self._do_directional_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_conjunction_symmetric_test(self):
        test_name = 'object_verification_conjunction_symmetric_test'
        label = 'conjunction_symmetry_invariance'
        pd_table, _ = self._do_invariance_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_conjunction_perturbation(self):
        test_name = 'object_verification_conjunction_perturbation'
        sigma3_label = 'verification_conjunction_crop_with_light_blurring'
        sigma6_label = 'verification_conjunction_crop_with_medium_blurring'
        sigma9_label = 'verification_conjunction_crop_with_heavy_blurring'
        sigma3_pd_table, _ = self._do_perturbation(test_name, 'sigma3')
        self._print_table(sigma3_pd_table, sigma3_label)
        sigma6_pd_table, _ = self._do_perturbation(test_name, 'sigma6')
        self._print_table(sigma6_pd_table, sigma6_label)
        sigma9_pd_table, _ = self._do_perturbation(test_name, 'sigma9')
        self._print_table(sigma9_pd_table, sigma9_label)

        avg_label = 'verification_conjunction_image_average_context_perturbation'
        crop_label = 'verification_conjunction_image_crop_context_perturbation'
        sigma_acc_label = 'verification_conjunction_image_blur_context_perturbation_accuracy'
        sigma_inv_label = 'verification_conjunction_image_blur_context_perturbation_invariance'
        avg_pd_table, _ = self._do_perturbation(test_name, 'avg')
        self._print_table(avg_pd_table, avg_label)
        crop_pd_table, _ = self._do_perturbation(test_name, 'crop')
        self._print_table(crop_pd_table, crop_label)

        sigma_pd_table, _ = self._do_sigma_accuracy(test_name)
        self._print_table(sigma_pd_table, sigma_acc_label)
        sigma_inv_pd_table, _ = self._do_sigma_invariance(test_name)
        self._print_table(sigma_inv_pd_table, sigma_inv_label)
        return avg_pd_table, avg_label, crop_pd_table, crop_label, sigma_pd_table, sigma_acc_label, sigma_inv_pd_table, sigma_inv_label

    # def object_verification_conjunction_perturbation_masked(self):
    #     test_name = 'object_verification_conjunction_perturbation'
    #     avg_label = 'verification_conjunction_image_average_occlusion_perturbation'
    #     sigma_label = 'verification_conjunction_image_blur_occlusion_perturbation'
    #     masked_avg_pd_table, _ = self._do_masked(test_name, 'avg')
    #     self._print_table(masked_avg_pd_table, avg_label)
    #     masked_sigma_cons_pd_table, _ = self._do_masked_sigma_consistency(test_name)
    #     self._print_table(masked_sigma_cons_pd_table, sigma_label)
    #     return masked_avg_pd_table, avg_label, masked_sigma_cons_pd_table, sigma_label

    def object_verification_disjunction_template_test(self):
        test_name = 'object_verification_disjunction_template_test'
        label = 'disjunction_phrasal_invariance'
        pd_table, _ = self._do_invariance_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_disjunction_negation_test(self):
        test_name = 'object_verification_disjunction_negation_test'
        label = 'disjunction_negation_directional_consistency'
        pd_table, _ = self._do_directional_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_disjunction_symmetric_test(self):
        test_name = 'object_verification_disjunction_symmetric_test'
        label = 'disjunction_symmetry_invariance'
        pd_table, _ = self._do_invariance_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_disjunction_perturbation(self):
        test_name = 'object_verification_disjunction_perturbation'
        sigma3_label = 'verification_disjunction_crop_with_light_blurring'
        sigma6_label = 'verification_disjunction_crop_with_medium_blurring'
        sigma9_label = 'verification_disjunction_crop_with_heavy_blurring'
        sigma3_pd_table, _ = self._do_perturbation(test_name, 'sigma3')
        self._print_table(sigma3_pd_table, sigma3_label)
        sigma6_pd_table, _ = self._do_perturbation(test_name, 'sigma6')
        self._print_table(sigma6_pd_table, sigma6_label)
        sigma9_pd_table, _ = self._do_perturbation(test_name, 'sigma9')
        self._print_table(sigma9_pd_table, sigma9_label)

        avg_label = 'verification_disjunction_image_average_context_perturbation'
        crop_label = 'verification_disjunction_image_crop_context_perturbation'
        sigma_acc_label = 'verification_conjunction_image_blur_context_perturbation_accuracy'
        sigma_inv_label = 'verification_conjunction_image_blur_context_perturbation_invariance'
        avg_pd_table, _ = self._do_perturbation(test_name, 'avg')
        self._print_table(avg_pd_table, avg_label)
        crop_pd_table, _ = self._do_perturbation(test_name, 'crop')
        self._print_table(crop_pd_table, crop_label)
        sigma_pd_table, _ = self._do_sigma_accuracy(test_name)
        self._print_table(sigma_pd_table, sigma_acc_label)
        sigma_inv_pd_table, _ = self._do_sigma_invariance(test_name)
        self._print_table(sigma_inv_pd_table, sigma_inv_label)
        return avg_pd_table, avg_label, crop_pd_table, crop_label, sigma_pd_table, sigma_acc_label, sigma_inv_pd_table, sigma_inv_label

    # def object_verification_disjunction_perturbation_masked(self):
    #     test_name = 'object_verification_disjunction_perturbation'
    #     avg_label = 'verification_disjunction_image_average_occlusion_perturbation'
    #     sigma_label = 'verification_disjunction_image_blur_occlusion_perturbation'
    #     masked_avg_pd_table, _ = self._do_masked(test_name, 'avg')
    #     self._print_table(masked_avg_pd_table, avg_label)
    #     masked_sigma_cons_pd_table, _ = self._do_masked_sigma_consistency(test_name)
    #     self._print_table(masked_sigma_cons_pd_table, sigma_label)
    #     return masked_avg_pd_table, avg_label, masked_sigma_cons_pd_table, sigma_label

    def object_verification_template_test(self):
        test_name = 'object_verification_template_test'
        label = 'phrasal_invariance'
        pd_table, _ = self._do_invariance_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_negation_test(self):
        test_name = 'object_verification_negation_test'
        label = 'negation_directional_test'
        pd_table, _ = self._do_directional_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_hypernym_test(self):
        test_name = 'object_verification_hypernym_test'
        label = 'hypernym_invariance_test'
        pd_table, _ = self._do_invariance_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_hyponym_test(self):
        test_name = 'object_verification_hyponym_test'
        label = 'hyponym_invariance_test'
        pd_table, _ = self._do_invariance_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label

    def object_verification_ontology_ontology_test(self):
        test_name = 'object_verification_ontology_ontology_test'
        label = 'ontological_invariance_test'
        pd_table, _ = self._do_ontology_test(test_name)
        self._print_table(pd_table, label)
        return pd_table, label
        # header = ['Model', 'Acc.', 'Original Acc.', 'Perturbed Acc.', 'Hypernym Invariance', 'Hyponym Invariance']
        # table = list()
        # for model in self.models:
        #     accuracy, _ = self.evaluator.accuracy(model, test_name)
        #     base_accuracy, _ = self.evaluator.base_accuracy(model, test_name)
        #     dual_accuracy, _ = self.evaluator.dual_accuracy(model, test_name)
        #     results_dict, _ = self.evaluator.directional(model, test_name)
        #     hypernym_invariance = results_dict['hypernym_invariance']
        #     hyponym_invariance = results_dict['hyponym_invariance']
        #     row = [constants.PRETTY_MODEL_NAMES[model], accuracy, base_accuracy, dual_accuracy, hypernym_invariance,
        #            hyponym_invariance]
        #     table.append(row)
        # pd_table = DataFrame(table, columns=header)
        # self._print_table(pd_table, label)
        # return pd_table, label

    def object_verification_perturbation(self):
        test_name = 'object_verification_perturbation'
        sigma3_label = 'object_verification_crop_with_light_blurring'
        sigma6_label = 'object_verification_crop_with_medium_blurring'
        sigma9_label = 'object_verification_crop_with_heavy_blurring'
        sigma3_pd_table, _ = self._do_perturbation(test_name, 'sigma3')
        self._print_table(sigma3_pd_table, sigma3_label)
        sigma6_pd_table, _ = self._do_perturbation(test_name, 'sigma6')
        self._print_table(sigma6_pd_table, sigma6_label)
        sigma9_pd_table, _ = self._do_perturbation(test_name, 'sigma9')
        self._print_table(sigma9_pd_table, sigma9_label)

        avg_label = 'object_verification_average_image_context_perturbation'
        crop_label = 'object_verification_crop_image_context_perturbation'
        sigma_acc_label = 'verification_conjunction_blur_image_context_perturbation_accuracy'
        sigma_inv_label = 'verification_conjunction_blur_image_context_perturbation_invariance'
        avg_pd_table, _ = self._do_perturbation(test_name, 'avg')
        self._print_table(avg_pd_table, avg_label)
        crop_pd_table, _ = self._do_perturbation(test_name, 'crop')
        self._print_table(crop_pd_table, crop_label)
        sigma_pd_table, _ = self._do_sigma_accuracy(test_name)
        self._print_table(sigma_pd_table, sigma_acc_label)
        sigma_inv_pd_table, _ = self._do_sigma_invariance(test_name)
        self._print_table(sigma_inv_pd_table, sigma_inv_label)
        return avg_pd_table, avg_label, crop_pd_table, crop_label, sigma_pd_table, sigma_acc_label, sigma_inv_pd_table, sigma_inv_label

    # def object_verification_perturbation_masked(self):
    #     test_name = 'object_verification_perturbation'
    #     avg_label = 'object_verification_image_average_occlusion_perturbation'
    #     sigma_label = 'object_verification_image_blur_occlusion_perturbation'
    #     masked_avg_pd_table, _ = self._do_masked(test_name, 'avg')
    #     self._print_table(masked_avg_pd_table, avg_label)
    #     masked_sigma_cons_pd_table, _ = self._do_masked_sigma_consistency(test_name)
    #     self._print_table(masked_sigma_cons_pd_table, sigma_label)
    #     return masked_avg_pd_table, avg_label, masked_sigma_cons_pd_table, sigma_label

    def _get_array_from_counts(self, counts, datasets, metrics):
        arr = list()
        for dataset in datasets:
            sub_arr = list()
            for metric in metrics:
                sub_arr.append([counts[dataset][metric][model] for model in self.models])
            arr.append(sub_arr)
        return np.array(arr).T

    def _get_array_from_tables(self, tables, datasets, metrics):
        arr = list()
        for dataset in datasets:
            sub_arr = list()
            for metric in metrics:
                sub_arr.append(tables[dataset][metric].values)
            arr.append(sub_arr)
        return np.array(arr).T

    def _do_mean_results(self, group_name, header, label, results_func):
        all_tables = dict()
        count_dicts = dict()
        for inv_test in constants.TEST_GROUPS[group_name]:
            pd_table, count_dicts[inv_test] = results_func(inv_test)
            all_tables[inv_test] = pd_table

        for dataset1, dataset2 in combinations(all_tables.keys(), r=2):
            assert list(all_tables[dataset1].columns) == list(all_tables[dataset2].columns)
            assert list(all_tables[dataset1]['Model']) == list(all_tables[dataset2]['Model'])

        datasets = list(count_dicts.keys())
        metrics = header[1:]

        counts_arr = self._get_array_from_counts(count_dicts, datasets, metrics)
        weights_arr = counts_arr / counts_arr.sum(axis=2)[:, :, np.newaxis]
        results_arr = self._get_array_from_tables(all_tables, datasets, metrics)
        mean_arr = np.multiply(results_arr, weights_arr).sum(axis=2)

        pd_table = DataFrame(columns=header)
        for row, model in enumerate(self.models):
            entry = {'Model': constants.PRETTY_MODEL_NAMES[model]}
            for col, metric in enumerate(metrics):
                entry[metric] = mean_arr[row][col]
            pd_table = pd_table.append(entry, ignore_index=True)
        self._print_table(pd_table, label)
        return pd_table, label

    def _do_mean_perturbation_results(self, pertubations, header, label, results_func):
        all_tables = dict()
        count_dicts = dict()
        for inv_test in ['object_verification_conjunction_perturbation',
                         'object_verification_disjunction_perturbation',
                         'object_verification_perturbation']:
            for perturbation in pertubations:
                pd_table, count_dicts[inv_test] = results_func(inv_test, perturbation)
                all_tables[inv_test] = pd_table

        for dataset1, dataset2 in combinations(all_tables.keys(), r=2):
            assert list(all_tables[dataset1].columns) == list(all_tables[dataset2].columns)
            assert list(all_tables[dataset1]['Model']) == list(all_tables[dataset2]['Model'])

        datasets = list(count_dicts.keys())
        metrics = header[1:]

        counts_arr = self._get_array_from_counts(count_dicts, datasets, metrics)
        weights_arr = counts_arr / counts_arr.sum(axis=2)[:, :, np.newaxis]
        results_arr = self._get_array_from_tables(all_tables, datasets, metrics)
        mean_arr = np.multiply(results_arr, weights_arr).sum(axis=2)

        pd_table = DataFrame(columns=header)
        for row, model in enumerate(self.models):
            entry = {'Model': constants.PRETTY_MODEL_NAMES[model]}
            for col, metric in enumerate(metrics):
                entry[metric] = mean_arr[row][col]
            pd_table = pd_table.append(entry, ignore_index=True)
        self._print_table(pd_table, label)
        return pd_table, label

    def _do_mean_sigma_results(self, header, label, results_func):
        all_tables = dict()
        count_dicts = dict()
        for inv_test in ['object_verification_conjunction_perturbation',
                         'object_verification_disjunction_perturbation',
                         'object_verification_perturbation']:
            pd_table, count_dicts[inv_test] = results_func(inv_test)
            all_tables[inv_test] = pd_table

        for dataset1, dataset2 in combinations(all_tables.keys(), r=2):
            assert list(all_tables[dataset1].columns) == list(all_tables[dataset2].columns)
            assert list(all_tables[dataset1]['Model']) == list(all_tables[dataset2]['Model'])

        datasets = list(count_dicts.keys())
        metrics = header[1:]

        counts_arr = self._get_array_from_counts(count_dicts, datasets, metrics)
        weights_arr = counts_arr / counts_arr.sum(axis=2)[:, :, np.newaxis]
        results_arr = self._get_array_from_tables(all_tables, datasets, metrics)
        mean_arr = np.multiply(results_arr, weights_arr).sum(axis=2)

        pd_table = DataFrame(columns=header)
        for row, model in enumerate(self.models):
            entry = {'Model': constants.PRETTY_MODEL_NAMES[model]}
            for col, metric in enumerate(metrics):
                entry[metric] = mean_arr[row][col]
            pd_table = pd_table.append(entry, ignore_index=True)
        self._print_table(pd_table, label)
        return pd_table, label

    def mean_phrasal_invariance(self):
        # test_name = 'mean_phrasal_invariance'
        label = 'mean_phrasal_invariance'
        header = self.invariance_header
        return self._do_mean_results('phrasal_invariance', header, label, self._do_invariance_test)

    def mean_phrasal_prediction_distribution(self):
        # test_name = 'mean_phrasal_invariance'
        label = 'Rephrasing_invariance_prediction_distribution'
        header = self.prediction_distribution_header
        return self._do_mean_results('phrasal_prediction_distribution', header, label, self._do_predictions_distribution)

    def mean_ontological_consistency(self):
        # test_name = 'mean_phrasal_invariance'
        label = 'mean_ontological_consistency'
        header = self.perturbation_header
        return self._do_mean_results('ontological_consistency', header, label, self._do_invariance_test)

    def mean_antonym_consistency(self):
        # test_name = 'mean_phrasal_invariance'
        label = 'mean_antonym_consistency'
        header = self.perturbation_header
        return self._do_mean_results('antonym_consistency', header, label, self._do_directional_test)

    def mean_ontological_prediction_distribution(self):
        # test_name = 'mean_phrasal_invariance'
        label = 'ontological_prediction_distribution'
        header = self.prediction_distribution_header
        return self._do_mean_results('ontological_prediction_distribution', header, label, self._do_predictions_distribution)

    def mean_symmetric_prediction_distribution(self):
        label = 'symmetric_prediction_distribution'
        header = self.prediction_distribution_header
        return self._do_mean_results('symmetric_prediction_distribution', header, label, self._do_predictions_distribution)

    def mean_antonym_prediction_distribution(self):
        # test_name = 'mean_phrasal_invariance'
        label = 'antonym_prediction_distribution'
        header = self.prediction_distribution_header
        return self._do_mean_results('antonym_prediction_distribution', header, label, self._do_predictions_distribution)

    def mean_negation_prediction_distribution(self):
        # test_name = 'mean_phrasal_invariance'
        label = 'negation_prediction_distribution'
        header = self.prediction_distribution_header
        return self._do_mean_results('negation_prediction_distribution', header, label, self._do_predictions_distribution)

    def mean_visual_obfuscation_prediction_distribution(self):
        # test_name = 'mean_phrasal_invariance'
        label = 'visual_obfuscation_prediction_distribution'
        header = self.prediction_distribution_header
        return self._do_mean_results('visual_obfuscation_prediction_distribution', header, label, self._do_predictions_distribution)

    def mean_symmetry_invariance(self):
        # test_name = 'mean_symmetry_invariance'
        label = 'mean_symmetry_invariance'
        header = self.invariance_header
        return self._do_mean_results('symmetry_invariance', header, label, self._do_invariance_test)

    def mean_negation_consistency(self):
        label = 'mean_negation_directional_expectation'
        header = self.negation_header
        return self._do_mean_results('negation_consistency', header, label, self._do_directional_test)

    def mean_context_perturbation_invariance(self):
        label = 'mean_context_perturbation_invariance'
        header = self.perturbation_header
        return self._do_mean_perturbation_results(['crop', 'avg', 'sigma3', 'sigma6', 'sigma9'], header, label,
                                                  self._do_perturbation)

    def mean_context_occlusion_invariance(self):
        label = 'mean_context_occlusion_invariance'
        header = self.perturbation_header
        return self._do_mean_perturbation_results(['avg'], header, label, self._do_perturbation)

    def mean_context_crop_invariance(self):
        label = 'mean_context_crop_invariance'
        header = self.perturbation_header
        return self._do_mean_perturbation_results(['crop'], header, label, self._do_perturbation)

    def mean_context_blur_invariance(self):
        label = 'mean_context_blur_invariance'
        header = self.perturbation_header
        return self._do_mean_perturbation_results(['sigma3', 'sigma6', 'sigma9'], header, label, self._do_perturbation)

    def visual_obfuscation_invariance_with_light_blurring(self):
        label = 'visual_obfuscation_invariance_with_light_blurring'
        header = self.perturbation_header
        return self._do_mean_perturbation_results(['sigma3'], header, label, self._do_perturbation)

    def visual_obfuscation_invariance_with_medium_blurring(self):
        label = 'visual_obfuscation_invariance_with_medium_blurring'
        header = self.perturbation_header
        return self._do_mean_perturbation_results(['sigma6'], header, label, self._do_perturbation)

    def visual_obfuscation_invariance_with_heavy_blurring(self):
        label = 'visual_obfuscation_invariance_with_heavy_blurring'
        header = self.perturbation_header
        return self._do_mean_perturbation_results(['sigma9'], header, label, self._do_perturbation)

    # def mean_content_occlusion_consistency(self):
    #     label = 'mean_content_occlusion_consistency'
    #     header = self.masked_header
    #     return self._do_mean_perturbation_results(['avg', 'sigma3', 'sigma6', 'sigma9'], header, label, self._do_masked)

    def mean_sigma_context_perturbation_accuracy(self):
        label = 'mean_sigma_context_perturbation_accuracy'
        header = self.sigma_accuracy_header
        return self._do_mean_sigma_results(header, label, self._do_sigma_accuracy)

    def mean_sigma_context_perturbation_invariance(self):
        label = 'mean_sigma_context_perturbation_invariance'
        header = self.sigma_invariance_header
        return self._do_mean_sigma_results(header, label, self._do_sigma_invariance)

    def visual_obfuscation_comprehensive_accuracy_with_all_blurring(self):
        label = 'visual_obfuscation_comprehensive_accuracy_with_all_blurring'
        header = self.sigma_comp_accuracy_header
        return self._do_mean_sigma_results(header, label, self._do_sigma_comp_accuracy)

    # def mean_sigma_content_occlusion_consistency(self):
    #     label = 'mean_sigma_content_occlusion_consistency'
    #     header = self.sigma_consistency_header
    #     return self._do_mean_sigma_results(header, label, self._do_masked_sigma_consistency)


def main(models, tests, save_dir, use_latex):
    tester = Tester(models, tests, save_dir, use_latex=use_latex)
    for test in tests:
        try:
            getattr(tester, test)()
        except:
            print(f'Failed for test: {test}')
            print(traceback.format_exc())


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--models', type=str, nargs='*', default=set(),
                        help=f'models to evaluate from: {" ".join(sorted(constants.MODELS))}')
    parser.add_argument('--all-models', action='store_true',
                        help='This overrides the --models flag and evaluates all models.')
    parser.add_argument('--tests', type=str, nargs='*', default=set(),
                        help=f'tests for each model from: {" ".join(constants.FORMAL_TESTS)}')
    parser.add_argument('--all-tests', action='store_true',
                        help='This overrides the --tests flag and evaluates all tests.')
    parser.add_argument('--save-dir', type=str, default='evaluation_output',
                        help='Directory to save tables and figures to.')
    parser.add_argument('--latex', action='store_true',
                        help='Print latex tables instead of markdown.')

    args = parser.parse_args()
    models = set(args.models)
    if args.all_models:
        models = constants.MODELS
    if len(models) == 0:
        logging.info('No models entered. Nothing to do.')
        exit(0)
    bad_models = models.difference(constants.MODELS)
    if len(bad_models) > 0:
        raise ValueError('The following models are not available: %s' % ', '.join(bad_models))

    tests = set(args.tests)
    if args.all_tests:
        tests = constants.FORMAL_TESTS
    if len(tests) == 0:
        logging.info('No tests entered. Nothing to do.')
        exit(0)
    bad_tests = set(tests).difference(constants.FORMAL_TESTS)
    if len(bad_tests) > 0:
        raise ValueError('The following tests are not available: %s' % ', '.join(bad_tests))
    ordered_tests = list()
    for test in constants.FORMAL_TESTS:
        if test in tests:
            ordered_tests.append(test)

    save_dir = Path(args.save_dir).resolve()
    if not save_dir.exists():
        logging.info(f'Making figures directory: {save_dir.as_posix()}')
        save_dir.mkdir(parents=True)

    log_format = '%(asctime)s %(levelname)s: %(message)s'
    logging.basicConfig(format=log_format, level=logging.DEBUG)

    main(models, ordered_tests, save_dir, args.latex)
