import os
import random
import time
import torch
import torch.nn as nn
import numpy as np
import itertools

from torch.optim import Adam
from transformers import AdamW
from pruner.utils.runs import Run
from pruner.utils.amp import MixedPrecisionManager

from pruner.training.batcher import Batcher
from pruner.parameters import DEVICE

from pruner.modeling.colbert import ColBERTPruner
from pruner.utils.utils import print_message
from pruner.training.utils import print_progress, manage_checkpoints

from colbert.modeling.colbert import ColBERT

# LOG_STEP = 100 #?@ debugging
LOG_STEP = 5000 #!@ custom


def train(args):
    random.seed(12345)
    np.random.seed(12345)
    torch.manual_seed(12345)
    if args.distributed:
        torch.cuda.manual_seed_all(12345)

    #TODO
    assert not args.distributed, "distributed 를 사용하면, 여러 epoch 을 iteration 돌 수 없음."
    if args.distributed:
        assert args.bsize % args.nranks == 0, (args.bsize, args.nranks)
        assert args.accumsteps == 1
        args.bsize = args.bsize // args.nranks

        print("Using args.bsize =", args.bsize, "(per process) and args.accumsteps =", args.accumsteps)

    reader = Batcher(args, (0 if args.rank == -1 else args.rank), args.nranks)

    if args.rank not in [-1, 0]:
        torch.distributed.barrier()

    # colbert = ColBERT.from_pretrained('bert-base-uncased',
    #                                   pseudo_query_indicator=False, #!@ custom
    #                                   query_maxlen=args.query_maxlen,
    #                                   doc_maxlen=args.doc_maxlen,
    #                                   dim=args.dim,
    #                                   similarity_metric=args.similarity,
    #                                   mask_punctuation=args.mask_punctuation)

    pruner = ColBERTPruner.from_pretrained('bert-base-uncased',
                                    doc_maxlen=args.doc_maxlen,
                                    dim=args.dim,
                                    mask_punctuation=args.mask_punctuation,
                                    )

    # Load ColBERT checkpoint
    assert (args.colbert_checkpoint is not None), "ColBERT checkpoint should be given"
    print_message(f"#> Starting from checkpoint {args.colbert_checkpoint} -- but NOT the optimizer!")
    colbert_checkpoint = torch.load(args.colbert_checkpoint, map_location='cpu')
    try:
        pruner.load_state_dict(colbert_checkpoint['model_state_dict'])
        # colbert.load_state_dict(colbert_checkpoint['model_state_dict'])
    except:
        print_message("[WARNING] Loading colbert_checkpoint with strict=False")
        pruner.load_state_dict(colbert_checkpoint['model_state_dict'], strict=False)
        # colbert.load_state_dict(colbert_checkpoint['model_state_dict'], strict=False)
    
    if args.rank == 0:
        torch.distributed.barrier()

    pruner = pruner.to(DEVICE)
    pruner.train()

    # colbert = colbert.to(DEVICE)
    # colbert.eval()

    if args.distributed:
        pruner = torch.nn.parallel.DistributedDataParallel(pruner, device_ids=[args.rank],
                                                            output_device=args.rank,
                                                            find_unused_parameters=True)
        # colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[args.rank],
        #                                                     output_device=args.rank,
        #                                                     find_unused_parameters=True)

    #!@ custom
    if args.finetune_encoder:
        # assert args.lr < 1e-6, f'You pass finetune_encoder==True. However, lr={args.lr} is too high, when fine-tuning BERT encoder.'
        print_message("[WARNING] Fine-tuning BERT encoder. This will take more GPU memories. (optimizer=AdamW)")
        optimizer = AdamW(filter(lambda p: p.requires_grad, pruner.parameters()), lr=args.lr, eps=1e-8)
    else:
        # Freeze BERT encoder & linear layer
        print_message("[WARNING] Freeze BERT encode and linear layer. (optimizer=Adam)")
        for p in itertools.chain(pruner.bert.parameters(), pruner.linear.parameters()): p.requires_grad = False
        optimizer = Adam(filter(lambda p: p.requires_grad, pruner.parameters()), lr=args.lr, eps=1e-8)
    optimizer.zero_grad()

    amp = MixedPrecisionManager(args.amp)
    criterion = nn.CrossEntropyLoss()

    start_time = time.time()
    train_loss = 0.0

    start_batch_idx = 0

    for batch_idx, BatchSteps in zip(range(start_batch_idx, args.maxsteps), reader):
        this_batch_loss = 0.0

        for ids, mask, label in BatchSteps:
            # ids   : tensor (B, L) = input_ids for BERT encoder inputs
            # mask  : tensor (B, L) = attention_mask for BERT encoder inputs (0 for [PAD] and 1 for the others)
            # label : tensor (B, L) = score labels
            with amp.context():
                scores = pruner(ids, mask) # B, L, 1
                scores = torch.cat((torch.zeros_like(scores), scores), dim=-1) # B, L, 2

                if args.gumbel_softmax:
                    pruner_mask = torch.nn.functional.gumbel_softmax(scores, dim=2, hard=args.gumbel_softmax_hard)[:, :, 1] # B, L
                    pruner_mask = pruner_mask.bool()
                
                else:
                    mask = mask.to(scores.device).bool()
                    scores_masked = scores[mask, :]
                    label_masked = label.to(scores.device)[mask].long()

                    loss = criterion(scores_masked, label_masked)
                    loss = loss / args.accumsteps
                
            #!@ custom: comment
            # if args.rank < 1:
            #     print_progress(scores)

            amp.backward(loss)

            train_loss += loss.item()
            this_batch_loss += loss.item()

        amp.step(pruner, optimizer)

        if args.rank < 1:
            avg_loss = train_loss / (batch_idx+1)

            num_examples_seen = (batch_idx - start_batch_idx) * args.bsize * args.nranks
            elapsed = float(time.time() - start_time)

            # log_to_mlflow = (batch_idx % 20 == 0) #!@ original
            log_to_mlflow = ((batch_idx+1) % LOG_STEP == 0) #!@ custom

            Run.log_metric('train/avg_loss', avg_loss, step=batch_idx, log_to_mlflow=log_to_mlflow)
            # Run.log_metric('train/batch_loss', this_batch_loss, step=batch_idx, log_to_mlflow=log_to_mlflow) #!@ custom
            Run.log_metric('train/examples', num_examples_seen, step=batch_idx, log_to_mlflow=log_to_mlflow)
            Run.log_metric('train/throughput', num_examples_seen / elapsed, step=batch_idx, log_to_mlflow=log_to_mlflow)

            #!@ original
            # print_message(batch_idx, avg_loss)
            #!@ custom
            if (batch_idx + 1) % LOG_STEP == 0:
                print_message(batch_idx+1, avg_loss)

            manage_checkpoints(args, pruner, optimizer, batch_idx+1)
