import os, datetime
import pandas as pd
import matplotlib.pyplot as plt
from pretty_confusion_matrix import pp_matrix
from torchmetrics.classification.confusion_matrix import ConfusionMatrix
    
def make_confusion_matrix(output_path, pred, target, categories):
    confusion = ConfusionMatrix(task="multiclass", num_classes=len(categories))(pred, target)
    df_cm = pd.DataFrame(confusion, iter(categories), iter(categories))
    pp_matrix(df_cm, cmap='PuRd')
    plt.savefig(output_path)


def make_data_frame(args, save_name, test_result, log_except_list):
    now = datetime.datetime.now()
    _time = now.strftime('%Y-%m-%d %H:%M:%S')
    test_result["save_name"] = save_name
    for hpram, value in args.__dict__.items():
        if(hpram in log_except_list): continue
        test_result[hpram] = value
    test_result["date"] = _time
    return pd.DataFrame([test_result])


def save_results(args, save_name, test_result, log_except_list, result_file_name):
    result_df = []
    result_df.append(make_data_frame(args, save_name, test_result, log_except_list))
    result_df = pd.concat(result_df, axis=0)

    if(not os.path.exists(result_file_name)):
        result_df.to_csv(result_file_name, index=False, encoding="utf-8-sig")
    else:
        prev_df = pd.read_csv(result_file_name, index_col=False)
        result_df = pd.concat([prev_df, result_df], axis=0)
        result_df.to_csv(result_file_name, index=False, encoding="utf-8-sig")
