import os
import random
import time
import torch
import torch.nn as nn
import numpy as np

from transformers import AdamW
from colbert.utils.runs import Run
from colbert.utils.amp import MixedPrecisionManager

from colbert.training.lazy_batcher import LazyBatcher
from colbert.training.eager_batcher import EagerBatcher
from colbert.parameters import DEVICE

from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message
from colbert.training.utils import print_progress, manage_checkpoints

# LOG_STEP = 100 #?@ debugging
LOG_STEP = 5000 #!@ custom


def train(args):
    random.seed(12345)
    np.random.seed(12345)
    torch.manual_seed(12345)
    if args.distributed:
        torch.cuda.manual_seed_all(12345)

    if args.distributed:
        assert args.bsize % args.nranks == 0, (args.bsize, args.nranks)
        assert args.accumsteps == 1
        args.bsize = args.bsize // args.nranks

        print("Using args.bsize =", args.bsize, "(per process) and args.accumsteps =", args.accumsteps)

    if args.lazy:
        reader = LazyBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks)
    else:
        reader = EagerBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks)

    if args.rank not in [-1, 0]:
        torch.distributed.barrier()

    colbert = ColBERT.from_pretrained('bert-base-uncased',
                                      pseudo_query_indicator=args.pseudo_query_indicator, #!@ custom
                                      query_maxlen=args.query_maxlen,
                                      doc_maxlen=args.doc_maxlen,
                                      dim=args.dim,
                                      similarity_metric=args.similarity,
                                      mask_punctuation=args.mask_punctuation)
    
    #!@ custom
    if args.prune_tokens:
        
        # Load teacher's checkpoint, for warm-start
        assert args.teacher_checkpoint is not None
        print_message(f"#> Warm-start PQA-ColBERT from {args.teacher_checkpoint}.")
        _teacher_checkpoint = torch.load(args.teacher_checkpoint, map_location='cpu')

        # Check whether student and teacher have the same model configuration
        for _attr in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'mask_punctuation']:
            assert getattr(args, _attr) == _teacher_checkpoint['arguments'][_attr], \
                f'Different "{_attr}" => student ({getattr(args, _attr)}) != teacher ({_teacher_checkpoint["arguments"][_attr]})'
        
        # Load state_dict from the teacher
        try:
            colbert.load_state_dict(_teacher_checkpoint['model_state_dict'])
        except:
            print_message("[WARNING] Loading checkpoint with strict=False")
            colbert.load_state_dict(_teacher_checkpoint['model_state_dict'], strict=False)
        
    #!@ custom
    # Knowledge Distillation
    if args.knowledge_distillation:
        print_message(f'#> Instantiate teacher (ColBERT that uses full tokens).')
        teacher = ColBERT.from_pretrained('bert-base-uncased',
                                    pseudo_query_indicator=False,
                                    query_maxlen=_teacher_checkpoint["arguments"]["query_maxlen"],
                                    doc_maxlen=_teacher_checkpoint["arguments"]["doc_maxlen"],
                                    dim=_teacher_checkpoint["arguments"]["dim"],
                                    similarity_metric=_teacher_checkpoint["arguments"]["similarity"],
                                    mask_punctuation=_teacher_checkpoint["arguments"]["mask_punctuation"],)
        print_message(f'#> Load teacher checkpoint from {args.teacher_checkpoint}.')
        teacher.load_state_dict(_teacher_checkpoint['model_state_dict'])

    if args.checkpoint is not None:

        assert args.resume_optimizer is False, "TODO: This would mean reload optimizer too."
        print_message(f"#> Starting from checkpoint {args.checkpoint} -- but NOT the optimizer!")

        checkpoint = torch.load(args.checkpoint, map_location='cpu')

        try:
            colbert.load_state_dict(checkpoint['model_state_dict'])
        except:
            print_message("[WARNING] Loading checkpoint with strict=False")
            colbert.load_state_dict(checkpoint['model_state_dict'], strict=False)
    


    if args.rank == 0:
        torch.distributed.barrier()

    colbert = colbert.to(DEVICE)
    colbert.train()

    if args.distributed:
        colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[args.rank],
                                                            output_device=args.rank,
                                                            find_unused_parameters=True)

    optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=args.lr, eps=1e-8)
    optimizer.zero_grad()

    if args.knowledge_distillation:
            
        teacher = teacher.to(DEVICE)
        print_message(f'#> Set eval mode for teacher.')
        teacher.eval()

        if args.distributed:
            teacher = torch.nn.parallel.DistributedDataParallel(teacher, device_ids=[args.rank],
                                                        output_device=args.rank,
                                                        find_unused_parameters=True)
    
    amp = MixedPrecisionManager(args.amp)
    criterion = nn.CrossEntropyLoss()
    labels = torch.arange(args.bsize, dtype=torch.long, device=DEVICE)

    start_time = time.time()
    train_loss = 0.0

    start_batch_idx = 0

    if args.resume:
        assert args.checkpoint is not None
        start_batch_idx = checkpoint['batch']

        reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize'])

    for batch_idx, BatchSteps in zip(range(start_batch_idx, args.maxsteps), reader):
        this_batch_loss = 0.0

        for queries, passages in BatchSteps:
        
            with amp.context():
                
                scores = colbert(queries, passages, inbatch_negatives=True)
                # scores: float tensor (bsize, 2 * bsize)

                #?@ debugging
                # print(f'scores ({scores.shape})')
                # print(f'training.py: train: exit');exit()

                #!@ custom
                if args.knowledge_distillation:
                    with torch.no_grad():
                        
                        # Replace pruning_mask by attention_mask, to use all tokens
                        _input_ids, _attention_mask, _ = passages
                        psg_all_tokens = (_input_ids, _attention_mask, _attention_mask.clone())
                        
                        
                        teacher_scores = teacher(queries, psg_all_tokens, inbatch_negatives=True)
                        # teacher_scores: float tensor (bsize, 2 * bsize)
                        
                        teacher_scores = teacher_scores / args.kd_temperature
                        # teacher_scores: float tensor (bsize, 2 * bsize)
                        
                        soft_labels = torch.nn.functional.softmax(teacher_scores, dim=-1)
                        log_soft_labels = torch.nn.functional.log_softmax(teacher_scores, dim=-1)

                    loss = soft_labels * (log_soft_labels - torch.nn.functional.log_softmax(scores, dim=-1))
                    loss = loss.sum(-1) 
                    # loss: float tensor (bsize)
                    
                    loss = loss.mean(0)

                else:
                    loss = criterion(scores, labels[:scores.size(0)])

                loss = loss / args.accumsteps

                #?@ debugging
                # print()
                # print(f'scores=\n\t{scores} ({scores.size()})')
                # print(f'soft_labels=\n\t{soft_labels} ({soft_labels.size()})')
                # print(f'loss={loss.item()}')
                # print(f'training.py: train: exit');exit()


            #!@ custom: comment
            # if args.rank < 1:
            #     print_progress(scores)

            #?@ debugging
            # if batch_idx > 5: exit()

            amp.backward(loss)

            train_loss += loss.item()
            this_batch_loss += loss.item()

        amp.step(colbert, optimizer)

        if args.rank < 1:
            avg_loss = train_loss / (batch_idx+1)

            num_examples_seen = (batch_idx - start_batch_idx) * args.bsize * args.nranks
            elapsed = float(time.time() - start_time)

            # log_to_mlflow = (batch_idx % 20 == 0) #!@ original
            log_to_mlflow = ((batch_idx+1) % LOG_STEP == 0) #!@ custom

            Run.log_metric('train/avg_loss', avg_loss, step=batch_idx, log_to_mlflow=log_to_mlflow)
            # Run.log_metric('train/batch_loss', this_batch_loss, step=batch_idx, log_to_mlflow=log_to_mlflow) #!@ custom
            Run.log_metric('train/examples', num_examples_seen, step=batch_idx, log_to_mlflow=log_to_mlflow)
            Run.log_metric('train/throughput', num_examples_seen / elapsed, step=batch_idx, log_to_mlflow=log_to_mlflow)

            #!@ original
            # print_message(batch_idx, avg_loss)
            #!@ custom
            if (batch_idx + 1) % LOG_STEP == 0:
                print_message(batch_idx+1, avg_loss)

            manage_checkpoints(args, colbert, optimizer, batch_idx+1)
