import os

import dill
import pandas as pd

from project_root import ROOT_DIR
from xaiMetrics.constants import REFERENCE_BASED
from xaiMetrics.evalTools.apply_explainers_on_metrics import apply_explainers_on_metrics
from xaiMetrics.evalTools.explanations_to_scores import explanations_to_scores_and_eval, \
    explanations_to_scores_and_eval_stratified
from xaiMetrics.evalTools.utils.get_explanation_file_names import get_explanation_file_names, get_file_info
from xaiMetrics.explainer.wrappers.Lime import LimeExplainer
from xaiMetrics.explanations.FeatureImportanceExplanation import FeatureImportanceExplanationCollection
from xaiMetrics.metrics.wrappers.BARTScore import BARTScore
from xaiMetrics.metrics.wrappers.BertScore import BertScore
from xaiMetrics.metrics.wrappers.Rouge import RougeMetric
from xaiMetrics.metrics.wrappers.RandomScore import RandomScore

"""
Create new explanations for SummEval, apply the powermean and calculate the new correlations.
Explanations will be saved to xaiMetrics/outputs/raw_explanations
Graphics of p-w-distributions will be saved to xaiMetrics/outputs/experiment_graphs_pdf
Graphics of p-w-distributions will be saved to xaiMetrics/outputs/experiment_results
"""

explainers = {
    # 'ErasureExplainer': ErasureExplainer(),
    'LimeExplainer': LimeExplainer(),
    # 'ShapExplainer': ShapExplainer(),
    # 'InputMarginalizationExplainer': InputMarginalizationExplainer(delta=0.05),
    # 'RandomExplainer': RandomExplainer()
}

metrics = {
    # 'COMET': CometQE()
    # 'Transquest':TransQuest()
    #'BERTSCORE_REF_FREE_XLMR': BertScore(model_type="xlm-roberta-large", mode=REFERENCE_FREE),
     'BERTSCORE_REF_BASED_ROBERTA_DEFAULT': BertScore(mode=REFERENCE_BASED),
    'BARTScore': BARTScore(mode=REFERENCE_BASED, batch_size=128),
    # 'BERTSCORE_REF_FREE_XNLI': BertScore(model_type="joeddav/xlm-roberta-large-xnli", mode=REFERENCE_FREE,
    #                                    num_layers=16),
    # 'SentenceBLEU_REF_BASED': SentenceBleu(mode=REFERENCE_BASED),
    # 'XLMRSBERT': XlmrCosineSim(),
    # 'XMoverScore_No_Mapping': XMoverScore(),
}

dataset = "SummEval"
generate_explanations = True
explanationDirectory = os.path.join(ROOT_DIR, "xaiMetrics", "outputs", "raw_explanations")

lp = "en"
explanationFilenames = get_explanation_file_names(explanationDirectory, dataset, [lp], explainers, metrics)
correlations = []

summ_df = pd.read_json(
    os.path.join(ROOT_DIR, "xaiMetrics", "data", "cnndm", "SummEval.json"))

if generate_explanations:
    print("Applying Explainers and metrics")
    attributions = apply_explainers_on_metrics(summ_df,
                                               explainers=explainers,
                                               metrics=metrics,
                                               print_time=True,
                                               save_dir=explanationDirectory,
                                               dataset=dataset,
                                               lp=lp)

# combining with pmeans
attributions = {}
for filename in explanationFilenames:
    _, explainer, file_lp, metric = get_file_info(filename)
    if lp == file_lp:
        with open(filename, 'rb') as pickle_file:
            if explainer not in attributions:
                attributions[explainer] = {}
            attributions[explainer][metric] = FeatureImportanceExplanationCollection(dill.load(pickle_file))

summ_df_consistency = summ_df.copy()
summ_df_consistency["DA"] = [d["consistency"] for d in summ_df_consistency["expert_avg"].to_list()]
summ_df_coherence = summ_df.copy()
summ_df_coherence["DA"] = [d["coherence"] for d in summ_df_coherence["expert_avg"].to_list()]
summ_df_fluency = summ_df.copy()
summ_df_fluency["DA"] = [d["fluency"] for d in summ_df_fluency["expert_avg"].to_list()]
summ_df_relevance = summ_df.copy()
summ_df_relevance["DA"] = [d["relevance"] for d in summ_df_relevance["expert_avg"].to_list()]
correlations.append(
    explanations_to_scores_and_eval_stratified(summ_df_consistency, attributions,
                                               dataset=dataset + "-consistency-spearman", lp=lp, save_fig=True,
                                               only_hyp=False, kendall=False, sys_level=True, spearman=True))
correlations.append(
    explanations_to_scores_and_eval_stratified(summ_df_coherence, attributions, dataset=dataset + "-coherence-spearman",
                                               lp=lp, save_fig=True, only_hyp=False, kendall=False, sys_level=True,
                                               spearman=True))
correlations.append(
    explanations_to_scores_and_eval_stratified(summ_df_fluency, attributions, dataset=dataset + "-fluency-spearman",
                                               lp=lp, save_fig=True, only_hyp=False, kendall=False, sys_level=True,
                                               spearman=True))
correlations.append(
    explanations_to_scores_and_eval_stratified(summ_df_relevance, attributions, dataset=dataset + "-relevance-spearman",
                                               lp=lp, save_fig=True, only_hyp=False, kendall=False, sys_level=True,
                                               spearman=True))
correlations.append(
    explanations_to_scores_and_eval_stratified(summ_df_consistency, attributions, dataset=dataset + "-consistency",
                                               lp=lp, save_fig=True, only_hyp=False, kendall=True, sys_level=True,
                                               spearman=False))
correlations.append(
    explanations_to_scores_and_eval_stratified(summ_df_coherence, attributions, dataset=dataset + "-coherence", lp=lp,
                                               save_fig=True, only_hyp=False, kendall=True, sys_level=True,
                                               spearman=False))
correlations.append(
    explanations_to_scores_and_eval_stratified(summ_df_fluency, attributions, dataset=dataset + "-fluency", lp=lp,
                                               save_fig=True, only_hyp=False, kendall=True, sys_level=True,
                                               spearman=False))
correlations.append(
    explanations_to_scores_and_eval_stratified(summ_df_relevance, attributions, dataset=dataset + "-relevance", lp=lp,
                                               save_fig=True, only_hyp=False, kendall=True, sys_level=True,
                                               spearman=False))

df = pd.concat(correlations)
print(df.to_string())
print(df.groupby([df.metric, df.explainer]).mean().to_string())
