
import json
from dataclasses import dataclass
import os
import torch
import random
import copy
from torch.utils.data import Dataset
import transformers
import pickle
DATASET_ROOT = 'datas'
PROMPT_DICT = {
    "prompt_cot": (
        "Answer the given questions.Question:{} let's thinking step by step.\n"
    ),
    "prompt_pot": (
        "Answer the given questions.Question:{} Let's break down the code step by step.\n"
    ),
    "prompt_pcot":(
        """Answer the given questions.Question:{} let's write code with think step by step. \n """
    )
}
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

def _tokenize_fn(strings, tokenizer):
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

def preprocess(
    sources_cot,
    sources_pot,
    targets_cot,
    targets_pot,
    tokenizer):
    
    cot_examples = [s + t for s, t in zip(sources_cot, targets_cot)]
    pot_examples=[s + t for s, t in zip(sources_pot, targets_pot)]
    examples_tokenized_cot, sources_tokenized_cot= [_tokenize_fn(strings, tokenizer) for strings in (cot_examples, sources_cot)]
    examples_tokenized_pot, sources_tokenized_pot = [_tokenize_fn(strings, tokenizer) for strings in (pot_examples, sources_pot)]
    input_ids_cot = examples_tokenized_cot["input_ids"]
    input_ids_pot = examples_tokenized_pot["input_ids"]
    labels_cot = copy.deepcopy(input_ids_cot)
    labels_pot=copy.deepcopy(input_ids_pot)
    for label, source_len in zip(labels_cot, sources_tokenized_cot["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    for label, source_len in zip(labels_pot, sources_tokenized_pot["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=(input_ids_cot,input_ids_pot),labels=(labels_cot,labels_pot))




class TextDataset(Dataset):
    def __init__(self, tokenizer, args):
        assert os.path.isfile(args.train_data_file)
        directory, filename = os.path.split(args.train_data_file)
       # cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(args.max_input_length) + '_' + filename)
        if False:
            # print("Loading features from cached file %s", cached_features_file)
            # with open(cached_features_file, 'rb') as handle:
                self.examples = pickle.load(handle)
        else:
            print("Creating features from dataset file at %s", directory)

            self.examples = []
            with open(args.train_data_file,'r') as f:
                train_set=json.load(f)
            sample_size = int(len(train_set) * args.p)
            train_set = random.sample(train_set, sample_size)    
            prompt_cot, prompt_pot= PROMPT_DICT["prompt_cot"], PROMPT_DICT["prompt_pot"]
            inputs_cot= [prompt_cot.format(example['input'])  for example in train_set for i in range(args.n_cpot)]
            inputs_pot= [prompt_pot.format(example['input'])  for example in train_set for i in range(args.n_cpot)]
            cot_targets = [example[f"cot{i}"] for example in train_set for i in range(1, 11)]
            pot_targets = [example[f"pot{i}"] for example in train_set for i in range(1, 11)]
            data_dict = preprocess(inputs_cot,inputs_pot,cot_targets, pot_targets,tokenizer)
            self.input_ids = data_dict["input_ids"]
            self.labels = data_dict["labels"]
    def __len__(self):
        return len(self.input_ids[0])

    def __getitem__(self, i):
        return dict(input_ids=(self.input_ids[0][i],self.input_ids[1][i]), labels=(self.labels[0][i],self.labels[1][i]))


def load_and_cache_examples(args, tokenizer):
    dataset = TextDataset(tokenizer, args)
    return dataset

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances):
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids_cot = [t[0] for t in input_ids]
        input_ids_pot = [t[1] for t in input_ids]
        labels_cot=[t[0] for t in labels]
        labels_pot=[t[1] for t in labels]
        input_ids_cot = torch.nn.utils.rnn.pad_sequence(
            input_ids_cot, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        input_ids_pot = torch.nn.utils.rnn.pad_sequence(
            input_ids_pot, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels_cot = torch.nn.utils.rnn.pad_sequence(labels_cot, batch_first=True, padding_value=IGNORE_INDEX)
        labels_pot = torch.nn.utils.rnn.pad_sequence(labels_pot, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=(input_ids_cot,input_ids_pot),
            labels=(labels_cot,labels_pot),
            attention_mask=(input_ids_cot.ne(self.tokenizer.pad_token_id),input_ids_pot.ne(self.tokenizer.pad_token_id)),
        )





def preprocess_single(
    sources_pcot,
    targets_pcot,
    tokenizer):
    
    pcot_examples = [s + t for s, t in zip(sources_pcot, targets_pcot)]
    examples_tokenized_pcot, sources_tokenized_pcot= [_tokenize_fn(strings, tokenizer) for strings in (pcot_examples, sources_pcot)]
    input_ids_pcot = examples_tokenized_pcot["input_ids"]
    labels_pcot = copy.deepcopy(input_ids_pcot)
    for label, source_len in zip(labels_pcot, sources_tokenized_pcot["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids_pcot,labels=labels_pcot)

class TextDataset_single(Dataset):
    def __init__(self, tokenizer, args):
        assert os.path.isfile(args.train_data_file)
        directory, filename = os.path.split(args.train_data_file)
        cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(args.max_input_length) + '_' + filename)
        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            print("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, 'rb') as handle:
                self.examples = pickle.load(handle)
        else:
            print("Creating features from dataset file at %s", directory)

            self.examples = []
            with open(args.train_data_file,'r') as f:
                train_set=json.load(f)
            sample_size = len(train_set) * 1
            train_set = random.sample(train_set, sample_size)    
            prompt_cot, prompt_pot= PROMPT_DICT["prompt_cot"], PROMPT_DICT["prompt_pot"]
            inputs_cot= [prompt_cot.format(example['input'])  for example in train_set for i in range(args.n_cpot)]
            inputs_pot= [prompt_pot.format(example['input'])  for example in train_set for i in range(args.n_cpot)]
            cot_targets = [example[f"cot{i}"] for example in train_set for i in range(1, 11)]
            pot_targets = [example[f"pot{i}"] for example in train_set for i in range(1, 11)]
            inputs_pcot=inputs_cot+inputs_pot
            pcot_targets=cot_targets+pot_targets
            # prompt_pcot=PROMPT_DICT["prompt_pcot"]
            # inputs_pcot= [prompt_pcot.format(example['input'])  for example in train_set for i in range(10)]
            # pcot_targets = [example[f"pcot{i}"] for example in train_set for i in range(0, 10)]
            #pcot_targets = [example[f"pcot"] for example in train_set]
            ##直接微调
            # inputs_pcot= ["question:"+example['input']  for example in train_set for i in range(10)]
            # pcot_targets= [str(eval(example['label']))  for example in train_set for i in range(10)]
           # pot_targets = [example[f"pot{i}"] for example in train_set for i in range(1, 11)]
            #  inputs_pcot= [prompt_pcot.format(example['input']).replace('\n',"newline")  for example in train_set]
            # pcot_targets = [example[f"pcot"].replace('\n',"newline") for example in train_set]
            data_dict = preprocess_single(inputs_pcot,pcot_targets,tokenizer)
            self.input_ids = data_dict["input_ids"]
            self.labels = data_dict["labels"]
            

    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self, i):
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])

def load_and_cache_examples_single(args, tokenizer):
    dataset = TextDataset_single(tokenizer, args)
    return dataset

@dataclass
class DataCollatorForSupervisedDataset_single(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances):
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )





