import math
import sys
from tqdm import tqdm
import numpy as np
import torch
import torch.utils.data

from gpt2_experiments.base_model import CustomGPT2LMHeadModel, CustomGPT2Config
from toy_task.toy_task import get_synthetic_dataset

TQDM_FILE = sys.stderr  # sys.stdout

class Config:
    
    # dataset
    num_train_samples = 100000
    num_eval_samples = 10000
    sequence_length = 16 * 8
    num_jumps = 16 - 1

    # model
    # layers = 4  # defined later
    embed_dim = 512
    ffn_dim = embed_dim * 4
    num_heads = 8

    reorder_and_upcast_attn = True
    scale_attn_by_inverse_layer_idx = False
    scale_attn_weights = True
    
    # custom_attention = False  # defined later

    # training
    epochs = 30
    warmup_steps = 782 * 10  # 10 epochs
    lr = 3e-4
    betas = (0.9, 0.98)
    batch_size = 128
    # grad_clip = 0.5



def train(epoch, model, train_loader, eval_loader, optimizer, criterion, scheduler, task_info):
    device = next(model.parameters()).device

    smooth_loss = None

    # training
    model.train()
    bar = tqdm(train_loader, file=TQDM_FILE, leave=False)
    for x, y in bar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(input_ids=x).logits

        y = y.view(-1)
        y_pred = y_pred.view(*y.shape, -1)

        loss = criterion(y_pred, y)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), Config.grad_clip)
        optimizer.step()
        scheduler.step()

        smooth_loss = loss.item() if smooth_loss is None else 0.99 * smooth_loss + 0.01 * loss.item()
        bar.set_description(f"Epoch {epoch+1}: loss: {smooth_loss:.3f}")

    # evaluation
    model.eval()
    with torch.no_grad():
        total = 0
        total_loss = 0
        total_correct = 0
        for x, y in eval_loader:
            x, y = x.to(device), y.to(device)
            y_pred = model(input_ids=x).logits
        
            y = y.view(-1)
            y_pred = y_pred.view(*y.shape, -1)

            loss = criterion(y_pred, y)
            total += y.size(0)
            total_loss += loss.item() * y.size(0)

            pred_labels = y_pred.argmax(-1) if task_info.multiclass else (y_pred > 0)
            label_mask = (y != -100)
            total_correct += (pred_labels == y)[label_mask].float().mean().item() * y.size(0)

    print(f"Epoch {epoch+1}: loss: {total_loss/total:.3f}, accuracy: {total_correct/total:.2%}, lr: {scheduler.get_last_lr()[0]:.2e}")


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    _, info = get_synthetic_dataset(Config.num_eval_samples, Config.sequence_length, Config.num_jumps)

    model_config = CustomGPT2Config(
        vocab_size=max(info.n_tokens, info.n_classes),
        n_positions=info.max_seq_length,
        n_embd=Config.embed_dim,
        n_layer=Config.layers,
        n_head=Config.num_heads,
        n_inner=Config.ffn_dim,
        activation_function="gelu",
        custom_attention=Config.custom_attention,
        resid_pdrop=0.0,
        embd_pdrop=0.0,
        attn_pdrop=0.0,
        reorder_and_upcast_attn=Config.reorder_and_upcast_attn,
        scale_attn_by_inverse_layer_idx=Config.scale_attn_by_inverse_layer_idx,
        scale_attn_weights=Config.scale_attn_weights,
        use_cache=False,
    )
    model = CustomGPT2LMHeadModel(model_config)
    model.to(device)

    print("Config:")
    for k, v in Config.__dict__.items():
        if not k.startswith("__"):
            print(f"  {k}: {v}")
    print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} parameters")

    optimizer = torch.optim.Adam(model.parameters(), lr=Config.lr, betas=Config.betas)
    # linear warmup
    scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [
        torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-6, end_factor=1.0, total_iters=Config.warmup_steps),
        torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.epochs*math.ceil(Config.num_train_samples/Config.batch_size)-Config.warmup_steps)
    ], milestones=[Config.warmup_steps])
    # scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-6, end_factor=1.0, total_iters=Config.warmup_steps)
    criterion = torch.nn.CrossEntropyLoss() if info.multiclass else torch.nn.BCEWithLogitsLoss()

    for epoch in range(Config.epochs):
        train_dataset, _ = get_synthetic_dataset(Config.num_train_samples, Config.sequence_length, Config.num_jumps)
        eval_dataset, _ = get_synthetic_dataset(Config.num_eval_samples, Config.sequence_length, Config.num_jumps)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=Config.batch_size, shuffle=False)

        train(epoch, model, train_loader, eval_loader, optimizer, criterion, scheduler, info)



if __name__ == "__main__":
    # np.random.seed(42)
    
    # main()

    repeats = 2
    
    Config.custom_attention = False
    for layers in [1, 2, 3, 4, 5]:
        for i in range(repeats):
            print("\n\n===============================================")
            print(f"Starting run {i+1}/{repeats} with {layers} layers\n")

            # Config.lr = 1e-3 if layers < 5 else 3e-4
            Config.layers = layers
            main()

    Config.custom_attention = True
    for layers in [1]:
        for i in range(repeats):
            print("\n\n===============================================")
            print(f"Starting run {i+1}/{repeats} with {layers} layers\n")

            Config.layers = layers
            main()