''' finetune roberta on GYAFC data '''

import argparse
import os

import torch
from torch.utils.data import Dataset
import random

from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer, EarlyStoppingCallback
import wandb
import evaluate
import numpy as np
import json
import os

from datetime import datetime

class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=128, do_normalize=True):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.do_normalize = do_normalize

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        retval = {}

        text = self.data[idx]['text']

        tokenized = self.tokenizer(text, truncation=True, max_length=self.max_len, return_tensors='pt')
        retval['text'] = text
        retval['input_ids'] = tokenized['input_ids']
        retval['attention_mask'] = tokenized['attention_mask']
        retval['labels'] = self.data[idx]['label']

        return retval


def pad_to_length(x, length, pad_token_id=0):
    return torch.cat([x, pad_token_id*torch.ones((1, length - x.shape[-1]), dtype=torch.long)], dim=-1)
        
def collate_fn(batch, tokenizer):

    input_ids = []
    attention_masks = []
    labels = []

    for sample in batch:
        input_ids.append(sample['input_ids'])
        attention_masks.append(sample['attention_mask'])
        labels.append(sample['labels'])

    max_len = max(x.shape[-1] for x in input_ids)

    for i in range(len(input_ids)):
        input_ids[i] = pad_to_length(input_ids[i], max_len, pad_token_id=tokenizer.pad_token_id)
        attention_masks[i] = pad_to_length(attention_masks[i], max_len, pad_token_id=tokenizer.pad_token_id)

    return {
        'input_ids': torch.cat(input_ids,0),
        'attention_mask': torch.cat(attention_masks,0),
        'labels': torch.tensor(labels)
    }

def compute_metrics(eval_preds):
    accuracy_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")
    recall_metric = evaluate.load("recall")
    precision_metric = evaluate.load("precision")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    
    accuracy_results = accuracy_metric.compute(predictions=predictions, references=labels)
    f1_results = f1_metric.compute(predictions=predictions, references=labels, average='binary')
    recall_results = recall_metric.compute(predictions=predictions, references=labels, average='binary')
    precision_results = precision_metric.compute(predictions=predictions, references=labels, average='binary')

    return {
        'eval_accuracy': accuracy_results['accuracy'],
        'eval_f1': f1_results['f1'],
        'eval_recall': recall_results['recall'],
        'eval_precision': precision_results['precision'],
    }

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for name, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            print(name, param.numel())
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def load_data(path, label):
    data = []
    with open(path, 'r') as f:
        for line in f:
            text = line.strip()
            data.append({
                'text': text,
                'label': label
            })
    
    return data

if __name__ == "__main__":

    # Example usage:
    # python training/train.py --training_data_pos data/gyafc/formal/formal_train.txt --training_data_neg data/gyafc/informal/informal_train.txt --val_data_pos data/gyafc/formal/formal_val.txt --val_data_neg data/gyafc/informal/informal_val.txt --model_name roberta-base --seed 42 --learning_rate 5e-5 --batch_size 32 --accumulation_steps 1 --out_dir gyafc_formality_classifier --project_name unfun --eval_metric eval_f1
   
    parser = argparse.ArgumentParser()
    parser.add_argument('--val_data_pos', help='path to the test data', nargs='+', default=[])
    parser.add_argument('--val_data_neg', help='path to the test data', nargs='+', default=[])
    parser.add_argument('--train_split', help='train split', type=float)
    parser.add_argument('--model_name', default='roberta-base', type=str, help='name of the model')
    parser.add_argument('--seed', default=42, type=int, help='random seed')
    parser.add_argument('--training_data_pos', required=True,  nargs='+', help='path totraining data')
    parser.add_argument('--training_data_neg', required=True,  nargs='+', help='path totraining data')
    parser.add_argument('--learning_rate', default=5e-5, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=32, type=int, help='batch size')
    parser.add_argument('--accumulation_steps', default=1, type=int, help='gradient accumulation steps')
    parser.add_argument('--out_dir', default='gyafc_formality_classifier', type=str, help='output directory')
    parser.add_argument('--project_name', default='unfun', type=str, help='project name')
    parser.add_argument('--eval_metric', required=True, type=str, help='metric to use for early stopping')
    
    args = parser.parse_args()

    if len(args.val_data_pos + args.val_data_neg) == 0:
        assert args.train_split is not None, "train split must be provided if no validation data is provided"
        do_train_split = True
    else:
        do_train_split = False

    random.seed(int(args.seed))
    np.random.seed(int(args.seed))
    torch.manual_seed(int(args.seed))

    args.out_dir =  os.path.join(args.out_dir, args.model_name.replace('/',''))
    training_names = '-'.join(['_'.join(tname.strip('/').split('/')[-3:]) for tname in args.training_data_pos + args.training_data_neg])
    run_name = '_'.join([args.model_name, str(args.learning_rate), str(args.batch_size), training_names])
    current_date = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
    args.out_dir = os.path.join(args.out_dir, run_name, str(args.seed), current_date)

    os.makedirs(args.out_dir, exist_ok=True)

    with open(os.path.join(args.out_dir, 'args.json'), 'w+') as f:
        json.dump(vars(args), f, indent=2)
    
    training_data = []

    for training_dname in args.training_data_pos:
        training_data.extend(load_data(training_dname, 1))
    
    for training_dname in args.training_data_neg:
        training_data.extend(load_data(training_dname, 0))


    random.shuffle(training_data)

    if do_train_split:
        n_train = int(args.train_split * len(training_data))
        val_data = training_data[n_train:]
        training_data = training_data[:n_train]

    else:
        val_data = []

        for val_dname in args.val_data_pos:
            val_data.extend(load_data(val_dname, 1))

        for val_dname in args.val_data_neg:
            val_data.extend(load_data(val_dname, 0))

        random.shuffle(val_data)
 
    use_bf16=False
    optim="adamw_torch"

    model = AutoModelForSequenceClassification.from_pretrained(args.model_name, num_labels=2)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, add_eos_token=True)

    if not hasattr(tokenizer, 'pad_token') or tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(training_data[:10])

    train_dataset = TextDataset(training_data, tokenizer)
    val_dataset = TextDataset(val_data, tokenizer)

    wandb.init(
        project=args.project_name,
        config={
            'learning_rate': args.learning_rate,
            'batch_size': args.batch_size,
            'seed': args.seed,
            'model_name': args.model_name,
            'training_data': args.training_data_pos + args.training_data_neg,
            'val_data': args.val_data_pos + args.val_data_neg,
            'accumulation_steps': args.accumulation_steps,
            'eval_metric': args.eval_metric,
            
        },
        name=run_name

    )
    
    training_args = TrainingArguments(
        output_dir=args.out_dir,
        num_train_epochs=200,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.accumulation_steps,
        evaluation_strategy='steps',
        save_strategy='steps',
        logging_dir='logs',
        logging_steps=50,
        save_total_limit=1,
        load_best_model_at_end=True,
        metric_for_best_model=args.eval_metric,
        greater_is_better=False if 'loss' in args.eval_metric else True,
        seed=args.seed,
        bf16=use_bf16,
        optim=optim,
        learning_rate=args.learning_rate,
        save_steps=100,
        eval_steps=100,
        report_to="wandb",
        lr_scheduler_type='constant_with_warmup',
        warmup_steps=100,
    )

    trainer = Trainer(
        model,
        training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=lambda x: collate_fn(x, tokenizer=tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=15)],
        compute_metrics=compute_metrics,
    )
    # model.config.use_cache = False
    trainer.train()
    
