from dataprocess.data_preprocess import T5Preprocessor
#from transformers import T5Tokenizer

#source tasks
# MNLI , QNLI , QQP , SST2 , SQuAD, and ReCoRD
labeldict={'glue:mnli': {0: 'entailment', 1: 'neutral', 2: 'contradiction'},
           'glue:qqp': {0: 'not_duplicate', 1: 'duplicate'},
           'glue:qnli': {0: 'entailment', 1: 'not_entailment'},
           'glue:sst2': {0: 'negative', 1: 'positive'}}


class SST2Preprocessor(T5Preprocessor):
    def __init__(self,  tokenizer, prefixes,preprocessing_modes,dataset_name_1,dataset_name_2=None):
        self.label_dict = {0: 'negative', 1: 'positive'}
        super().__init__( tokenizer, prefixes,preprocessing_modes,dataset_name_1,dataset_name_2)


    def process_example(self, prefix, example):
        processed_input = f"{prefix} {example['sentence']}"
        label_word = self.label_dict.get(example['label'], "unknown")
        processed_target = label_word
        return {
            "input": self.tokenizer.encode(processed_input, truncation=True, padding='max_length',max_length=128),
            "target": self.tokenizer.encode(processed_target, truncation=True, padding='max_length',max_length=3),
            "prefix": prefix
        }
# 示例使用
# tokenizer = 您的T5模型的tokenizer
# prefixes = ["sst2 sentence:", ...]
# label_dict = {1: "positive", 0: "negative"}
# preprocessing_modes = {"few-shot_train": 2, "train": 1, "validation": 3, "test": 3}
# sst2_preprocessor = SST2Preprocessor("sst2", tokenizer, prefixes, label_dict, preprocessing_modes)


class MNLIPreprocessor(T5Preprocessor):
    def __init__(self,tokenizer, prefixes, preprocessing_modes,dataset_name_1,dataset_name_2=None):
        self.label_dict = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
        super().__init__(tokenizer, prefixes, preprocessing_modes,dataset_name_1,dataset_name_2)

    def process_example(self, prefix, example):
        # 处理输入，结合假设和前提
        processed_input = f"{prefix} hypothesis: {example['hypothesis']} premise: {example['premise']}"

        # 根据label_dict将标签数字转换为文字标签
        label_word = self.label_dict.get(example['label'], "unknown")
        processed_target = label_word

        return {
            "input": self.tokenizer.encode(processed_input, truncation=True, padding='max_length',max_length=256),
            "target": self.tokenizer.encode(processed_target, truncation=True, padding='max_length',max_length=3),
            "prefix": prefix
        }

# 示例使用
# tokenizer = 您的T5模型的tokenizer
# prefixes = ["mnli:", ...]
# label_dict = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}}
# preprocessing_modes = {"few-shot_train": 2,"train": 1, "validation_matched": 3, "validation_mismatched": 3, "test_matched": 3, "test_mismatched": 3}
# preprocessing_modes = {"few-shot_train": 2,"train": 1, "validation_matched": 3, "test_matched": 3}
# mnli_preprocessor = MNLIPreprocessor("mnli", tokenizer, prefixes, label_dict, preprocessing_modes)



class QQPPreprocessor(T5Preprocessor):
    def __init__(self,tokenizer, prefixes, preprocessing_modes,dataset_name_1,dataset_name_2=None):
        self.label_dict = {0: 'not duplicate', 1: 'duplicate'}
        super().__init__(tokenizer, prefixes, preprocessing_modes,dataset_name_1,dataset_name_2)

    def process_example(self, prefix, example):
        # 处理输入，结合假设和前提
        processed_input = f"{prefix} hypothesis: {example['hypothesis']} premise: {example['premise']}"

        # 根据label_dict将标签数字转换为文字标签
        label_word = self.label_dict.get(example['label'], "unknown")
        processed_target = label_word

        return {
            "input": self.tokenizer.encode(processed_input, truncation=True, padding='max_length',max_length=256),
            "target": self.tokenizer.encode(processed_target, truncation=True, padding='max_length',max_length=3),
            "prefix": prefix
        }


class QNLIPreprocessor(T5Preprocessor):
    def __init__(self,tokenizer, prefixes, label_dict, preprocessing_modes,dataset_name_1,dataset_name_2=None):
        self.label_dict = {0: 'entailment', 1: 'not entailment'}
        super().__init__(tokenizer, prefixes, preprocessing_modes,dataset_name_1,dataset_name_2)

    def process_example(self, prefix, example):
        # 处理输入，结合假设和前提
        processed_input = f"{prefix} hypothesis: {example['hypothesis']} premise: {example['premise']}"

        # 根据label_dict将标签数字转换为文字标签
        label_word = self.label_dict.get(example['label'], "unknown")
        processed_target = label_word

        return {
            "input": self.tokenizer.encode(processed_input, truncation=True, padding='max_length',max_length=256),
            "target": self.tokenizer.encode(processed_target, truncation=True, padding='max_length',max_length=3),
            "prefix": prefix
        }

# 示例使用
# tokenizer = 您的T5模型的tokenizer
# prefixes = ["mnli:", ...]
# label_dict = {0: "entailment", 1: "neutral", 2: "contradiction"}
# preprocessing_modes = {"train": 1, "validation_matched": 3, "validation_mismatched": 3, "test_matched": 3, "test_mismatched": 3}
# mnli_preprocessor = MNLIPreprocessor("mnli", tokenizer, prefixes, label_dict, preprocessing_modes)



#target tasks
# SuperGLUE：MultiRC, BoolQ , WiC , WSC, and CB
# GLUE：RTE , CoLA , STS-B ,MRPC , MNLI, QQP, QNLI and SST-2
# MRQA：Natural Questions , HotpotQA , NewsQA  and SearchQA
# Others：WinoGrande， Yelp-2 , SciTail  and PAWS-Wiki


#squad
# Using the latest cached version of the module from C:\Users\95455\.cache\huggingface\modules\datasets_modules\datasets\squad\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453 (last modified on Thu Dec 14 00:06:41 2023) since it couldn't be found locally at squad., or remotely on the Hugging Face Hub.
# DatasetDict({
#     train: Dataset({
#         features: ['id', 'title', 'context', 'question', 'answers'],
#         num_rows: 87599
#     })
#     validation: Dataset({
#         features: ['id', 'title', 'context', 'question', 'answers'],
#         num_rows: 10570
#     })
# })
# super_glue record
# DatasetDict({
#     train: Dataset({
#         features: ['passage', 'query', 'entities', 'entity_spans', 'answers', 'idx'],
#         num_rows: 100730
#     })
#     validation: Dataset({
#         features: ['passage', 'query', 'entities', 'entity_spans', 'answers', 'idx'],
#         num_rows: 10000
#     })
#     test: Dataset({
#         features: ['passage', 'query', 'entities', 'entity_spans', 'answers', 'idx'],
#         num_rows: 10000
#     })
# })
# glue mnli
# DatasetDict({
#     train: Dataset({
#         features: ['premise', 'hypothesis', 'label', 'idx'],
#         num_rows: 392702
#     })
#     validation_matched: Dataset({
#         features: ['premise', 'hypothesis', 'label', 'idx'],
#         num_rows: 9815
#     })
#     validation_mismatched: Dataset({
#         features: ['premise', 'hypothesis', 'label', 'idx'],
#         num_rows: 9832
#     })
#     test_matched: Dataset({
#         features: ['premise', 'hypothesis', 'label', 'idx'],
#         num_rows: 9796
#     })
#     test_mismatched: Dataset({
#         features: ['premise', 'hypothesis', 'label', 'idx'],
#         num_rows: 9847
#     })
# })
# glue qqp
# DatasetDict({
#     train: Dataset({
#         features: ['question1', 'question2', 'label', 'idx'],
#         num_rows: 363846
#     })
#     validation: Dataset({
#         features: ['question1', 'question2', 'label', 'idx'],
#         num_rows: 40430
#     })
#     test: Dataset({
#         features: ['question1', 'question2', 'label', 'idx'],
#         num_rows: 390965
#     })
# })
# glue qnli
# DatasetDict({
#     train: Dataset({
#         features: ['question', 'sentence', 'label', 'idx'],
#         num_rows: 104743
#     })
#     validation: Dataset({
#         features: ['question', 'sentence', 'label', 'idx'],
#         num_rows: 5463
#     })
#     test: Dataset({
#         features: ['question', 'sentence', 'label', 'idx'],
#         num_rows: 5463
#     })
# })
# glue sst2
# DatasetDict({
#     train: Dataset({
#         features: ['sentence', 'label', 'idx'],
#         num_rows: 67349
#     })
#     validation: Dataset({
#         features: ['sentence', 'label', 'idx'],
#         num_rows: 872
#     })
#     test: Dataset({
#         features: ['sentence', 'label', 'idx'],
#         num_rows: 1821
#     })
# })

