# -*- coding: utf-8 -*-
"""
usage example:
accelerate launch --num_processes {numbers of GPU} llm_ets_classification_train_dev_tune.py --config llm_ets-longformer-hyperparameter-tunning.yaml

"""
import argparse
import yaml
import os
from tqdm import tqdm
import shutil
import os
import pandas as pd
import json
import numpy as np
from datasets import (
    Dataset,
    DatasetDict,
    Features,
    ClassLabel,
)
from accelerate import Accelerator

from transformers import (AutoTokenizer, DataCollatorWithPadding,
                          AutoModelForSequenceClassification,
                          TrainingArguments, Trainer, EarlyStoppingCallback,
                          set_seed)

from sklearn.metrics import accuracy_score, f1_score

import numpy as np
from datetime import datetime

parser = argparse.ArgumentParser()
parser.add_argument('--config', required=True)
args = parser.parse_args()

# YAML
with open(args.config, 'r') as file:
    config = yaml.safe_load(file)

print(type(config))

config['learning_rate'] = eval(config['learning_rate'])

if config['debug']:
    config['per_device_train_batch_size'] = 2
    config['per_device_eval_batch_size'] = 2
    config['num_train_epochs'] = 1



accelerator = Accelerator()

# setting seed before any other operation
set_seed(config['seed'])

os.makedirs(config['output_dir'], exist_ok=True)
print("data")

test_data_name = config['test_data_name']
train_df = pd.read_json(f"{config['input_dir']}/train64.jsonl", lines=True)
dev_df = pd.read_json(f"{config['input_dir']}/dev64.jsonl", lines=True)
test_df = pd.read_json(f"{config['input_dir']}/{test_data_name}", lines=True)


train_dataset = Dataset.from_pandas(train_df)
dev_dataset = Dataset.from_pandas(dev_df)
test_dataset = Dataset.from_pandas(test_df)
print('data_done:')
# Shuffle training set
train_dataset = train_dataset.shuffle(seed=config['seed'])

if config['debug']:
    train_dataset = train_dataset.select(range(32))
    dev_dataset = dev_dataset.select(range(32))
    test_dataset = test_dataset.select(range(32))

dataset = DatasetDict({
    'train': train_dataset,
    'dev': dev_dataset,
    'test': test_dataset,
})

# tokenizer = AutoTokenizer.from_pretrained(config['model_dir'], use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(config['model_dir'],
                                          add_prefix_space=True)

features = Features(
    {'label': ClassLabel(num_classes=5, names=['0', '1', '2', '3', '4'])})


def preprocess_function(key, examples, max_seq_length):
    if split == 'test':
                # tokenize inputs
        model_inputs = tokenizer(
                                examples["text"],
                                # examples["Excerpt"],
                                max_length=max_seq_length,
                                padding='max_length',
                                truncation=True)

        # model_inputs["label"] = float(examples["SMOG Readability"])
        model_inputs["label"] = float(examples["version"])
    else:
        # tokenize inputs
        model_inputs = tokenizer(examples["text"],
                                max_length=max_seq_length,
                                padding='max_length',
                                truncation=True)

        model_inputs["label"] = float(examples["version"])
    return model_inputs


# Using a for loop to process each dataset
tokenized_dataset = {}
for split in dataset.keys():
    tokenized_dataset[split] = []
    for example in tqdm(dataset[split]):
        processed_example = preprocess_function(split, 
                                                example,
                                                config['max_seq_length'])
        tokenized_dataset[split].append(processed_example)
print('token')
# Dropping an item by key
for key in dataset["train"].column_names:
    if key in tokenized_dataset:
        del tokenized_dataset[key]

tokenized_train_ds = Dataset.from_list(tokenized_dataset['train'])
tokenized_dev_ds = Dataset.from_list(tokenized_dataset['dev'])
tokenized_test_ds = Dataset.from_list(tokenized_dataset['test'])

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenized_dataset = DatasetDict({
    "train": tokenized_train_ds,
    "dev": tokenized_dev_ds,
    "test": tokenized_test_ds
})


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.clip(np.round(predictions), 0, 4)
    # predictions = np.argmax(predictions, axis=1)
    f1 = f1_score(labels, predictions,
                  average='weighted')  
    return {"accuracy": accuracy_score(labels, predictions), "f1": f1}


id2label = {0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}
label2id = {v:k for k,v in id2label.items()}

if config['debug']:
    config['trainer_logging_steps'] = 2
    config['trainer_save_steps'] = 12
    config['trainer_eval_steps'] = 4


# Define model
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(
        config['model_dir'],
        num_labels=1)
        # id2label=id2label,
        # label2id=label2id)


def delete_optimizer_scheduler_files(directory):
    """
    Delete optimizer.pt and scheduler.pt files from all subdirectories of the specified directory
    """

    if not os.path.exists(directory):
        print("The directory does not exist.")
        return

    for root, dirs, files in os.walk(directory):
        for file in files:
            if file in ("optimizer.pt", "scheduler.pt"):
                file_path = os.path.join(root, file)
                os.remove(file_path)
                print(f"Deleted {file_path}")


def delete_checkpoints(directory):

    if not os.path.exists(directory):
        print("Directory does not exist.")
        return


    for item_name in os.listdir(directory):
        item_path = os.path.join(directory, item_name)

        if os.path.isdir(item_path):
            shutil.rmtree(item_path)
            print(f"Deleted folder: {item_path}")
        else:

            os.remove(item_path)
            print(f"Deleted file: {item_path}")

def trainer_predict(test_data, input_road, config):
    print("#"*50)

    predictions = trainer.predict(test_data).predictions

    predicted_labels = np.clip(np.round(predictions), 0, 4)


    with open(input_road, 'r') as jsonl_file:
        original_data = [json.loads(line) for line in jsonl_file]


    assert len(original_data) == len(predicted_labels)

    
    with open(config['predict_saved_dir'], 'w') as jsonl_file:
         for i, data in enumerate(original_data):
            
            data['predicted_label'] = int(predicted_labels[i][0])
            data['predicted_score'] = float(predictions[i][0])
            # print(type(data['predicted_score']))
            # print(type(data['predicted_label']))
            
            jsonl_file.write(json.dumps(data, ensure_ascii=False) + '\n')

    print("Saved the new jsonl file with predictions.")


for learning_rate in [1e-5, 2e-5, 3e-5, 5e-5]:
    config['learning_rate'] = learning_rate

    print(f">>> Current learning rate: {learning_rate}")

    cur_output_dir = f'''{config['output_dir']}/{config['model_dir'].split('/')[-1]}/run_lr{learning_rate}_patience{config['early_stopping_patience']}_save@{config['save_strategy']}_seed{config['seed']}_{datetime.now().strftime('%b%d_%H-%M-%S')}'''
    cur_logging_dir = f'''{config['logging_dir']}/{config['model_dir'].split('/')[-1]}/run_lr{learning_rate}_patience{config['early_stopping_patience']}_save@{config['save_strategy']}_seed{config['seed']}_{datetime.now().strftime('%b%d_%H-%M-%S')}'''

    training_args = TrainingArguments(
        output_dir=cur_output_dir,
        learning_rate=learning_rate,
        per_device_train_batch_size=config['per_device_train_batch_size'],
        per_device_eval_batch_size=config['per_device_train_batch_size'],
        num_train_epochs=config['num_train_epochs'],
        weight_decay=config['weight_decay'],
        evaluation_strategy=config['evaluation_strategy'],
        save_strategy=config['save_strategy'],
        logging_steps=config['trainer_logging_steps'],
        save_steps=config['trainer_save_steps'],
        eval_steps=config['trainer_eval_steps'],
        warmup_steps=config['warmup_steps'],
        save_total_limit=config['save_total_limit'],
        load_best_model_at_end=True,
        report_to="tensorboard",
        logging_dir=cur_logging_dir,
        metric_for_best_model=config['metric_for_best_model'],
    )

    if config['mode'] == 'output':
        def compute_metrics(eval_pred):
            return {}

    trainer = Trainer(
        model_init=model_init,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["dev"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[
            EarlyStoppingCallback(
                early_stopping_patience=config['early_stopping_patience'])
        ],
    )

    if config['mode'] == 'train':
        trainer.train()

    if config['mode'] == 'test':
        
        test_results = trainer.evaluate(tokenized_dataset["test"])
        print(test_results)
        break
    
    if config['mode'] == 'output':
        
        trainer_predict(tokenized_dataset["test"], f"{config['input_dir']}/{test_data_name}", config)
        break
 
    # delete optimizer and scheduler after training
    if accelerator.is_main_process:
    # 在主进程中执行的代码
        delete_optimizer_scheduler_files(cur_output_dir)

    if config['debug']:
        delete_checkpoints(cur_output_dir)
        print('DEBUGGING CHECK COMPLETE!')
        break
