"""
multilingual universal sentence encoder
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""

from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
from textattack.shared.utils import LazyLoader

hub = LazyLoader("tensorflow_hub", globals(), "tensorflow_hub")



class MultilingualUniversalSentenceEncoder(SentenceEncoder):
    """Constraint using similarity between sentence encodings of x and x_adv
    where the text embeddings are created using the Multilingual Universal
    Sentence Encoder."""

    def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs):
        super().__init__(threshold=threshold, metric=metric, **kwargs)
        tensorflow_text = LazyLoader("tensorflow_text", globals(), "tensorflow_text")
        tensorflow_text._load()
        # if large:
        #     tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/3"
        # else:
        #     tfhub_url = (
        #         "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
        #     )
        if large:
        #源代码：tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
            tfhub_url = "https://hub.tensorflow.google.cn/google/universal-sentence-encoder-multilingual-large/3" #国内镜像
        else:
        #源代码：tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder/3"
            tfhub_url = "https://hub.tensorflow.google.cn/google/universal-sentence-encoder-multilingual/3" #国内镜像

        # TODO add QA SET. Details at: https://tfhub.dev/google/universal-sentence-encoder-multilingual-qa/3
        self._tfhub_url = tfhub_url
        # self.model = hub.load(tfhub_url)
        self.model = None

    def encode(self, sentences):
        if not self.model:
            self.model = hub.load(self._tfhub_url)
        encoding = self.model(sentences)

        if isinstance(encoding, dict):
            encoding = encoding["outputs"]

        return encoding.numpy()
        # return self.model(sentences).numpy()

    def __getstate__(self):
        state = self.__dict__.copy()
        state["model"] = None
        return state

    def __setstate__(self, state):
        self.__dict__ = state
        # self.model = hub.load(self._tfhub_url)
        self.model = None