import pandas as pd
import numpy as np
import argparse
import random
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, get_scheduler
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--train_label', type=str, default='ours', choices=['ours', 'gdelt'])
parser.add_argument('--train_num_class', type=int, default=2)
parser.add_argument('--test_num_class', type=int, default=2)
parser.add_argument('--model', type=str, default='bert', choices=['bert', 'roberta', 'albert', 'xlnet', 'distilbert'])

args = parser.parse_args()
args.exp_name = f'emnlp'
set_seed(args.seed)

train_data = pd.read_csv('./data/train_data.csv')
valid_data = pd.read_csv('./data/valid_data.csv')
test_data = pd.read_csv('./data/test_data.csv')

train_texts = train_data['text']
valid_texts = valid_data['text']
test_texts = test_data['text']
train_labels_1 = train_data['RE_label_new']
valid_labels_1 = valid_data['RE_label_new']
test_labels_1 = test_data['RE_label_new']
train_labels_2 = train_data['CE_QuadClass_Label']
valid_labels_2 = valid_data['CE_QuadClass_Label']
test_labels_2 = test_data['CE_QuadClass_Label']
test_labels_3 = test_data['Expert_label']

train_texts = train_texts[(train_labels_1 == 1) | (train_labels_1 == 0)]
train_labels_2 = train_labels_2[(train_labels_1 == 1) | (train_labels_1 == 0)]
train_labels_1 = train_labels_1[(train_labels_1 == 1) | (train_labels_1 == 0)]

valid_texts = valid_texts[(valid_labels_1 == 1) | (valid_labels_1 == 0)]
valid_labels_2 = valid_labels_2[(valid_labels_1 == 1) | (valid_labels_1 == 0)]
valid_labels_1 = valid_labels_1[(valid_labels_1 == 1) | (valid_labels_1 == 0)]

test_texts = test_texts[(test_labels_1 == 1) | (test_labels_1 == 0)]
test_labels_2 = test_labels_2[(test_labels_1 == 1) | (test_labels_1 == 0)]
test_labels_1 = test_labels_1[(test_labels_1 == 1) | (test_labels_1 == 0)]

class RelationDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

if args.model == 'bert':
    from transformers import BertTokenizer, BertForSequenceClassification
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=args.train_num_class)
elif args.model == 'roberta':
    from transformers import RobertaTokenizer, RobertaForSequenceClassification
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=args.train_num_class)
elif args.model == 'albert':
    from transformers import AlbertTokenizer, AlbertForSequenceClassification
    tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
    model = AlbertForSequenceClassification.from_pretrained('albert-base-v2', num_labels=args.train_num_class)
elif args.model == 'distilbert':
    from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=args.train_num_class)
elif args.model == 'xlnet':
    from transformers import XLNetTokenizer, XLNetForSequenceClassification
    tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
    model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=args.train_num_class)

train_encodings = tokenizer(train_texts.tolist(), truncation=True, padding=True, max_length=512)
valid_encodings = tokenizer(valid_texts.tolist(), truncation=True, padding=True, max_length=512)
test_encodings = tokenizer(test_texts.tolist(), truncation=True, padding=True, max_length=512)

from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
train_labels_encoded_2 = label_encoder.fit_transform(train_labels_2.tolist())
train_labels_encoded_1 = label_encoder.fit_transform(train_labels_1.tolist())

valid_labels_encoded_1 = label_encoder.transform(valid_labels_1.tolist())
valid_labels_encoded_2 = label_encoder.transform(valid_labels_2.tolist())

test_labels_encoded_1 = label_encoder.transform(test_labels_1.tolist())
test_labels_encoded_2 = label_encoder.transform(test_labels_2.tolist())

train_dataset_1 = RelationDataset(train_encodings, train_labels_encoded_1.tolist())
train_dataset_2 = RelationDataset(train_encodings, train_labels_encoded_2.tolist())

valid_dataset_1 = RelationDataset(valid_encodings, valid_labels_encoded_1.tolist())
valid_dataset_2 = RelationDataset(valid_encodings, valid_labels_encoded_2.tolist())

test_dataset_1 = RelationDataset(test_encodings, test_labels_encoded_1.tolist())
test_dataset_2 = RelationDataset(test_encodings, test_labels_encoded_2.tolist())

def worker_init_fn(worker_id):
    np.random.seed(args.seed + worker_id)

if args.train_label == 'ours':
    train_loader = DataLoader(train_dataset_1, batch_size=16, shuffle=True, worker_init_fn=worker_init_fn)
elif args.train_label == 'gdelt':
    train_loader = DataLoader(train_dataset_2, batch_size=16, shuffle=True, worker_init_fn=worker_init_fn)

valid_loader_1 = DataLoader(valid_dataset_1, batch_size=16, shuffle=False, worker_init_fn=worker_init_fn)
valid_loader_2 = DataLoader(valid_dataset_2, batch_size=16, shuffle=False, worker_init_fn=worker_init_fn)
test_loader_1 = DataLoader(test_dataset_1, batch_size=16, shuffle=False, worker_init_fn=worker_init_fn)
test_loader_2 = DataLoader(test_dataset_2, batch_size=16, shuffle=False, worker_init_fn=worker_init_fn)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.train()

optim = AdamW(model.parameters(), lr=2e-5)

num_epochs = 20
num_training_steps = len(train_loader) * num_epochs
scheduler = get_scheduler(
    name='linear',
    optimizer=optim,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

def evaluate_model(model, test_loader, device):
    model.eval()
    predictions, true_labels = [], []
    total_loss = 0.0

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            logits = outputs.logits
            loss = outputs.loss
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=-1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='macro')

    conf_matrix = confusion_matrix(true_labels, predictions)

    average_loss = total_loss / len(test_loader)

    return average_loss, accuracy, f1, conf_matrix

best_model_state_1 = None
best_valid_loss_1 = float('inf')
best_acc_1 = 0.0
best_f1_1 = 0.0
best_conf_matrix_1 = None

best_model_state_2 = None
best_valid_loss_2 = float('inf')
best_acc_2 = 0.0
best_f1_2 = 0.0
best_conf_matrix_2 = None

print("Starting training...")
for epoch in range(num_epochs):
    total_loss = 0
    model.train()
    for batch_idx, batch in enumerate(train_loader):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optim.step()
        scheduler.step()

        total_loss += loss.item()
        if (batch_idx + 1) % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {loss.item()}")



    average_loss = total_loss / len(train_loader)
    print(f"End of Epoch {epoch+1}, Average Loss: {average_loss}")

    valid_loss_1, valid_accuracy_1, valid_f1_1, valid_conf_matrix_1 = evaluate_model(model, valid_loader_1, device)
    print(f"Validation Metrics (Loader 1) — End of Epoch {epoch+1}: Loss: {valid_loss_1}, Accuracy: {valid_accuracy_1}, F1 Score: {valid_f1_1}")
    print(f"Validation Confusion Matrix (Loader 1):\n{valid_conf_matrix_1}")

    if valid_loss_1 < best_valid_loss_1:
        best_valid_loss_1 = valid_loss_1
        best_acc_1 = valid_accuracy_1
        best_f1_1 = valid_f1_1
        best_model_state_1 = model.state_dict()
        best_conf_matrix_1 = valid_conf_matrix_1
        print(f"New best model saved for valid_loader_1 with loss: {valid_loss_1} / with acc: {best_acc_1} / with f1: {best_f1_1}")

    valid_loss_2, valid_accuracy_2, valid_f1_2, valid_conf_matrix_2 = evaluate_model(model, valid_loader_2, device)
    print(f"Validation Metrics (Loader 2) — End of Epoch {epoch+1}: Loss: {valid_loss_2}, Accuracy: {valid_accuracy_2}, F1 Score: {valid_f1_2}")
    print(f"Validation Confusion Matrix (Loader 2):\n{valid_conf_matrix_2}")

    if valid_loss_2 < best_valid_loss_2:
        best_valid_loss_2 = valid_loss_2
        best_acc_2 = valid_accuracy_2
        best_f1_2 = valid_f1_2
        best_model_state_2 = model.state_dict()
        best_conf_matrix_2 = valid_conf_matrix_2
        print(f"New best model saved for valid_loader_2 with loss: {valid_loss_2} / with acc: {best_acc_2} / with f1: {best_f1_2}")

torch.save(best_model_state_1, f'save/best_ours_{args.exp_name}.pth')
torch.save(best_model_state_2, f'save/best_gdelt_{args.exp_name}.pth')

final_model_state = model.state_dict()
torch.save(final_model_state, f'save/final_{args.exp_name}.pth')

model.load_state_dict(best_model_state_1)
test_loss_1, test_accuracy_1, test_f1_1, test_conf_matrix_1 = evaluate_model(model, test_loader_1, device)

model.load_state_dict(best_model_state_2)
test_loss_2, test_accuracy_2, test_f1_2, test_conf_matrix_2 = evaluate_model(model, test_loader_2, device)

with open(f'results/best_ours_{args.exp_name}_{args.seed}.txt', 'w') as f:
    f.write(f"best_valid_loss_1: {best_valid_loss_1}\n")
    f.write(f"best_acc_1: {best_acc_1}\n")
    f.write(f"best_f1_1: {best_f1_1}\n")
    f.write(f"last_epoch_valid_loss_1: {valid_loss_1}\n")
    f.write(f"last_epoch_valid_accuracy_1: {valid_accuracy_1}\n")
    f.write(f"last_epoch_valid_f1_1: {valid_f1_1}\n")
    f.write(f"test_loss_1: {test_loss_1}\n")
    f.write(f"test_accuracy_1: {test_accuracy_1}\n")
    f.write(f"test_f1_1: {test_f1_1}\n")

with open(f'results/best_ours_valid_conf_{args.exp_name}_{args.seed}.txt', 'w') as f:
    f.write(f"{best_conf_matrix_1}\n")

with open(f'results/best_ours_test_conf_{args.exp_name}_{args.seed}.txt', 'w') as f:
    f.write(f"{test_conf_matrix_1}\n")

with open(f'results/best_gdelt_{args.exp_name}_{args.seed}.txt', 'w') as f:
    f.write(f"best_valid_loss_2: {best_valid_loss_2}\n")
    f.write(f"best_acc_2: {best_acc_2}\n")
    f.write(f"best_f1_2: {best_f1_2}\n")
    f.write(f"last_epoch_valid_loss_2: {valid_loss_2}\n")
    f.write(f"last_epoch_valid_accuracy_2: {valid_accuracy_2}\n")
    f.write(f"last_epoch_valid_f1_2: {valid_f1_2}\n")
    f.write(f"test_loss_2: {test_loss_2}\n")
    f.write(f"test_accuracy_2: {test_accuracy_2}\n")
    f.write(f"test_f1_2: {test_f1_2}\n")

with open(f'results/best_gdelt_valid_conf_{args.exp_name}_{args.seed}.txt', 'w') as f:
    f.write(f"{best_conf_matrix_2}\n")

with open(f'results/best_gdelt_test_conf_{args.exp_name}_{args.seed}.txt', 'w') as f:
    f.write(f"{test_conf_matrix_2}\n")
