import argparse
import os
import logging
import random

import json
import time

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
    get_linear_schedule_with_warmup, BertTokenizer, RobertaTokenizer,
    BertConfig, RobertaConfig, RobertaModel, AutoTokenizer, AutoModel
)

from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from model import DMBert, RawBert, DMBertRelation, RawBertRelation, DMBertArg, RawBertArg, DMBertArgRelation
from dataloader import EFPDataset, EFPDatasetR, EFPDatasetArg, EFPDatasetArgR



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_data", type=str, default="../data/train.jsonl")
    parser.add_argument("--eval_data", type=str, default="../data/valid.jsonl")
    parser.add_argument("--test_data", type=str, default="../test.jsonl")
    parser.add_argument("--model_dir", type=str, default="../models")
    parser.add_argument("--log_dir", type=str, default="../logs")
    parser.add_argument("--model_name", type=str, default="roberta-large")
    parser.add_argument("--model", type=str, default="DMBert")
    parser.add_argument("--ckpt", type=str, default="roberta-large")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--max_length", type=int, default=160)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--seed", type=int, default=42)
    return parser.parse_args()


def evaluate(preds, labels, mode):
    gold_label = {"CT+": 0, "CT-": 1, "PS+": 2, "PS-": 3, "Uu": 4}
    gold_label_pair = {"CT": [0, 1], "PS": [2, 3], "p": [0, 2], "n": [1, 3]}

    if mode in gold_label:
        tp = sum([1 for p, l in zip(preds, labels) if p == gold_label[mode] and l == gold_label[mode]])
        fp = sum([1 for p, l in zip(preds, labels) if p == gold_label[mode] and l != gold_label[mode]])
        fn = sum([1 for p, l in zip(preds, labels) if p != gold_label[mode] and l == gold_label[mode]]) 
    elif mode in gold_label_pair:
        tp = sum([1 for p, l in zip(preds, labels) if p in gold_label_pair[mode] and l in gold_label_pair[mode]])
        fp = sum([1 for p, l in zip(preds, labels) if p in gold_label_pair[mode] and l not in gold_label_pair[mode]])
        fn = sum([1 for p, l in zip(preds, labels) if p not in gold_label_pair[mode] and l in gold_label_pair[mode]])
    else:
        raise ValueError("Invalid evaluation mode")

    precision = tp / (tp + fp) if tp + fp != 0 else 0
    recall = tp / (tp + fn) if tp + fn != 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
    return precision, recall, f1

        

def main():
    args = parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_save_path = os.path.join(args.model_dir, args.model_name)
    os.makedirs(model_save_path, exist_ok=True)
    log_save_path = os.path.join(args.log_dir, args.model_name)
    os.makedirs(log_save_path, exist_ok=True)

    logger = logging.getLogger()
    handler = logging.FileHandler(os.path.join(log_save_path, f'log_{args.model}_bs{args.batch_size}_ml{args.max_length}_lr{args.lr}.log'))
    logger.setLevel(logging.DEBUG)
    logger.addHandler(handler)

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    tokenizer = AutoTokenizer.from_pretrained(args.ckpt)
    tokenizer.add_special_tokens({'additional_special_tokens': ['<e>', '</e>', '<p>', '</p>', '<c>', '</c>']})

    if args.model == "RawBert":
        model = RawBert(args.ckpt, dropout=args.dropout, tokenizer_size=len(tokenizer), num_labels=5).to(device)
    elif args.model == "DMBertRelation":
        model = DMBertRelation(args.ckpt, dropout=args.dropout, tokenizer_size=len(tokenizer), num_labels=5).to(device)
    elif args.model == "RawBertRelation":
        model = RawBertRelation(args.ckpt, dropout=args.dropout, tokenizer_size=len(tokenizer), num_labels=5).to(device)
    elif args.model == "DMBertArg":
        model = DMBertArg(args.ckpt, dropout=args.dropout, tokenizer_size=len(tokenizer), num_labels=5).to(device)
    elif args.model == "RawBertArg":
        model = RawBertArg(args.ckpt, dropout=args.dropout, tokenizer_size=len(tokenizer), num_labels=5).to(device)
    elif args.model == "DMBertArgRelation":
        model = DMBertArgRelation(args.ckpt, dropout=args.dropout, tokenizer_size=len(tokenizer), num_labels=5).to(device)
    else:
        model = DMBert(args.ckpt, dropout=args.dropout, tokenizer_size=len(tokenizer), num_labels=5).to(device)


    if "Relation" in args.model and "Arg" in args.model:
        train_dataset = EFPDatasetArgR(data_dir=args.train_data, tokenizer=tokenizer, max_length=args.max_length, split='train')
        eval_dataset = EFPDatasetArgR(data_dir=args.eval_data, tokenizer=tokenizer, max_length=args.max_length, split='eval')
        test_dataset = EFPDatasetArgR(data_dir=args.test_data, tokenizer=tokenizer, max_length=args.max_length, split='test')
    elif "Relation" in args.model:
        train_dataset = EFPDatasetR(data_dir=args.train_data, tokenizer=tokenizer, max_length=args.max_length, split='train')
        eval_dataset = EFPDatasetR(data_dir=args.eval_data, tokenizer=tokenizer, max_length=args.max_length, split='eval')
        test_dataset = EFPDatasetR(data_dir=args.test_data, tokenizer=tokenizer, max_length=args.max_length, split='test')
    elif "Arg" in args.model:
        train_dataset = EFPDatasetArg(data_dir=args.train_data, tokenizer=tokenizer, max_length=args.max_length, split='train')
        eval_dataset = EFPDatasetArg(data_dir=args.eval_data, tokenizer=tokenizer, max_length=args.max_length, split='eval')
        test_dataset = EFPDatasetArg(data_dir=args.test_data, tokenizer=tokenizer, max_length=args.max_length, split='test')
    else:
        train_dataset = EFPDataset(data_dir=args.train_data, tokenizer=tokenizer, max_length=args.max_length, split='train')
        eval_dataset = EFPDataset(data_dir=args.eval_data, tokenizer=tokenizer, max_length=args.max_length, split='eval')
        test_dataset = EFPDataset(data_dir=args.test_data, tokenizer=tokenizer, max_length=args.max_length, split='test')
    
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    loss_fn = torch.nn.CrossEntropyLoss()
    best_macro_f1 = 0
    best_tst_macro_f1 = 0

    model.train()

    for epoch in range(args.epochs):
        start_time = time.time()
        for data in tqdm(train_dataloader, desc=f"Epoch {epoch} training: "):
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            labels = data['labels'].to(device)
            maskL = data['maskL'].to(device)
            maskR = data['maskR'].to(device)

            optimizer.zero_grad()
            if "Relation" in args.model and "Arg" in args.model:
                arg_ids = data['arg_ids'].to(device)
                arg_mask = data['arg_mask'].to(device)
                cause_ids = data['cause_ids'].to(device)
                precondition_ids = data['precondition_ids'].to(device)
                cause_mask = data['cause_mask'].to(device)
                precondition_mask = data['precondition_mask'].to(device)
                logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, arg_ids=arg_ids, arg_mask=arg_mask, cause_ids=cause_ids, cause_mask=cause_mask, precondition_ids=precondition_ids, precondition_mask=precondition_mask)
            elif "Relation" in args.model:
                cause_ids = data['cause_ids'].to(device)
                precondition_ids = data['precondition_ids'].to(device)
                cause_mask = data['cause_mask'].to(device)
                precondition_mask = data['precondition_mask'].to(device)
                logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, cause_ids=cause_ids, precondition_ids=precondition_ids, cause_mask=cause_mask, precondition_mask=precondition_mask)
            elif "Arg" in args.model:
                arg_ids = data['arg_ids'].to(device)
                arg_mask = data['arg_mask'].to(device)
                logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, arg_ids=arg_ids, arg_mask=arg_mask)
            else:
                logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()

        end_time = time.time()
        logger.info(f"Epoch {epoch} training time: {end_time - start_time}")

        model.eval()
        with torch.no_grad():
            eval_preds = []
            eval_labels = []
            for data in tqdm(eval_dataloader, desc=f"Epoch {epoch} eval: "):
                input_ids = data['input_ids'].to(device)
                attention_mask = data['attention_mask'].to(device)
                labels = data['labels'].to(device)
                maskL = data['maskL'].to(device)
                maskR = data['maskR'].to(device)
                if "Relation" in args.model and "Arg" in args.model:
                    arg_ids = data['arg_ids'].to(device)
                    arg_mask = data['arg_mask'].to(device)
                    cause_ids = data['cause_ids'].to(device)
                    precondition_ids = data['precondition_ids'].to(device)
                    cause_mask = data['cause_mask'].to(device)
                    precondition_mask = data['precondition_mask'].to(device)
                    logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, arg_ids=arg_ids, arg_mask=arg_mask, cause_ids=cause_ids, cause_mask=cause_mask, precondition_ids=precondition_ids, precondition_mask=precondition_mask)
                elif "Relation" in args.model:
                    cause_ids = data['cause_ids'].to(device)
                    precondition_ids = data['precondition_ids'].to(device)
                    cause_mask = data['cause_mask'].to(device)
                    precondition_mask = data['precondition_mask'].to(device)
                    logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, cause_ids=cause_ids, precondition_ids=precondition_ids, cause_mask=cause_mask, precondition_mask=precondition_mask)
                elif "Arg" in args.model:
                    arg_ids = data['arg_ids'].to(device)
                    arg_mask = data['arg_mask'].to(device)
                    logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, arg_ids=arg_ids, arg_mask=arg_mask)
                else:
                    logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR)
                preds = torch.argmax(logits, dim=1)
                eval_preds.extend(preds.cpu().numpy())
                eval_labels.extend(labels.cpu().numpy())

            eval_preds = np.array(eval_preds)
            eval_labels = np.array(eval_labels)

            CTp_precision, CTp_recall, CTp_f1 = evaluate(eval_preds, eval_labels, mode="CT+")
            CTn_precision, CTn_recall, CTn_f1 = evaluate(eval_preds, eval_labels, mode="CT-")
            PSp_precision, PSp_recall, PSp_f1 = evaluate(eval_preds, eval_labels, mode="PS+")
            PSn_precision, PSn_recall, PSn_f1 = evaluate(eval_preds, eval_labels, mode="PS-")
            Uu_precision, Uu_recall, Uu_f1 = evaluate(eval_preds, eval_labels, mode="Uu")

            macro_f1 = f1_score(eval_labels, eval_preds, average='macro')
            accuracy = accuracy_score(eval_labels, eval_preds)
            


            logger.info(f"Epoch {epoch} eval results:")
            logger.info(f"CT+ precision: {CTp_precision}, recall: {CTp_recall}, f1: {CTp_f1}")
            logger.info(f"CT- precision: {CTn_precision}, recall: {CTn_recall}, f1: {CTn_f1}")
            logger.info(f"PS+ precision: {PSp_precision}, recall: {PSp_recall}, f1: {PSp_f1}")
            logger.info(f"PS- precision: {PSn_precision}, recall: {PSn_recall}, f1: {PSn_f1}")
            logger.info(f"Uu precision: {Uu_precision}, recall: {Uu_recall}, f1: {Uu_f1}")
            logger.info(f"Macro F1: {macro_f1}")
            logger.info(f"Accuracy: {accuracy}")

            if macro_f1 > best_macro_f1:
                best_macro_f1 = macro_f1
                logger.info(f"Best val model at epoch {epoch}")

            test_preds = []
            test_labels = []
            for data in tqdm(test_dataloader, desc=f"Epoch {epoch} test: "):
                input_ids = data['input_ids'].to(device)
                attention_mask = data['attention_mask'].to(device)
                labels = data['labels'].to(device)
                maskL = data['maskL'].to(device)
                maskR = data['maskR'].to(device)
                if "Relation" in args.model and "Arg" in args.model:
                    arg_ids = data['arg_ids'].to(device)
                    arg_mask = data['arg_mask'].to(device)
                    cause_ids = data['cause_ids'].to(device)
                    precondition_ids = data['precondition_ids'].to(device)
                    cause_mask = data['cause_mask'].to(device)
                    precondition_mask = data['precondition_mask'].to(device)
                    logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, arg_ids=arg_ids, arg_mask=arg_mask, cause_ids=cause_ids, cause_mask=cause_mask, precondition_ids=precondition_ids, precondition_mask=precondition_mask)
                elif "Relation" in args.model:
                    cause_ids = data['cause_ids'].to(device)
                    precondition_ids = data['precondition_ids'].to(device)
                    cause_mask = data['cause_mask'].to(device)
                    precondition_mask = data['precondition_mask'].to(device)
                    logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, cause_ids=cause_ids, precondition_ids=precondition_ids, cause_mask=cause_mask, precondition_mask=precondition_mask)
                elif "Arg" in args.model:
                    arg_ids = data['arg_ids'].to(device)
                    arg_mask = data['arg_mask'].to(device)
                    logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR, arg_ids=arg_ids, arg_mask=arg_mask)
                
                else:
                    logits = model(input_ids=input_ids, attention_mask=attention_mask, maskL=maskL, maskR=maskR)
                preds = torch.argmax(logits, dim=1)
                test_preds.extend(preds.cpu().numpy())
                test_labels.extend(labels.cpu().numpy())

            test_preds = np.array(test_preds)
            test_labels = np.array(test_labels)

            tst_CTp_precision, tst_CTp_recall, tst_CTp_f1 = evaluate(test_preds, test_labels, mode="CT+")
            tst_CTn_precision, tst_CTn_recall, tst_CTn_f1 = evaluate(test_preds, test_labels, mode="CT-")
            tst_PSp_precision, tst_PSp_recall, tst_PSp_f1 = evaluate(test_preds, test_labels, mode="PS+")
            tst_PSn_precision, tst_PSn_recall, tst_PSn_f1 = evaluate(test_preds, test_labels, mode="PS-")
            tst_Uu_precision, tst_Uu_recall, tst_Uu_f1 = evaluate(test_preds, test_labels, mode="Uu")

            tst_macro_f1 = f1_score(test_labels, test_preds, average='macro')
            tst_accuracy = accuracy_score(test_labels, test_preds)
            
            logger.info(f"Epoch {epoch} test results:")
            logger.info(f"CT+ precision: {tst_CTp_precision}, recall: {tst_CTp_recall}, f1: {tst_CTp_f1}")
            logger.info(f"CT- precision: {tst_CTn_precision}, recall: {tst_CTn_recall}, f1: {tst_CTn_f1}")
            logger.info(f"PS+ precision: {tst_PSp_precision}, recall: {tst_PSp_recall}, f1: {tst_PSp_f1}")
            logger.info(f"PS- precision: {tst_PSn_precision}, recall: {tst_PSn_recall}, f1: {tst_PSn_f1}")
            logger.info(f"Uu precision: {tst_Uu_precision}, recall: {tst_Uu_recall}, f1: {tst_Uu_f1}")
            logger.info(f"Macro F1: {tst_macro_f1}")
            logger.info(f"Accuracy: {tst_accuracy}")

            if tst_macro_f1 > best_tst_macro_f1:
                best_tst_macro_f1 = tst_macro_f1
                torch.save(model.state_dict(), os.path.join(model_save_path, f'best_tst_{args.model}_{args.model_name}.pt'))
                logger.info(f"Best test model at epoch {epoch}")
        model.train()
    


if __name__ == "__main__":
    main()








            

