
from argparse import ArgumentParser
import pandas as pd
import warnings, os

import pytorch_lightning as pl

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor

from src.Models.Model import *
from src.log import *
from src.Dataset.DataModule import *

import config as cfg

from sklearn.metrics import f1_score, accuracy_score
warnings.filterwarnings(action='ignore')
    

bc_category = ["NoBC","continuer", "understanding","empathetic"]

def program_config(parser):
    parser.add_argument('--batch_size', default=cfg.batch_size, type=int)
    parser.add_argument('--epochs', default=cfg.epochs, type=int)
    parser.add_argument('--max_tokens', default=cfg.max_tokens, type=int)
    parser.add_argument("--accumulate_num", default=cfg.accumulate_num, type=int)
    parser.add_argument('--model_save_name', default=cfg.model_save_name, type=str)
    parser.add_argument("--learning_rate", default=cfg.learning_rate, type=float)
    parser.add_argument("--pretrain_LR", default=cfg.pretrain_LR, type=float)
    parser.add_argument("--dropout_rate", default=cfg.dropout_rate, type=float)
    parser.add_argument("--weight_decay", default=cfg.weight_decay, type=float)
    parser.add_argument("--memory_size", default=cfg.memory_size, type=int)
    parser.add_argument("--selective_history_type", default=cfg.selective_history_type, type=str, choices=["", "TAA", "MHA"])
    parser.add_argument("--use_holistic_history", default=cfg.use_holistic_history, type=str2bool) 
    parser.add_argument("--use_speaker_ids", default=cfg.use_speaker_ids, type=str2bool)
    parser.add_argument("--classifier_pool_dim", default=cfg.classifier_pool_dim, type=str)
    parser.add_argument("--accustic_feature", default=cfg.accustic_feature, type=str, choices=["", "wav2vec", "rnn"])

    args = parser.parse_args()
    return args

def cli_main():
    pl.seed_everything(1)

    parser = ArgumentParser()
    args = program_config(parser)
    
    output_file_index = []
    output_pred = []
    output_target = []
    print(args)

    for kfold in range(5):
        cur_save_name = f'{args.model_save_name}-k-{kfold}'
        dm = BackChannelDataModule(kfold, cfg.data_path, cfg.bert_path, cfg.audio_path, **args.__dict__)
        model = BackChannelModel(cfg.bert_path, **args.__dict__)

        os.makedirs(f'{cfg.log_save_path}/{args.model_save_name}', exist_ok=True)
        logger = TensorBoardLogger(
            save_dir=f'{cfg.log_save_path}/{args.model_save_name}', 
            name=cur_save_name,
            default_hp_metric=False
        )

        lr_monitor = LearningRateMonitor(logging_interval='step')
        checkpoint_callback = ModelCheckpoint(
            save_top_k=1,
            monitor=cfg.validation_metric,
            mode="max",
            dirpath=f"{cfg.model_save_path}/{args.model_save_name}",
            filename=cur_save_name,
            save_weights_only=True,
        )
        
        trainer = pl.Trainer(
            num_sanity_val_steps=0,
            max_epochs=args.epochs, 
            callbacks=[
                EarlyStopping(
                    monitor=cfg.validation_metric, 
                    patience=args.epochs // 2, 
                    mode='max'
                ), 
                checkpoint_callback, 
                lr_monitor
            ],
            accelerator="gpu",
            devices=1,
            logger=logger,
        )
        
        trainer.fit(model, datamodule=dm)
        load_model = model.load_from_checkpoint(
            os.path.join(cfg.model_save_path, args.model_save_name, f"{cur_save_name}.ckpt")
        )

        test_result = trainer.test(model=load_model, datamodule=dm)[0]
        output_file_index += list(dm.test_dataset.file_index)
        output_pred += list(trainer.model.outputs["test"]["pred"])
        output_target += list(trainer.model.outputs["test"]["target"])

    log_results = {}
    log_results["acc"] = accuracy_score(output_target, output_pred)
    log_results["w_f1"] = f1_score(output_target, output_pred, average="weighted")
    log_results["f1"] = f1_score(output_target, output_pred, average="macro")
    for c, f1 in zip(bc_category, f1_score(output_target, output_pred, average=None)):
        log_results[f"{c}_F1"] = f1
    
    make_confusion_matrix(
        f"{cfg.model_save_path}/{args.model_save_name}/{args.model_save_name}_confusion_matrix.png",
        torch.tensor(output_pred), torch.tensor(output_target), bc_category
    )
    
    save_results(
        args = args, 
        save_name = args.model_save_name, 
        test_result = log_results,
        log_except_list = cfg.log_except_list,
        result_file_name = cfg.result_file_name
    )
    
    pred_results = {
        "file_index":output_file_index, 
        "pred":output_pred, 
        "target":output_target
    }
    pd.DataFrame(pred_results).to_csv(
        f"{cfg.model_save_path}/{args.model_save_name}/{args.model_save_name}_predict_results.csv", 
        index=False, 
        encoding="utf-8-sig"
    )

if __name__ == '__main__':
    cli_main()

