from pathlib import Path
import shutil


def main():
    ### arguments ###
    task_name = 'edit_intent_classification'
    method = 'finetuning_llm_c' # Gen:finetuning_llm, SeqC:finetuning_llm_c, XNet:finetuning_llm_c_cross, SNet:finetuning_llm_c_siamese
    train_type = 'train'  # training date name in the folder data/Re3-Sci/tasks/<task_name>
    val_type = 'val'  # validation date name in the folder data/Re3-Sci/tasks/<task_name>
    test_type = 'test' # test date name in the folder data/Re3-Sci/tasks/<task_name>
    ### arguments ###
    # 1.load task dataset
    from tasks.task_data_loader import TaskDataLoader
    task_data_loader = TaskDataLoader(task_name=task_name, train_type=train_type, val_type=val_type, test_type=test_type)
    train_ds, val_ds, test_ds = task_data_loader.load_data()
    labels, label2id, id2label = task_data_loader.get_labels()

    print('train_ds: ', train_ds)
    print('val_ds: ', val_ds)
    print('test_ds: ', test_ds)
    print('labels: ', labels)


    # 2.load model and tokenizer
    ### arguments ###
    model_root_path, model_name, model_version ='','','' #local locations of the model and tokenizer, to be defined
    device_map = "auto"  # 'auto''
    num_cls_layers = 1
    emb_type = None # for XNet and SNet: ['diffABS', 'diff', 'n-diffABS', 'n-o', 'n-diffABS-o']
    input_type = 'text_st_on'  # 'text_st_on': So + Sn in structured input format, 'inst_text_st_on': inst + So + Sn in structured input format
    ### arguments ###
    from tasks.task_model_loader import TaskModelLoader
    model_loader = TaskModelLoader(task_name=task_name, method=method).model_loader
    model, tokenizer = model_loader.load_model_from_path_name_version(model_root_path, model_name, model_version,
                                                                          labels, label2id, id2label,
                                                                          device_map=device_map)

    print('model: ', model)
    print('tokenizer: ', tokenizer)
    print('model.config.id2label', model.config.id2label)
    print('model.config.label2id', model.config.label2id)
    print('model.config', model.config)

    # 3.preprocess dataset....
    ## arguments ###
    max_length = 1024 # 512 for SNet, 1024 for XNet, Gen, SeqC
    shuffle = True # set to Ture
    ## arguments ###
    from tasks.task_data_preprocessor import TaskDataPreprocessor
    data_preprocessor = TaskDataPreprocessor(task_name=task_name, method=method).data_preprocessor
    if train_ds is not None:
        train_ds = data_preprocessor.preprocess_data(train_ds, label2id, tokenizer, max_length=max_length,
                                                     input_type=input_type, shuffle=shuffle)
    if val_ds is not None:
        val_ds = data_preprocessor.preprocess_data(val_ds, label2id, tokenizer, max_length=max_length,
                                                   input_type=input_type, shuffle=shuffle)
    if test_ds is not None:
        test_ds = data_preprocessor.preprocess_data(test_ds, label2id, tokenizer, max_length=max_length,
                                                    input_type=input_type, shuffle=shuffle)

    # 4.fine-tune model....'
    ## arguments ###
    lora_r = 128  # LoRA parameters
    lora_alpha = 128  # Alpha parameter for LoRA scaling
    lora_dropout = 0.1  # Dropout probability for LoRA layers
    bias = "none"  # Bias
    learning_rate = 2e-4  # 2e-4
    per_device_train_batch_size = 32  # Batch size per GPU for training
    train_epochs = 10  # Number of epochs to train
    target_modules = "all-linear"
    create_dir = True
    ## arguments ###

    if method == 'finetuning_llm':
        task_type = "CAUSAL_LM"
    elif method == 'finetuning_llm_c':
        task_type = "SEQ_CLS"
    else:
        task_type = None
    # create model dir
    output_dir = Path("./results")
    if not output_dir.exists():
        output_dir.mkdir()
    output_dir = output_dir / task_name
    if not output_dir.exists():
        output_dir.mkdir()
    output_dir = output_dir / method
    if not output_dir.exists():
        output_dir.mkdir()
    model_folder_name = model_name + '-' + model_version + '_' + f'_lora-r{lora_r}-a{lora_alpha}-d{lora_dropout}'
    model_folder_name += f'_bs{per_device_train_batch_size}_ep{train_epochs}'
    model_folder_name += f'_{train_type}'
    model_folder_name += f'_{test_type}'

    if target_modules != "all-linear":
        t = [i[0] for i in target_modules]
        t = ''.join(t)
        model_folder_name += f'_{t}'

    model_folder_name += f'_ml{max_length}'
    model_folder_name += f'_cls{num_cls_layers}'
    model_folder_name += f'_{emb_type}'
    model_folder_name += f'_lr{learning_rate}'
    model_folder_name += f'_{input_type}'
    output_dir = output_dir / model_folder_name
    output_dir = output_dir

    if create_dir:
        if output_dir.exists():
            shutil.rmtree(output_dir)
        if not output_dir.exists():
            output_dir.mkdir()
    print('output_dir: ', output_dir)

    # fine-tune
    from tasks.task_model_finetuner import TaskModelFinetuner
    model_finetuner = TaskModelFinetuner(task_name=task_name, method=method).model_finetuner
    model_finetuner.fine_tune(model, tokenizer, train_ds, val_ds,
                              lora_r, lora_alpha, lora_dropout, bias, task_type,
                              per_device_train_batch_size, output_dir, train_epochs,
                              target_modules=target_modules, learning_rate=learning_rate)

    # 5.evaluate fine-tuned model....
    from tasks.task_evaluater import TaskEvaluater
    evaluater = TaskEvaluater(task_name=task_name, method=method).evaluater
    evaluater.evaluate(test_ds, labels, label2id, id2label, output_dir, emb_type=emb_type, input_type=input_type)

if __name__ == "__main__":
    main()
