from scipy.stats import ttest_rel
import pandas as pd

from analysis.util import load_results_table

def main():
    table = load_results_table()

    # Get set of unique model names
    model_names = table['model'].unique()

    # Get set of unique conditions
    conditions = table['description'].unique()

    # for each model, get best condition
    for model_name in model_names:
        print('*'*80)
        print(f'Model: {model_name}')
        model_data = table.loc[table['model'] == model_name].reset_index(drop=True)
        # TODO
        best_condition = model_data.groupby('description')['attribution_metric'].mean().idxmax()
        print(f'Best condition: {best_condition}')
        best_condition_data = model_data.loc[
            (model_data['description'] == best_condition) & (model_data['task'] != 'govreport'),
            'attribution_metric'
        ]
        assert not any(pd.isna(best_condition_data))
        for condition in conditions:
            if condition == best_condition:
                continue

            condition_data = model_data.loc[
                (model_data['description'] == condition) & (model_data['task'] != 'govreport'),
                'attribution_metric'
            ]
            assert not any(pd.isna(condition_data))
            # TODO: assert that pairing is correct
            statistic = ttest_rel(
                best_condition_data,
                condition_data,
                alternative='greater'
            )

            print(f'Statistics for alternative hypothesis {best_condition} > {condition}: {statistic}')


if __name__ == '__main__':
    main()