import json
from pathlib import Path
from typing import Tuple, Dict

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

from analysis.util import (
    selective_prediction,
    load_predictions_with_scores,
    infer_base_data_path,
    load_results_table,
    CLASSES,
    MODEL_SHORT_NAMES,
    TASK_SHORT_NAMES,
    METRIC_SHORT_NAMES,
    EXPERIMENT_SHORT_NAMES,
    ANSWER_METRICS
)

TASKS = [
    'qasper',
    'natural_questions',
    'evidence_inference',
    'wice',
    'contract_nli'
]

EXPERIMENT_ORDER = [
    EXPERIMENT_SHORT_NAMES[experiment_name]
    for experiment_name in ['full-post-hoc', 'full-citation']
]

def main():
    """
    For each run, check the correlation between the number of required evidence
    passages and (1) attributability (2) evidence F1
    :return:
    """
    # Load results table that contains all hashes
    results_table = load_results_table()
    # Adapt descriptions to short names in results table
    results_table['description'] = results_table['description'].apply(
        lambda x: EXPERIMENT_SHORT_NAMES[x]
    )
    # Use only certain models and experiments
    results_table = results_table.loc[
        results_table['model'].isin(MODEL_SHORT_NAMES)
    ]
    results_table = results_table.loc[
        results_table['description'].isin(EXPERIMENT_ORDER)
    ]
    results_table = results_table.loc[
        results_table['task'].isin(TASKS)
    ]

    # Extract hashes and metrics to evaluate
    correlation_results_table = {
        'task': [],
        'model': [],
        'description': [],
        'metric_name': [],
        'value': []
    }
    for row in results_table.iterrows():
        row = row[1]
        hash = row['hash']
        metric_name = row['answer_metric_name']
        task_name = row['task']
        correlation_results = compute_n_evidence_correlations(
            hash,
            task_name,
            metric_name
        )
        for metric_name, value in correlation_results.items():
            correlation_results_table['task'].append(task_name)
            correlation_results_table['model'].append(row['model'])
            correlation_results_table['description'].append(row['description'])
            correlation_results_table['metric_name'].append(metric_name)
            correlation_results_table['value'].append(value)

    correlation_results_table = pd.DataFrame(correlation_results_table)

    # Pivot to get p and r as separate columns
    correlation_results_table = correlation_results_table.pivot(
        index=['model', 'description'],
        columns=['task', 'metric_name'],
        values='value'
    )

    pass

def compute_n_evidence_correlations(
        hash: str,
        task_name: str,
        metric_name: str
) -> Dict[str, float]:
    df = load_predictions_with_scores(hash, metric_name, eval_idx='latest')
    if 'Attribution' in df.columns:
        # Use answerable predictions only, as attribution is not defined for non-answerable
        df = df.loc[
            df['Predicted Answerability'] == 1
            ]

        df['n_evidence'] = df['Gold Extraction Node Idxs'].apply(
            lambda x: np.mean([len(l) for l in x])
        )

        correlation_n_evidence_evidence_f1 = pearsonr(
            df['n_evidence'],
            df['Evidence F1']
        )

        correlation_n_evidence_attribution = pearsonr(
            df['n_evidence'],
            df['Attribution']
        )


    else:
        print(f'Attribution data missing for hash {hash}')
        correlation_n_evidence_evidence_f1 = (np.nan, np.nan)
        correlation_n_evidence_attribution = (np.nan, np.nan)

    return {
        'correlation_n_evidence_evidence_f1': correlation_n_evidence_evidence_f1[0],
        'p_n_evidence_evidence_f1': correlation_n_evidence_evidence_f1[1],
        'correlation_n_evidence_attribution': correlation_n_evidence_attribution[0],
        'p_n_evidence_attribution': correlation_n_evidence_attribution[1]
    }


if __name__ == '__main__':
    main()