import numpy as np
from sentence_transformers import SentenceTransformer, util

from typing import List, Dict
from .generation_metric import GenerationMetric

from scipy import stats


class P_K_Correlation(GenerationMetric):
    def __init__(self, depend=["greedy_texts"], cor_type=''):
        super().__init__(depend, "sequence")
        # self.sbert = SentenceTransformer("all-mpnet-base-v2")   # single language
        self.sbert = SentenceTransformer("sentence-transformers/distiluse-base-multilingual-cased-v2")
        assert cor_type in ['spearmanr', 'kendalltau']
        self.cor_type = cor_type

    def __str__(self):
        return f"correlation_{self.cor_type}"

    # def _score_single(self, t1: str, t2: str):
    #     return util.cos_sim(self.sbert.encode(t1), self.sbert.encode(t2)).item()

    def _cal_single_corr(
        self,
        gen_text,
        ref_text
    ) -> np.ndarray:
        hypo_embedding = self.sbert.encode(gen_text)
        ref_embedding = self.sbert.encode(ref_text)
        if self.cor_type == 'spearmanr':
            res = stats.spearmanr(hypo_embedding, ref_embedding)
        elif self.cor_type == 'kendalltau':
            res = stats.kendalltau(hypo_embedding, ref_embedding)
        else:
            raise ValueError(f'self.cor_type={self.cor_type} is wrongly set!')
        return res.statistic
        # return util.pairwise_cos_sim(embeddings, references).numpy()


    def __call__(
        self,
        prov_stats: Dict[str, np.ndarray],
        target_texts: List[str],
        target_tokens: List[List[int]],
        white,
    ) -> np.ndarray:
        """
        Calculates Rouge score between stats['greedy_texts'] and target_texts.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * model-generated texts in 'greedy_texts'
            target_texts (List[str]): ground-truth texts
            target_tokens (List[List[int]]): corresponding token splits for each target text
        Returns:
            np.ndarray: list of Rouge Scores for each sample in input.
        """
        if white:
            greedy_text_key = "greedy_texts"
        else:
            greedy_text_key = "blackbox_greedy_texts"
        return np.array(
            [
                # self._score_single(hyp, ref)
                self._cal_single_corr(hyp, ref)
                for hyp, ref in zip(prov_stats[greedy_text_key], target_texts)
            ]
        )

# if __name__ == "__main__":
#     """
#     Kind of tests, while there is no test suite
#     """
#     metric = SbertMetric()
#     stats = {
#         "greedy_texts": [
#             "Apple",
#             "Orange",
#             "Car",
#             "The best drink is a beer",
#             "January is before February",
#         ]
#     }
#     target_texts = ["Apple", "Apple", "Apple", "Octoberfest", "Octoberfest"]
#
#     scores = metric(stats, target_texts, None)
#     print(scores)
#
#     assert scores.shape == (5,)
#     assert scores[0] - 1 < 1e-5
#     assert scores[1] > scores[2]
#     assert scores[3] > scores[4]
