from pathlib import Path
from functools import partial
import random


class DataPreprocessor:
    def __init__(self) -> None:
        '''
        '''
        
        print('Preprocessing the data...ft_llm_c')

    def preprocess_data(self, dataset, label2id, tokenizer, max_length=1024, input_type='text_st_on', has_label=True, remove_unused_col=False, shuffle=True):
        """
        :param model: Hugging Face model
        :param tokenizer (AutoTokenizer): Model tokenizer
        :param max_length (int): Maximum number of tokens to emit from the tokenizer
        :param dataset (str): Instruction dataset
        """
        # perpare input text and label
        self.label2id = label2id
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_type = input_type
        self.has_label = has_label
        print("Preprocessing dataset...ft_llm_c")
        print("input_type", input_type)
        print('max_length', max_length)
        remove_cols = [n for n in dataset.column_names if n not in ["text","label", "input_ids_text", "attention_mask_text"]]
        print(remove_cols)
        if remove_unused_col:
            dataset = dataset.map(lambda x: self.create_input_text_and_label(x, self.input_type), remove_columns = remove_cols)
        else:
            dataset = dataset.map(lambda x: self.create_input_text_and_label(x, self.input_type))

        if not shuffle:
            return dataset

        
        # Shuffle dataset
        seed = 42
        dataset = dataset.shuffle(seed = seed)
        return dataset
    
    def create_input_text_and_label(self, sample, input_type):
        """
        Creates a formatted prompt template for a prompt in the dataset
        :param sample: sample from the dataset
        """
        instruction = "Classify the intent of the following sentence edit. The possible labels are: Grammar, Clarity, Fact/Evidence, Claim, Other. " 
        text_src = sample['text_src'] if sample['text_src'] is not None else ''
        text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else ''

        if input_type == 'text_st_on':
            #sample["text_old"] = f"<old> {text_tgt} </old>"
            #sample["text_new"] = f"<new> {text_src} </new>"
            sample["text"] = f"<old> {text_tgt} </old>" + '\n ' + f"<new> {text_src} </new>"
       # elif self.input_type == 'text_s_on':
       #     sample["text_old"] = f"<s> {text_tgt} </s>"
       #     sample["text_new"] = f"<s> {text_src} </s>"
       #     sample["text"] = sample["text_old"] + '\n ' + sample["text_new"]
        elif input_type == 'text_on':
            #sample["text_old"] = text_tgt
            #sample["text_new"] = text_src
            sample["text"] = text_tgt + '\n ' + text_src
        elif input_type == 'inst_text_st_on':
            #sample["text_old"] = f"<old> {text_tgt} </old>"
            #sample["text_new"] = f"<new> {text_src} </new>"
            #print('here!!!')
            sample["text"] = instruction + '\n ' + f"<old> {text_tgt} </old>" + '\n ' + f"<new> {text_src} </new>"
            #print('sample["text"]', sample["text"])
        elif input_type == 'inst_text_on':
            #sample["text_old"] = text_tgt
            #sample["text_new"] = text_src
            sample["text"] = instruction + '\n ' + text_tgt + '\n ' + text_src

        if self.has_label:
            label = self.label2id[sample['label']]
            sample["label"] = label

        sample['input_ids_text'] = self.tokenizer.encode_plus(sample["text"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")["input_ids"]
        sample['attention_mask_text'] = self.tokenizer.encode_plus(sample["text"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")["attention_mask"]
        return sample
    
        

def main():
    ''

if __name__ == "__main__":
    main()