import numpy as np
# from rouge_score import rouge_scorer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

from typing import List, Dict
from .generation_metric import GenerationMetric
# from ctc_score.scorer import Scorer
from .ctc_score.factual_consistency_scorer import FactualConsistencyScorer



class CTCFactConsistency(GenerationMetric):
    """
    Calculates Rouge metric between model-generated texts and ground truth texts.
    """

    def __init__(self, depend=["greedy_texts"], set_align='D-mix-albert'):
        """
        Parameters:
            rouge_name (str): rouge metric type. Possible values:
                * rouge1
                * rouge2
                * rougeL

            model_card (str): the NLI model used for hallucination evaluation

        """
        super().__init__(depend, "sequence")
        # self.rouge_name = rouge_name
        # self.scorer = rouge_scorer.RougeScorer([rouge_name], use_stemmer=True)



        self.scorer = FactualConsistencyScorer(align=set_align)



    def __str__(self):
        return f"ctc_fact_consistency"

    ## below is for rouge score
    # def _score_single(self, t1: str, t2: str):
    #     sc = self.scorer.score(t1, t2)[self.rouge_name].fmeasure
    #     sc_best = self.scorer.score(t2, t2)[self.rouge_name].fmeasure
    #     if sc_best == 0:
    #         return np.nan
    #     return sc / sc_best


    def _score(self, hypo, grounding):
        final_res = self.scorer.score(grounding, hypo)

        return final_res

    # def score(self, grounding, hypo, aspect='consistency', remove_stopwords=False):
    #     kwargs = dict(
    #         grounding=grounding,
    #         hypo=hypo,
    #         remove_stopwords=remove_stopwords)
    #
    #     if aspect == 'consistency':
    #         return self.score_consistency(**kwargs)
    #     else:
    #         raise NotImplementedError

    # def score_consistency(self, grounding, hypo, remove_stopwords):
    #     aligner = self._get_aligner('doc_to_summ')
    #
    #     return aligner.get_score(
    #         context=grounding,
    #         input_text=hypo,
    #         remove_stopwords=remove_stopwords)







    # def _get_nli_socre(self, gt_sent, pd_sent, tokenizer, model):
    #
    #     # concate_sent_list = [(gt_sent, pd_sent), (gt_sent, gt_sent)]
    #     concate_sent_list = [(gt_sent, pd_sent)]
    #
    #     print(f"len(concate_sent_list)={len(concate_sent_list)}")
    #
    #     batch_tokens = tokenizer.batch_encode_plus(concate_sent_list, padding=True,
    #                                                     truncation=True, max_length=512,
    #                                                     return_tensors="pt", truncation_strategy="only_first")
    #     with torch.no_grad():
    #         model_outputs = model(**{k: v.cuda() for k, v in batch_tokens.items()})
    #
    #     batch_probs = torch.nn.functional.softmax(model_outputs["logits"], dim=-1)
    #     batch_evids = batch_probs[:, self.entailment_idx]
    #     batch_conts = batch_probs[:, self.contradiction_idx]
    #
    #     res =  batch_evids - batch_conts # using the direction that is similar to Rouge and BERTScore
    #
    #     # if res.cpu().float()[1].item() == 0:
    #     #     return np.nan
    #
    #     print('RES is: ', res)
    #     # final_res = (res.cpu().float()[0].item()) / (res.cpu().float()[1].item())
    #     final_res = res.cpu().float()[0].item()
    #
    #     print("summac_res: ", final_res)
    #     return final_res



    def __call__(
        self,
        stats: Dict[str, np.ndarray],
        target_texts: List[str],
        target_tokens: List[List[int]],
        white,
    ) -> np.ndarray:
        """
        Calculates Rouge score between stats['greedy_texts'] and target_texts.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * model-generated texts in 'greedy_texts'
            target_texts (List[str]): ground-truth texts
            target_tokens (List[List[int]]): corresponding token splits for each target text
        Returns:
            np.ndarray: list of Rouge Scores for each sample in input.
        """
        if white:
            greedy_text_key = "greedy_texts"
        else:
            greedy_text_key = "blackbox_greedy_texts"
        return np.array(
            [
                # self._score_single(hyp, ref)
                # self._get_nli_socre(hyp, ref, self.sent_toknizer, self.sent_model)
                self._score(hyp, ref)
                for hyp, ref in zip(stats[greedy_text_key], target_texts)
            ]
        )
