# coding=utf-8
import argparse
import transformers

from utils.get_stopwords import get_stopwords
from utils.dataset import data_processor
from utils.zh_lstm_cnn import ZHLSTMForClassification, ZHWordCNNForClassification

import transformations
import search_methods

from textattack.models.wrappers import HuggingFaceModelWrapper, PyTorchModelWrapper
from textattack.constraints.pre_transformation import RepeatModification, StopwordModification, InputColumnModification
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.goal_functions import UntargetedClassification

from textattack import Attack, Attacker, AttackArgs
from textattack.datasets import Dataset, HuggingFaceDataset
import textattack

HUGGINGFACE_MODELS = {
    #
    # bert-base-chinese
    #
    "bert-chinanews": "/root/autodl-tmp/bert-base-finetuned-chinanews-chinese",
    "bert-chnsenticorp": "/root/autodl-tmp/bert-base-finetuned-chnsenticorp-chinese",
    "bert-ocnli": "/root/autodl-tmp/bert-base-finetuned-ocnli-chinese",
    "bert-lcqmc": "/root/autodl-tmp/bert-base-finetuned-lcqmc-chinese",
    "bert-ctrip": "/root/autodl-tmp/bert-base-finetuned-ctrip",
    "bert-jd": "/root/autodl-tmp/bert-base-finetuned-jd",
    #
    # roberta-base-wwm-chinese-cluecorpussmall
    #
    "roberta-chinanews": "/root/autodl-tmp/roberta-base-finetuned-chinanews-chinese",
    "roberta-chnsenticorp": "/root/autodl-tmp/roberta-base-finetuned-chnsenticorp-chinese",
    "roberta-ocnli": "/root/autodl-tmp/roberta-base-finetuned-ocnli-chinese",
    "roberta-lcqmc": "/root/autodl-tmp/roberta-base-finetuned-lcqmc-chinese",
    "roberta-ctrip": "/root/autodl-tmp/roberta-base-finetuned-ctrip",
    "roberta-jd": "/root/autodl-tmp/roberta-base-finetuned-jd",
    #
    # albert-base-chinese-cluecorpussmall
    #
    "albert-chinanews": "/root/autodl-tmp/albert-base-finetuned-chinanews-chinese",
    "albert-chnsenticorp": "/root/autodl-tmp/albert-base-finetuned-chnsenticorp-chinese",
    "albert-ocnli": "/root/autodl-tmp/albert-base-finetuned-ocnli-chinese",
    "albert-lcqmc": "/root/autodl-tmp/albert-base-finetuned-lcqmc-chinese",
    "albert-ctrip": "/root/autodl-tmp/albert-base-finetuned-ctrip",
    "albert-jd": "/root/autodl-tmp/albert-base-finetuned-jd",
    #
    # distilbert-base-multilingual-cased
    #
    "distilbert-chinanews": "/root/autodl-tmp/distilbert-base-finetuned-chinanews-chinese",
    "distilbert-chnsenticorp": "/root/autodl-tmp/distilbert-base-finetuned-chnsenticorp-chinese",
    "distilbert-ocnli": "/root/autodl-tmp/distilbert-base-finetuned-ocnli-chinese",
    "distilbert-lcqmc": "/root/autodl-tmp/distilbert-base-finetuned-lcqmc-chinese",
    "distilbert-ctrip": "/root/autodl-tmp/distilbert-base-finetuned-ctrip",
    "distilbert-jd": "/root/autodl-tmp/distilbert-base-finetuned-jd",
}
TEXTATTACK_MODELS = {
    #
    # LSTMs
    #
    "lstm-chinanews": "/root/autodl-tmp/lstm_and_cnn/lstm-chinanews-chinese",
    "lstm-chnsenticorp": "/root/autodl-tmp/lstm_and_cnn/lstm-chnsenticorp-chinese",
    #
    # CNNs
    #
    "cnn-chinanews": "/root/autodl-tmp/lstm_and_cnn/cnn-chinanews-chinese",
    "cnn-chnsenticorp": "/root/autodl-tmp/lstm_and_cnn/cnn-chnsenticorp-chinese",
}
DATA_PROCESSOR = {
    "chinanews": data_processor.ChinanewsProcessor,
    "chnsenticorp": data_processor.ChnsenticorpProcessor,
    "ocnli": data_processor.OcnliProcessor,
    "lcqmc": data_processor.LcqmcProcessor,
    "ctrip": data_processor.CtripHotelReviewsProcessor,
    "jd": data_processor.JDComProductReviewsProcessor,
}
TRANSFORMATION_CLASS_NAMES = {
    "shuffle": transformations.ChineseShuffleCharacterSwap(),
    "split": transformations.ChineseSplittingCharacterSwap(),
    "synonym": transformations.ChineseSynonymWordSwap(),
    "same-pinyin": transformations.ChineseHomophoneCharacterSwap(),
    "sim-pinyin": transformations.ChineseSimilarPinyinCharacterSwap(),
    "glyph": transformations.ChineseGlyphCharacterSwap(),
    "argot": transformations.ChineseArgotWordSwap,
    "wordnet": transformations.ChineseWordSwapWordNet('cmn'),
    "hownet": transformations.ChineseWordSwapHowNet(),
    "mix-ssc": transformations.mix_ssc,
    "mlm": transformations.ChineseWordSwapMaskedLM(), 
    "es": transformations.ChineseExpandingScopeWordSwap,
}
SEARCH_METHOD_CLASS_NAMES = {
    "beam-search": textattack.search_methods.BeamSearch(),
    "greedy": textattack.search_methods.GreedySearch(),
    "ga": textattack.search_methods.AlzantotGeneticAlgorithm(),
    "delete": textattack.search_methods.GreedyWordSwapWIR(wir_method='delete'),
    "unk": textattack.search_methods.GreedyWordSwapWIR(wir_method='unk'),
    "pso": search_methods.FasterParticleSwarmOptimization(),
    'ia': search_methods.ImmuneAlgorithm(),
    'pwws': textattack.search_methods.GreedyWordSwapWIR(wir_method='weighted-saliency'),
}


def build_model_and_dataset(model_name_or_path):
    model_name, task_name = model_name_or_path.split('-')

    processor = DATA_PROCESSOR[task_name]()
    label_names = processor.get_labels()
    input_columns = processor.get_input_columns()

    dataset = processor.get_test_examples(
        f'/root/autodl-tmp/attack_datasets/{task_name}')
    dataset = Dataset(dataset, input_columns=input_columns, label_names=label_names)

    if args.victim_model in HUGGINGFACE_MODELS.keys():
        model_name_or_path = HUGGINGFACE_MODELS[args.victim_model]

        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_name_or_path)
        model = transformers.AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path)
        model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
    elif args.victim_model in TEXTATTACK_MODELS.keys():
        model_name_or_path = TEXTATTACK_MODELS[args.victim_model]

        if model_name == 'cnn':
            model = ZHWordCNNForClassification.from_pretrained(
                model_name_or_path)
            model_wrapper = PyTorchModelWrapper(model, model.tokenizer)
        elif model_name == 'lstm':
            model = ZHLSTMForClassification.from_pretrained(model_name_or_path)
            model_wrapper = PyTorchModelWrapper(model, model.tokenizer)
    else:
        raise ValueError(f"{model_name_or_path} is not support!")

    return model_wrapper, dataset


def build_constraints(model_name_or_path):
    stopwords = get_stopwords()
    constraints = [RepeatModification(),
                   StopwordModification(stopwords=stopwords),]
                #    MaxWordsPerturbed(max_percent=0.15)]
    model_name, task_name = model_name_or_path.split('-')

    if task_name == 'ocnli':
        input_column_modification = InputColumnModification(
            ["premise", "hypothesis"], {"premise"}
        )
        constraints.append(input_column_modification)
    elif task_name == 'lcqmc':
        input_column_modification = InputColumnModification(
            ["text_a", "text_b"], {"text_a"}
        )
        constraints.append(input_column_modification)
    else:
        pass

    return constraints


def main():
    textattack.shared.utils.set_seed(718)
    model_wrapper, dataset = build_model_and_dataset(args.victim_model)
    constraints = build_constraints(args.victim_model)

    goal_function = UntargetedClassification(model_wrapper, query_budget=50000)

    transformation = TRANSFORMATION_CLASS_NAMES[args.transformation]
    search_method = SEARCH_METHOD_CLASS_NAMES[args.search_method]

    attack = Attack(goal_function, constraints, transformation, search_method)
    attack_args = AttackArgs(
        num_examples=args.num_examples, random_seed=718, enable_advance_metrics=True, checkpoint_interval=args.num_examples, checkpoint_path=args.checkpoint_path)

    attacker = Attacker(attack, dataset, attack_args)
    attacker.attack_dataset()
    print(f"victim model: {args.victim_model}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--transformation",
        "-t",
        type=str,
        default='mix-ssc',
        choices=TRANSFORMATION_CLASS_NAMES.keys(),
        help="apply a transformations for attack"
    )
    parser.add_argument(
        "--search-method",
        "-s",
        type=str,
        default='ia',
        choices=SEARCH_METHOD_CLASS_NAMES.keys(),
        help="apply a search method for attack"
    )
    parser.add_argument(
        "--victim-model",
        "-m",
        type=str,
        default="bert-chinanews",
        help="apply a victim model for attack"
    )
    parser.add_argument(
        "--num-examples",
        "-n",
        type=int,
        default=100,
        help="number of examples to attack"
    )
    parser.add_argument(
        "--checkpoint-path",
        "-ckp",
        type=str,
        default=None,
        help="path to load checkpoint files"
    )
    global args
    args = parser.parse_args()

    main()
