import contextlib
import json
from collections import OrderedDict, defaultdict
from datetime import datetime
from functools import partial

import torch.cuda
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm

from evaluate import evaluate0
from iter import ITER, ITERConfig
from iter.datasets import CoNLL, sparse_collate_fn
from iter.datasets.training import Hparams
from iter.datasets.util import move_to_device_and_de_sparsify_collate_fn
from iter.misc.config import with_argparse
from iter.misc.seeding import setup_seed
from iter.optimizing_iter import get_grouped_parameters, get_scheduler_lambda


@with_argparse
def train(
        transformer: str,
        dataset: str,
        seed: int = 42,
        use_bfloat16: bool = False,
        num_epochs: int = 0,
        verbose: bool = True,
        dont_ckpt: bool = False,
):
    setup_seed(seed)

    device = "cpu" if not torch.cuda.is_available() else "cuda:0"
    config = CoNLL.from_name(dataset)
    hparams = Hparams.from_name(dataset)
    config = ITERConfig(
        transformer,
        transformer_config={"max_length": config.max_length},
        num_types=config.num_types,
        num_links=config.num_links,
        features=config.features,
        dataset=config.name,
        max_nest_depth=config.entity_nest_depth,
        dropout=hparams.dropout,
        activation_fn=hparams.activation_fn,
    )
    model: ITER = ITER(config)
    if verbose:
        print(model)
        print(hparams.to_json())
        model.list_features()
    # model = torch.compile(model, fullgraph=False, dynamic=True)
    tokenizer = model.tokenizer

    dataset = CoNLL.from_name(dataset, tokenizer=tokenizer)
    dataset.setup_dataset()

    collate_fn = partial(sparse_collate_fn, tokenizer=dataset.tokenizer)
    collate_fn = move_to_device_and_de_sparsify_collate_fn(collate_fn, device)
    compute_loss_context_mngr = torch.cuda.amp.autocast(
        enabled=True, dtype=torch.bfloat16) if use_bfloat16 else contextlib.nullcontext()
    # calculate total train steps ahead of time
    num_samples_per_epoch = len(dataset["train"])
    effective_batch_size = hparams.batch_size * hparams.gradient_accumulation
    num_steps = (hparams.max_epochs * num_samples_per_epoch) // effective_batch_size

    optimizer = torch.optim.AdamW(get_grouped_parameters(model, hparams), fused=False, weight_decay=0.1, lr=1e-4)
    lr_scheduler = get_scheduler_lambda(hparams.lr_scheduler, hparams.warmup_steps, num_steps)
    task_lr_scheduler = get_scheduler_lambda(hparams.task_lr_scheduler, hparams.task_warmup_steps, num_steps)
    lr_scheduler = LambdaLR(optimizer, [lr_scheduler, lr_scheduler, task_lr_scheduler, task_lr_scheduler])

    model = model.to(device)

    def train_step():
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1, error_if_nonfinite=False)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad(set_to_none=True)

    def train_epoch(ep: int):
        model.train()
        dataloader = DataLoader(
            dataset["train"], batch_size=hparams.batch_size, shuffle=True, collate_fn=collate_fn, )
        steps, grad_steps = 0, 0
        losses = []
        for batch in (tq := tqdm(dataloader, leave=False, desc=f"Training - Epoch {ep}")):
            with compute_loss_context_mngr:
                output = model(**batch)
                loss = output.loss
            loss.backward()
            steps += 1
            grad_steps += 1
            if grad_steps >= hparams.gradient_accumulation:
                train_step()
                grad_steps = 0
            loss_val = loss.item()
            loss_val_per_element = loss_val / batch["input_ids"].size(0)
            tq.set_postfix(OrderedDict(loss=loss_val_per_element))
            losses.append(loss_val)
        if grad_steps > 0 and False:  # mimic original training, this should not be necessary
            train_step()
        pass

    best_f1 = 0.0
    best_f1_epoch = 0
    outcomes = defaultdict(list)
    date = datetime.now()
    date_fmt = date.strftime('%Y-%m-%d_%H-%M-%S')
    model_path = f"models/{dataset.name}/{transformer}/{date_fmt}"

    def post_train_epoch(ep: int):
        nonlocal best_f1, best_f1_epoch
        val_metrics = evaluate0(
            model,
            dataset,
            split="eval",
            batch_size=hparams.eval_batch_size or hparams.batch_size
        )
        if verbose:
            print(val_metrics)
        outcome_f1 = (val_metrics.ere_f1 if hparams.optimize_for == "ere" else val_metrics.ner_f1) \
            if hparams.metric_average == "micro" \
            else (val_metrics.macro_ere_f1 if hparams.optimize_for == "ere" else val_metrics.macro_ner_f1)
        if outcome_f1 > best_f1:
            best_f1 = outcome_f1
            best_f1_epoch = ep
            # save a checkpoint
            if not dont_ckpt:
                print(f"Found new best checkpoint ({hparams.optimize_for}): {outcome_f1}; "
                      f"saving to {model_path}")
                model.save_pretrained(model_path)
        for k, v in val_metrics.to_dict().items():
            outcomes[k].append(v)
        pass

    for epoch in range(num_epochs or hparams.num_epochs):
        train_epoch(epoch)
        post_train_epoch(epoch)
        if best_f1_epoch + hparams.patience <= epoch:
            print(f"No improvement for {hparams.patience} epochs, aborting training")
            break

    if not dont_ckpt:
        del model
        torch.cuda.empty_cache()
        model = ITER.from_pretrained(model_path)
        model = model.to(device)
    print(f"TESTING")
    test_metrics = evaluate0(
        model,
        dataset,
        split="test",
        batch_size=hparams.eval_batch_size or hparams.batch_size
    )
    print(test_metrics)
    test_outcomes = test_metrics.to_dict()
    with open(model_path + "/metrics.json", "w") as f:
        json.dump({
            "test_metrics": test_outcomes,
            "metrics": outcomes,
        }, f)

if __name__ == "__main__":
    train()
