import numpy as np
# from rouge_score import rouge_scorer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

from typing import List, Dict
from .generation_metric import GenerationMetric

from .eval_metric.utils import convert_to_json
from .eval_metric.evaluator import get_evaluator
import json
import os


class UniEval(GenerationMetric):
    """
    Calculates Rouge metric between model-generated texts and ground truth texts.
    """

    def __init__(self, task='summarization', selected_key = 'overall', file_name ='./assist_res/unieval_res.json', depend=["greedy_texts"]):
        """
        Parameters:
            rouge_name (str): rouge metric type. Possible values:
                * rouge1
                * rouge2
                * rougeL

            model_card (str): the NLI model used for hallucination evaluation

        """
        super().__init__(depend, "sequence")
        # self.rouge_name = rouge_name
        # self.scorer = rouge_scorer.RougeScorer([rouge_name], use_stemmer=True)


        # model_card = 'tals/albert-xlarge-vitaminc-mnli'
        self.task = task # ['summarization', 'dialogue', 'fact']
        self.selected_key = selected_key # [coherence, consistency, fluency, relevance, overall]
        self.file_name = file_name
        if self.selected_key == 'overall':
            self.to_read_json_status = False
        else:
            self.to_read_json_status = True
        # remove ori_file
        if self.to_read_json_status == False:
            if os.path.isfile(self.file_name):
                os.remove(self.file_name)


    def __str__(self):
        return f"UniEval_{self.task}_{self.selected_key}"

    ## below is for rouge score
    # def _score_single(self, t1: str, t2: str):
    #     sc = self.scorer.score(t1, t2)[self.rouge_name].fmeasure
    #     sc_best = self.scorer.score(t2, t2)[self.rouge_name].fmeasure
    #     if sc_best == 0:
    #         return np.nan
    #     return sc / sc_best


    def _get_single_unieval_socre(self, pd_sent, gt_sent=None):
        src_list = [gt_sent] # should replace into the original text


        ref_list = [gt_sent]

        output_list = [pd_sent]

        # Prepare data for pre-trained evaluators
        if self.task == 'fact':
            data = convert_to_json(output_list=output_list,
                               src_list=src_list)
        else:
            data = convert_to_json(output_list=output_list,
                                   src_list=src_list, ref_list=ref_list)
        # Initialize evaluator for a specific task
        evaluator = get_evaluator(self.task)
        # Get multi-dimensional evaluation scores
        eval_scores = evaluator.evaluate(data, print_result=True)
        # eval_scores[0]
        # {'coherence': 0.3496320601478555, 'consistency': 0.8283510196125293, 'fluency': 0.7905930252047128,
        #  'relevance': 0.5022504812582286, 'overall': 0.6177066465558315}

        # print(eval_scores)
        # assert 1==0

        return eval_scores[0][self.selected_key]


    def list_to_json(self, pred_list, label_list, file_name):
        if os.path.isfile(file_name):
            os.remove(file_name)

        res = {}
        res['pred'] = pred_list
        res['label'] = label_list


        with open(file_name, 'w') as f:
            json_str = json.dumps(res)
            f.write(json_str)
        print(f"{file_name} has saved all sampled predictions ['pred'] and labels ['label'] .")

    def json_to_list(self, file_name):
        with open(file_name, 'r') as ini_f:
            ini_str = ini_f.readline()
            res = json.loads(ini_str) # res['pred'], res['label']
        return res

    def read_res(self): # read a json file by row
        with open(self.file_name, 'r') as ini_f:
            final_res = {}
            for ini_str in ini_f:
                mid_res = json.loads(ini_str)  # res['pred'], res['label']
                final_res.update(mid_res)
        self.dict_res = final_res

    def mapping_res(self, hyp, ref, sou, key_word):
        cur_key = sou + ref
        return self.dict_res[cur_key][key_word]



    def _get_batch_unieval_socre(self, pd_sent, gt_sent=None, sr_sent=None):
        src_list = sr_sent # should replace into the original text


        ref_list = gt_sent

        output_list = pd_sent

        # Prepare data for pre-trained evaluators
        if self.task == 'fact':
            data = convert_to_json(output_list=output_list,
                               src_list=src_list)
        else:
            data = convert_to_json(output_list=output_list,
                                   src_list=src_list, ref_list=ref_list)
        # Initialize evaluator for a specific task
        evaluator = get_evaluator(self.task)
        # Get multi-dimensional evaluation scores
        eval_scores = evaluator.evaluate(data, print_result=True)
        # eval_scores[0]
        # {'coherence': 0.3496320601478555, 'consistency': 0.8283510196125293, 'fluency': 0.7905930252047128,
        #  'relevance': 0.5022504812582286, 'overall': 0.6177066465558315}

        # print(eval_scores)
        # assert 1==0

        if self.selected_key == 'overall':
            mid_res = []
            for i in range(len(eval_scores)):
                mid_val = eval_scores[i][self.selected_key]
                mid_res.append(mid_val)



            # save ori_file
            with open(self.file_name, 'a+') as f:
                for i in range(len(eval_scores)):
                    json_str = json.dumps(
                        # {(sr_sent[i] + gt_sent[i]): eval_scores[i]}
                        {sr_sent[i]: eval_scores[i]}
                    )
                    f.write(json_str)
                    f.write('\n')



        return mid_res

    # def _inter_cal_fai_onesample(self, ):



    def __call__(
        self,
        stats: Dict[str, np.ndarray],
        target_texts: List[str],
        target_tokens: List[List[int]],
        white=None,
    ) -> np.ndarray:
        """
        Calculates Rouge score between stats['greedy_texts'] and target_texts.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * model-generated texts in 'greedy_texts'
            target_texts (List[str]): ground-truth texts
            target_tokens (List[List[int]]): corresponding token splits for each target text
        Returns:
            np.ndarray: list of Rouge Scores for each sample in input.
        """


        # return np.array(
        #     [
        #         # self._score_single(hyp, ref)
        #         self._get_single_unieval_socre(hyp, ref)
        #         for hyp, ref in zip(stats["greedy_texts"], target_texts)
        #     ]
        # )

        if white:
            greedy_text_key = "greedy_texts"
        else:
            greedy_text_key = "blackbox_greedy_texts"

        if self.selected_key == 'overall':
            return np.array(self._get_batch_unieval_socre(stats[greedy_text_key], target_texts, stats["input_texts"]))
        else:
            if self.to_read_json_status == True:
                self.json_res = self.read_res()
            return np.array(
                [
                    # self._score_single(hyp, ref)
                    self.mapping_res(hyp, ref, sou, self.selected_key)
                    for hyp, ref, sou in zip(stats[greedy_text_key], target_texts, stats["input_texts"]) # prediction, ground truth, input
                ]
            )




