import pandas as pd
import numpy as np
import argparse
import random
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, get_scheduler
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--train_label', type=str, default='ours', choices=['ours', 'gdelt'])
parser.add_argument('--train_num_class', type=int, default=2)
parser.add_argument('--test_num_class', type=int, default=2)
parser.add_argument('--model', type=str, default='bert', choices=['bert', 'roberta', 'albert', 'xlnet', 'distilbert'])

args = parser.parse_args()
args.exp_name = f'emnlp'
set_seed(args.seed)

train_data = pd.read_csv('./data/train_data.csv')
valid_data = pd.read_csv('./data/valid_data.csv')
test_data = pd.read_csv('./data/test_data.csv')

test_texts = test_data['text']
test_labels_1 = test_data['RE_label']
train_labels_2 = train_data['CE_QuadClass_Label']
test_labels_3 = test_data['Expert_label']

test_texts = test_texts[(test_labels_1 == 1) | (test_labels_1 == 0)]
test_labels_3 = test_labels_3[(test_labels_1 == 1) | (test_labels_1 == 0)]
test_texts = test_texts[(test_labels_3 == 1) | (test_labels_3 == 0)]
test_labels_3 = test_labels_3[(test_labels_3 == 1) | (test_labels_3 == 0)]


class RelationDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

if args.model == 'bert':
    from transformers import BertTokenizer, BertForSequenceClassification
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=args.train_num_class)
elif args.model == 'roberta':
    from transformers import RobertaTokenizer, RobertaForSequenceClassification
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=args.train_num_class)
elif args.model == 'albert':
    from transformers import AlbertTokenizer, AlbertForSequenceClassification
    tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
    model = AlbertForSequenceClassification.from_pretrained('albert-base-v2', num_labels=args.train_num_class)
elif args.model == 'distilbert':
    from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=args.train_num_class)
elif args.model == 'xlnet':
    from transformers import XLNetTokenizer, XLNetForSequenceClassification
    tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
    model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=args.train_num_class)


test_encodings = tokenizer(test_texts.tolist(), truncation=True, padding=True, max_length=512)

from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
train_labels_encoded_2 = label_encoder.fit_transform(train_labels_2.tolist())
test_labels_encoded_3 = label_encoder.transform(test_labels_3.tolist())
test_dataset_3 = RelationDataset(test_encodings, test_labels_encoded_3.tolist())

def worker_init_fn(worker_id):
    np.random.seed(args.seed + worker_id)

test_loader_3 = DataLoader(test_dataset_3, batch_size=16, shuffle=False, worker_init_fn=worker_init_fn)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

def evaluate_model(model, test_loader, device):
    model.eval()
    predictions, true_labels = [], []
    total_loss = 0.0

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            logits = outputs.logits
            loss = outputs.loss
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=-1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='macro')

    conf_matrix = confusion_matrix(true_labels, predictions)

    average_loss = total_loss / len(test_loader)

    return average_loss, accuracy, f1, conf_matrix

if args.train_label == 'ours':
    model.load_state_dict(torch.load(f'save/best_ours_{args.exp_name}.pth'))
    test_loss, test_accuracy, test_f1, test_conf_matrix  = evaluate_model(model, test_loader_3, device)
    with open(f'results/best_ours_{args.exp_name}_{args.seed}.txt', 'w') as f:
        f.write(f"test_loss: {test_loss}\n")
        f.write(f"test_accuracy: {test_accuracy}\n")
        f.write(f"test_f1: {test_f1}\n")

    with open(f'results/best_ours_conf_{args.exp_name}_{args.seed}.txt', 'w') as f:
        f.write(f"{test_conf_matrix}\n")
elif args.train_label == 'gdelt':
    model.load_state_dict(torch.load(f'save/best_gdelt_{args.exp_name}.pth'))
    test_loss, test_accuracy, test_f1, test_conf_matrix = evaluate_model(model, test_loader_3, device)
    with open(f'results/best_gdelt_{args.exp_name}_{args.seed}.txt', 'w') as f:
        f.write(f"test_loss: {test_loss}\n")
        f.write(f"test_accuracy: {test_accuracy}\n")
        f.write(f"test_f1: {test_f1}\n")

    with open(f'results/best_gdelt_conf_{args.exp_name}_{args.seed}.txt', 'w') as f:
        f.write(f"{test_conf_matrix}\n")