import pathlib
from functools import partial

import IPython
import hydra
from typing import *
from itertools import cycle
import os

import omegaconf
import torch
import pytorch_lightning as pl
import torch.nn as nn
import pandas as pd
import numpy as np
# from loguru import logger
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import confusion_matrix, f1_score

from allennlp.predictors.predictor import Predictor
import allennlp_models.pair_classification

import logging
logger = logging.getLogger(__name__)


def _process_nli(row, _predictor):
    return pd.Series(max(zip(
        _predictor.predict(
            hypothesis=row['question'],
            premise=row['context']
        )['label_probs'],
        ['Entailment', 'Contradiction', 'Neutral'],
        [1, 0, 1]
    ), key=lambda t: t[0]), index=['probability', 'TE_label', 'predicted_label'])


@hydra.main(config_path='../Configs/LMBenchEval.yaml')
def main(config: omegaconf.dictconfig.DictConfig):
    predictor = Predictor.from_path(
        "https://storage.googleapis.com/allennlp-public-models/decomposable-attention-elmo-2020.04.09.tar.gz",
        "textual_entailment"
    )

    file_names = ['test', 'train', 'eval']

    for f in file_names:
        _path = pathlib.Path(config["benchmark_path"]) / f'{f}.csv'
        df: pd.DataFrame = pd.read_csv(_path).fillna('')

        out = df.apply(
            axis=1,
            func=partial(_process_nli, _predictor=predictor)
        )

        _acc = sum(out['predicted_label'] == df['label']) / (out.shape[0] * 1.0)
        _f1_score = f1_score(y_true=df['label'], y_pred=out['predicted_label'])
        _conf_matrix = pd.DataFrame(
            confusion_matrix(y_true=df['label'], y_pred=out['predicted_label']),
            columns=[0, 1],
            index=[0, 1],
        )

        logger.info(f'{f}_acc: {_acc}')
        logger.info(f'{f}_f1_macro: {_f1_score}')
        logger.info(f'{f}_conf_matrix: \n{_conf_matrix}')

        df[df['label'] != out['predicted_label']].apply(
            axis=1,
            func=lambda r: pd.Series({
                'fact': r['question'],
                'context': r['context'],
                'type': "False Negative" if r['label'] == 1 else "False Positive"
            })
        ).to_csv(f'{f}_errors.csv')


if __name__ == '__main__':
    main()