from textattack.datasets import HuggingFaceDataset
from data.instance import InputInstance
from typing import List, Dict
import random
from textattack.attack_recipes import TextFoolerJin2019,HotFlipEbrahimi2017,DeepWordBugGao2018,TextBuggerLi2018,PSOZang2020,BERTAttackLi2020
from textattack.transformations import WordSwapEmbedding
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.goal_functions import UntargetedClassification
from textattack.constraints.pre_transformation import (
    InputColumnModification,
    RepeatModification,
    StopwordModification,
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.constraints.overlap import LevenshteinEditDistance
from textattack import Attack
from textattack.models.wrappers import ModelWrapper,PyTorchModelWrapper
import torch

class CustomModelWrapper(PyTorchModelWrapper):
    def __init__(self,model,tokenizer):
        super(CustomModelWrapper,self).__init__(model,tokenizer)

    def __call__(self,text_input_list):
        inputs_dict = self.tokenizer(
            text_input_list,
            truncation=True,
            padding=True,
            return_tensors="pt",
        )
        model_device = next(self.model.parameters()).device
        inputs_dict.to(model_device)

        with torch.no_grad():
            outputs = self.model(**inputs_dict)

        if isinstance(outputs,tuple):
            return outputs[-1]

        if isinstance(outputs,torch.Tensor):
            return outputs

        if isinstance(outputs[0], str):
            # HuggingFace sequence-to-sequence models return a list of
            # string predictions as output. In this case, return the full
            # list of outputs.
            return outputs
        else:
            # HuggingFace classification models return a tuple as output
            # where the first item in the tuple corresponds to the list of
            # scores for each input.
            return outputs.logits

class CustomTextAttackDataset(HuggingFaceDataset):
    """Loads a dataset from HuggingFace ``datasets`` and prepares it as a
    TextAttack dataset.

    - name: the dataset name
    - subset: the subset of the main dataset. Dataset will be loaded as ``datasets.load_dataset(name, subset)``.
    - label_map: Mapping if output labels should be re-mapped. Useful
      if model was trained with a different label arrangement than
      provided in the ``datasets`` version of the dataset.
    - output_scale_factor (float): Factor to divide ground-truth outputs by.
        Generally, TextAttack goal functions require model outputs
        between 0 and 1. Some datasets test the model's correlation
        with ground-truth output, instead of its accuracy, so these
        outputs may be scaled arbitrarily.
    - shuffle (bool): Whether to shuffle the dataset on load.
    """

    def __init__(
            self,
            name,
            instances: List[InputInstance],
            label_map: Dict[str, int] = None,
            output_scale_factor=None,
            dataset_columns=None,
            shuffle=False,
    ):
        assert instances is not None or len(instances) == 0
        self._name = name
        self._i = 0
        self.label_map = label_map
        self.output_scale_factor = output_scale_factor
        self.label_names = sorted(list(label_map.keys()))

        if instances[0].is_nli():
            self.input_columns, self.output_column = ("premise", "hypothesis"), "label"
            self.examples = [{"premise": instance.text_a, "hypothesis": instance.text_b, "label": int(instance.label)}
                             for
                             instance in instances]
        else:
            self.input_columns, self.output_column = ("text",), "label"
            self.examples = [{"text": instance.text_a, "label": int(instance.label)} for instance in instances]

        if shuffle:
            random.shuffle(self.examples)
        self.shuffled=shuffle

    @classmethod
    def from_instances(cls, name: str, instances: List[InputInstance],
                       labels: Dict[str, int]) -> "CustomTextAttackDataset":
        return cls(name, instances, labels)

def build_attacker(model,args):
    if args['attack_method'] == 'hotflip':
        return HotFlipEbrahimi2017.build(model)
    #build返回的是否是Attack对象？
    # if args['attack_method'] == 'pwws':
    #     attacker = PWWSRen2019.build(model)
    elif (args['attack_method'] == 'textfooler'):
        attacker=TextFoolerJin2019.build(model)
    elif (args['attack_method'] == 'textbugger'):
        attacker=TextBuggerLi2018.build(model)
    elif (args['attack_method'] == 'bertattack'):
        attacker=BERTAttackLi2020.build(model)
    elif (args['attack_method']== 'deepwordbug'):
        attacker=DeepWordBugGao2018.build(model)
        if (args['modify_ratio'] == 0):
            # attacker.constraints.append(UniversalSentenceEncoder(0.8))
            attacker.constraints.append(LevenshteinEditDistance(5))
    elif (args['attack_method'] == "pso"):
        attacker=PSOZang2020.build(model)
    else:
        raise NotImplementedError
        # attacker=TextFoolerJin2019.build(model)
    if(args['modify_ratio']!=0):
        attacker.constraints.append(MaxWordsPerturbed(max_percent=args['modify_ratio']))
    return attacker