from datasets import load_dataset, DatasetDict
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import random
from datasets import Dataset, DatasetDict
import pandas as pd




class T5Preprocessor:

    small_datasets_without_all_splits = ["cola", "wnli", "rte", "superglue-cb", "superglue-copa", "superglue-multirc",
                                         "superglue-wic", "superglue-wsc.fixed", "superglue-rte", "mrpc", "stsb",
                                         "superglue-boolq", "xsum", "scitail"]
    large_data_without_all_splits = ["qqp", "qnli", "superglue-record", "sst2", "squad", "snli", "anli",
                                     "amazon_polarity", "yelp_polarity", "winogrande", "newsqa", "searchqa", "triviaqa",
                                     "nq", "hotpotqa"]



    def __init__(self, tokenizer, prefixes, preprocessing_modes,dataset_name_1,dataset_name_2=None):

        if dataset_name_2==None:
            self.dataset_name=dataset_name_1
            print('input only one name, we will use dataset %s'%(dataset_name_1))
            self.raw_dataset = load_dataset(dataset_name_1,script_version="master")
        else:
            self.dataset_name = dataset_name_2
            print('input two name, we will use dataset %s , %s' % (dataset_name_1,dataset_name_2))
            self.raw_dataset = load_dataset(dataset_name_1,dataset_name_2,script_version="master")

        self.tokenizer = tokenizer
        self.prefixes = prefixes
        self.preprocessing_modes = preprocessing_modes
        self.need_split_val=False




        print('data_load_ready')
        print('original data format is :')
        print(self.raw_dataset)

        self.dataset = None

        #self.prepare_dataset()

    def prepare_dataset(self):
        # if "train" not in self.raw_dataset or "test" not in self.raw_dataset:
        #     print('we need to split the dataset')
        #     # if self.dataset_name in self.small_datasets_without_all_splits :
        #     #
        #     #     self.split_dataset("validation")
        #
        #     # else:
        #     #
        #     #     self.split_dataset("validation")
        # else:
        #     print('we use original dataset')
        #     self.dataset = self.raw_dataset
        pass

        #self.transform_dataset_for_t5()

    def split_dataset(self):
        # # 分割数据集
        # dataset = self.raw_dataset[split]
        # print('the data we need split is:')
        # print(dataset)
        #
        # train, test = train_test_split(dataset, test_size=0.2)
        # train, val = train_test_split(train, test_size=0.2)
        #
        # print('split over')
        #
        # # 转换为 Dataset 对象
        # train_dataset = Dataset.from_dict(train)
        # val_dataset = Dataset.from_dict(val)
        # test_dataset = Dataset.from_dict(test)
        #
        # # 构建 DatasetDict
        # self.dataset = DatasetDict({
        #     'train': train_dataset,
        #     'validation': val_dataset,
        #     'test': test_dataset
        # })
        # print('now the split dataset is :')
        # print(self.dataset)
        pass


    def transform_dataset_for_t5(self):
        print('we begin to transform the dataset')
        for dataset_name, mode in self.preprocessing_modes.items():
            if dataset_name in self.dataset:
                print("now we transform %s"%(dataset_name))
                if mode == 1:
                    self.dataset[dataset_name] = self.apply_mode_1(self.dataset[dataset_name])
                elif mode == 2:
                    self.dataset[dataset_name] = self.apply_mode_2(self.dataset[dataset_name])
                elif mode == 3:
                    self.dataset[dataset_name] = self.apply_mode_3(self.dataset[dataset_name])
                print("%s data has been transformed " % (dataset_name))

    def apply_mode_1(self, dataset):
        # 模式 1: 随机选择一个前缀
        def transform(example):
            prefix = random.choice(self.prefixes)
            return self.process_example(prefix, example)

        return dataset.map(transform)

    def apply_mode_2(self, dataset):
        # 模式 2: 每个样本与所有前缀组合

        # 生成新样本的列表
        new_samples = []

        # 遍历原始数据集中的每个样本
        for example in dataset:
            # 对于每个样本，遍历所有前缀
            for prefix in self.prefixes:
                # 处理每个样本与前缀的组合
                processed_example = self.process_example(prefix, example)

                # 将原始数据和处理后的数据结合起来
                combined_example = {
                    **example,  # 添加原始样本的所有字段
                    **processed_example,  # 添加处理后的字段
                    }
                new_samples.append(combined_example)

        # 将新样本列表转换为 Dataset
        return Dataset.from_pandas(pd.DataFrame(new_samples))

    def apply_mode_3(self, dataset):
        # 模式 3: 使用列表的第一个前缀
        def transform(example):
            prefix = self.prefixes[0]
            return self.process_example(prefix, example)

        return dataset.map(transform)

    def process_example(self, prefix, example):
        # 这个方法应该由子类实现
        raise NotImplementedError

    def show_sample(self, index, split='train'):
        # 展示指定序列号的数据

        if split in self.dataset:
            if index < len(self.dataset[split]):
                #print(len(self.dataset[split]))
                return self.dataset[split][index]
            else:
                return f"Index {index} is out of range for the {split} dataset."
        else:
            return f"Dataset split '{split}' not found."

    #选k-shot
    def create_few_shot_dataset(self, num_samples_per_class):
        # 确保训练集存在
        if 'train' not in self.dataset:
            raise ValueError("Training dataset not found.")

        # 获取训练集并将其转换为DataFrame
        train_df = pd.DataFrame(self.dataset['train'])

        # 确定类别列的名称
        # 这需要根据您的数据集进行调整
        class_column = 'label'

        # 筛选每个类别的样本
        few_shot_samples = []
        for label in train_df[class_column].unique():
            samples = train_df[train_df[class_column] == label].sample(n=num_samples_per_class)
            few_shot_samples.append(samples)

        # 组合成新的DataFrame
        few_shot_df = pd.concat(few_shot_samples)

        # 转换回Dataset
        few_shot_dataset = Dataset.from_pandas(few_shot_df)

        # 将新的训练集添加到dataset字典
        self.dataset['few-shot_train'] = few_shot_dataset

    def get_dataset(self):
        return self.dataset



