from pathlib import Path
from functools import partial
import random

# Initialize static strings for the prompt template
INSTRUCTION_KEY = "### Instruction:" #"<instruction>" 
INSTRUCTION_KEY_END = ''
INPUT_KEY = "INPUT:" 
INPUT_KEY_END = ''
NEW_START = 'NEW:'
NEW_END = ''
OLD_START = 'OLD:'
OLD_END = ''
RESPONSE_KEY = 'RESPONSE:'
END_KEY = '### End'

 

INSTRUCTION_KEY_ST = "<instruction>" 
INSTRUCTION_KEY_END_ST = '</instruction>'
INPUT_KEY_ST = '<input>'
INPUT_KEY_END_ST = '</input>'
NEW_START_ST = '<new>'
NEW_END_ST = '</new>'
OLD_START_ST = '<old>'
OLD_END_ST = '</old>'
RESPONSE_KEY_ST = "<response>"
END_KEY_ST = "</response>"


TASK_PROMPT0 = "Classify the intent of the following sentence edit. The possible labels are: Grammar, Clarity, Fact/Evidence, Claim, Other. "

PROMPT_DIC ={'pt0':TASK_PROMPT0}
PROMPT_ST_DIC = {'nl': [INSTRUCTION_KEY,INSTRUCTION_KEY_END, INPUT_KEY, INPUT_KEY_END, OLD_START,OLD_END, NEW_START, NEW_END, RESPONSE_KEY, END_KEY],
                 'st': [INSTRUCTION_KEY_ST,INSTRUCTION_KEY_END_ST, INPUT_KEY_ST, INPUT_KEY_END_ST, OLD_START_ST,OLD_END_ST, NEW_START_ST, NEW_END_ST, RESPONSE_KEY_ST, END_KEY_ST]}

class DataPreprocessor:
    def __init__(self) -> None:
        print('Preprocessing the data...')

    def preprocess_data(self, dataset, is_train:bool=True, prompt_type:str='pt0', prompt_st_type:str='nl'):
        """
        :param dataset: dataset to be preprocessed
        :param is_train: whether the dataset is for training or testing
        :param prompt_type: type of the task instruction
        :param prompt_st_type: input types, 'st' for structured input, 'nl' for natural language input
        """
        self.prompt_type = prompt_type
        self.prompt_st_type = prompt_st_type
        instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, \
        new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type]
        
        # Add prompt to each sample
        print("Preprocessing dataset...")
        if is_train:
            dataset = dataset.map(self.create_prompt_formats_train)
        else:
            dataset = dataset.map(self.create_prompt_formats_test)
        remove_cols = [n for n in dataset.column_names if n not in ["text","label"]]
        # remove columns
        dataset = dataset.remove_columns(remove_cols)
        # Shuffle dataset
        seed = 42
        dataset = dataset.shuffle(seed = seed)
        return dataset, response_key
    
    def create_task_prompt(self, prompt_type):
        """
        Creates a task prompt based on the prompt type
        :param prompt_type: Type of the prompt
        """
        if prompt_type in PROMPT_DIC:
            p = PROMPT_DIC[prompt_type]
        else:
                assert False, f"Invalid prompt type: {prompt_type}"
        return p

            
        
    def create_prompt_formats_train(self, sample):
        """
        Creates a formatted prompt template for a prompt in the dataset
        :param sample: sample from the dataset
        """
        instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type]
        task_prompt = self.create_task_prompt(self.prompt_type)
        # Combine a prompt with the static strings
        instruction = f"{instruction_key} {task_prompt} {instruction_key_end}"

        text_src = sample['text_src'] if sample['text_src'] is not None else ''
        text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else ''
        input_context = f"{input_key}\n {old_start} {text_tgt} {old_end}\n {new_start} {text_src} {new_end}\n{input_key_end}" 
        response = f"{response_key}{sample['label']}"
        end = f"{end_key}"
        # Create a list of prompt template elements
        parts = [part for part in [instruction, input_context, response, end] if part]
        # Join prompt template elements into a single string to create the prompt template
        formatted_prompt = "\n".join(parts)
        # Store the formatted prompt template in a new key "text"
        sample["text"] = formatted_prompt
        return sample
    
    def create_prompt_formats_test(self, sample):
        """
        Creates a formatted prompt template for a prompt in the dataset
        :param sample: sample from the dataset
        """
        instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type]
        task_prompt = self.create_task_prompt(self.prompt_type)
        instruction = f"{instruction_key} {task_prompt} {instruction_key_end}"
        text_src = sample['text_src'] if sample['text_src'] is not None else ''
        text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else ''
        input_context = f"{input_key}\n {old_start} {text_tgt} {old_end}\n {new_start} {text_src} {new_end}\n{input_key_end}" 
        response = f"{response_key}"
        parts = [part for part in [instruction, input_context, response] if part]
        formatted_prompt = "\n".join(parts)
        sample["text"] = formatted_prompt
        return sample
        

def main():
    ''

if __name__ == "__main__":
    main()