from mmap import MAP_ANON
import os
import time
import logging
from time import gmtime, strftime
from pathlib import Path
import json
from torch.nn.modules.container import ModuleList
from torch.utils.mobile_optimizer import optimize_for_mobile

import wandb
import torch
from torch import optim
from torch import nn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler

from clip.clip import _transform, load, tokenize
from clip.model import convert_weights, CLIP, DeiT, PretrainedViT, Transformer, PETransformer, PEVisualTransformer, Pit
from training.train import train, evaluate, evaluate_vilt, evaluate_correct_val, evaluate_correct_test
from training.data import get_data, get_cc_data
from training.params import parse_args
from training.logger import setup_primary_logging, setup_worker_logging
from training.scheduler import cosine_lr
from training.vilt import ViLT
import time
from transformers import (
get_polynomial_decay_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup
)

from tqdm import trange, tqdm
import faiss
import numpy as np

# Used by https://github.com/openai/CLIP/issues/83 but not below.
# Keeping it incase needed.
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()

def is_master(args):
    return (not args.distributed) or args.gpu == 0 or args.dp

def test_inference(model):
    dummy_input = torch.rand(1, 3, 224, 224).to(torch.device(f"cuda:0"))
    model.eval()
    s = time.time()
    with torch.no_grad():
        _ = model(dummy_input)
    torch.cuda.synchronize()
    e = time.time()
    print(f"Time duration: {(e-s)}")


def get_vilt_optim(model, config, args, dataloader):
    lr = config["learning_rate"]
    lr = 2e-4
    wd = config["weight_decay"]

    no_decay = [
        "bias",
        "LayerNorm.bias",
        "LayerNorm.weight",
        "norm.bias",
        "norm.weight",
        "norm1.bias",
        "norm1.weight",
        "norm2.bias",
        "norm2.weight",
    ]
    head_names = ["vqa_classifier", "nlvr2_classifier"]
    lr_mult = config["lr_mult"]
    end_lr = config["end_lr"]
    decay_power = config["decay_power"]
    optim_type = config["optim_type"]

    names = [n for n, p in model.named_parameters()]
    optimized_name_p = []
    for n, p in model.named_parameters():
        optimized_name_p.append((n, p))
    # for name, p in model.query_embeddings.named_parameters():
    #     optimized_name_p.append((name, p))
    # for name, p in model.text_lstm.named_parameters():
    #     optimized_name_p.append((name, p))
    # for name, p in model.projection.named_parameters():
    #     optimized_name_p.append((name, p))
    # optimized_name_p.append(('prompt', model.text_prompt_embeddings))
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in optimized_name_p
                if not any(nd in n for nd in no_decay)
                and not any(bb in n for bb in head_names)
            ],
            "weight_decay": wd,
            "lr": lr,
        },
        {
            "params": [
                p
                for n, p in optimized_name_p
                if any(nd in n for nd in no_decay)
                and not any(bb in n for bb in head_names)
            ],
            "weight_decay": 0.0,
            "lr": lr,
        },
        {
            "params": [
                p
                for n, p in optimized_name_p
                if not any(nd in n for nd in no_decay)
                and any(bb in n for bb in head_names)
            ],
            "weight_decay": wd,
            "lr": lr * lr_mult,
        },
        {
            "params": [
                p
                for n, p in optimized_name_p
                if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names)
            ],
            "weight_decay": 0.0,
            "lr": lr * lr_mult,
        }
    ]

    optimizer = optim.AdamW(
        optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)
    )
    max_steps = args.epochs * len(dataloader)
    warmup_steps = int(max_steps * 0.1)
    scheduler = get_linear_schedule_with_warmup(optimizer, 
            num_warmup_steps=warmup_steps,
            num_training_steps=max_steps)
    return optimizer, scheduler


def main_worker_vilt(gpu, ngpus_per_node, log_queue, args):
    args.gpu = gpu
    args.rank = gpu
    setup_worker_logging(args.rank, log_queue, args.log_level)

    # Log and save params.
    if is_master(args):
        logging.info("Params:")
        params_file = os.path.join(args.logs, args.name, "params.txt")
        with open(params_file, "w") as f:
            for name in sorted(vars(args)):
                val = getattr(args, name)
                logging.info(f"  {name}: {val}")
                f.write(f"{name}: {val}\n")
            
    if args.distributed:
        dist.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank,
        )
    
    if args.dp:
        args.batch_size *= args.world_size

    if args.gpu is not None:
        logging.info(f"Use GPU: {args.gpu} for training")
        torch.cuda.set_device(args.gpu)

    PATH_MLM = "/home/roy/ViLT/weights/vilt_200k_mlm_itm.ckpt"
    PATH_COCO = "/home/roy/ViLT/weights/vilt_irtr_coco.ckpt"

    # teacher
    # vilt_model = ViLT(ckpt_path=PATH_COCO, student=False)
    # vilt_model.cuda(args.gpu)
    # for p in vilt_model.parameters():
    #     p.requires_grad = False
    # vilt_model.eval()

    # student
    vilt_model_student = ViLT(ckpt_path=PATH_MLM, student=False)
    vilt_model_student.cuda(args.gpu)
    vilt_model_student.train()

    gcc_data = get_cc_data(args, (None, None))
    data = get_data(args, (None, None))
    # evaluate_vilt(vilt_model, data, 0, args, data['train'].dataloader.dataset.bert_tokenizer)
    # exit()

    optimizer, scheduler = get_vilt_optim(vilt_model_student, vilt_model_student.config, args, data['train'].dataloader)

    logging.info("Initial Evaluation:")
    logging.info("Start training ViLT")
    vilt_model_student.train()
    logging_steps = 100
    tr_loss, logging_loss = 0.0, 0.0
    bce_loss_fn = nn.BCEWithLogitsLoss()
    kld_loss_fn = nn.KLDivLoss()
    for e in trange(0, args.epochs):
        if e == 0:
            batch_iter = tqdm(gcc_data.dataloader, total=len(gcc_data.dataloader), desc='Iteration:')
        else:
            batch_iter = tqdm(data['train'].dataloader, total=len(data['train'].dataloader), desc='Iteration:')
        for i, batch in enumerate(batch_iter):
            images, texts, _ = batch
            images = images.cuda(args.gpu)
            texts = texts.cuda(args.gpu)
            loss = .0
            # text2image_matrix = vilt_model_student.sim_score_dual(images, texts)
            # gt = torch.arange(len(images)).cuda(args.gpu)
            # loss += nn.CrossEntropyLoss()(text2image_matrix, gt)
            image_cls, image_embeds, image_masks = vilt_model_student.forward_image(images)
            text_cls, text_embeds, text_masks = vilt_model_student.forward_text(texts)
            sim_matrix = (text_cls / torch.norm(text_cls, dim=1, keepdim=True)) @ (image_cls / torch.norm(image_cls, dim=1, keepdim=True)).t()
            # prob = torch.softmax(sim_matrix, dim=1)
            # text2image_matrix = vilt_model_student.sim_score_embeds(image_embeds[:, 1:], image_masks, text_embeds, text_masks)
            # with torch.no_grad():
            #     text2image_matrix_teacher = vilt_model.sim_score_raw(images, texts)
            #     prob_teacher = torch.softmax(text2image_matrix_teacher, dim=1)
            # gt_bce = torch.eye(len(texts)).cuda(args.gpu)
            gt = torch.arange(len(images)).cuda(args.gpu)
            loss = .0
            loss += (
                nn.CrossEntropyLoss()(sim_matrix, gt) 
                # nn.CrossEntropyLoss()(sim_matrix.t(), gt)
            ) 
            # loss += kld_loss_fn(prob, prob_teacher)
            # loss += bce_loss_fn(text2image_matrix.view(-1), gt_bce.view(-1))
            loss_val = loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()
            tr_loss += loss_val
            batch_iter.set_description(desc=f"Batch loss: {loss_val:.3f}")
            if (i>=1 and i%logging_steps==0):
                logging.info(f"Avg loss over {logging_steps}: {(tr_loss-logging_loss)/logging_steps:.3f}")
                logging_loss = tr_loss
            # if e >= 1 and i >=1 and (i%1000) == 0:
            #     logging.info("Start evaluation on COCO2014 5k test set")
            #     evaluate_vilt(vilt_model, data, e, args, data['train'].dataloader.dataset.bert_tokenizer)
            #     vilt_model.train()

        logging.info("Start evaluation on COCO2014 5k test set")
        evaluate_vilt(vilt_model, data, e, args, data['train'].dataloader.dataset.bert_tokenizer)
        vilt_model.train()
        


def main_worker(gpu, ngpus_per_node, log_queue, args):
    args.gpu = gpu
    args.rank = gpu
    setup_worker_logging(args.rank, log_queue, args.log_level)

    # Log and save params.
    if is_master(args):
        logging.info("Params:")
        params_file = os.path.join(args.logs, args.name, "params.txt")
        with open(params_file, "w") as f:
            for name in sorted(vars(args)):
                val = getattr(args, name)
                logging.info(f"  {name}: {val}")
                f.write(f"{name}: {val}\n")
            
    if args.distributed:
        dist.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank,
        )
    
    if args.dp:
        args.batch_size *= args.world_size

    if args.gpu is not None:
        logging.info(f"Use GPU: {args.gpu} for training")
        torch.cuda.set_device(args.gpu)

    # Do not use skip_reset unless you want to use on of the CLIP model
    if args.openai_pretrained:
        model, preprocess_train, preprocess_val = load(
            args.model,
            jit=False,
            is_train=True)
    else:
        model_config_file = Path(__file__).parent / f"model_configs/{args.model.replace('/', '-')}.json"
        print('Loading model from', model_config_file)
        assert os.path.exists(model_config_file)
        with open(model_config_file, 'r') as f:
            model_info = json.load(f)
        model = CLIP(**model_info)
        convert_weights(model)
        teacher_model, *_ = load("ViT-B/32", jit=False, is_train=False)
        # teacher_model.load_state_dict(torch.load("/home/roy/open_clip_2/open_clip-main/src/LN-tuning-CLIP-coco.pt"))
        convert_weights(teacher_model)
        teacher_model.eval()
        for p in teacher_model.parameters():
            p.requires_grad = False
        model.transformer.load_state_dict(teacher_model.transformer.state_dict())
        model.token_embedding.load_state_dict(teacher_model.token_embedding.state_dict())
        model.positional_embedding.data.copy_(teacher_model.positional_embedding.data)
        model.text_projection.data.copy_(teacher_model.text_projection.data)
        if not args.tune_text:
            for p in model.transformer.parameters():
                p.requires_grad = False
            for p in model.token_embedding.parameters():
                p.requires_grad = False
            model.positional_embedding.requires_grad = False
            model.text_projection.requires_grad = False
        preprocess_train = _transform(model.visual.input_resolution, is_train=True)
        preprocess_val = _transform(model.visual.input_resolution, is_train=False)

        # model.load_state_dict(teacher_model.state_dict())
        # model.load_state_dict(torch.load("/home/roy/open_clip_2/open_clip-main/src/LN-tuning-CLIP-coco.pt"))

        # model.visual = Pit()
        model.visual = PretrainedViT()
        model.visual.load_state_dict(torch.load("/home/roy/open_clip_2/open_clip-main/src/PretrainedViTSmall_logitScalealso.pt"))
        # model.visual.load_state_dict(torch.load("/home/roy/open_clip_2/open_clip-main/src/pretrainedViTSmall_aftercc12m_addtemp10inpretrain.pt"))
        # model.visual.load_state_dict(torch.load("/home/roy/open_clip_2/open_clip-main/src/PretrainedViTSmall_largerbs_lowerWeightDecay_flickr30k.pt"))
        # for key in model.visual.state_dict():
        #     if key in teacher_model.visual.state_dict():
        #         model.visual.state_dict()[key].copy_(teacher_model.visual.state_dict()[key])
        #         print(f"{key} copied from teacher for PEViT")
        #     else:
        #         print(key)
        # cc4m_state_dict = torch.load("/home/roy/open_clip_2/open_clip-main/src/luanshi_2.pt")
        # logging.info("loaded ckpt after full finetuning on MSCOCO")
        # for key in model.visual.state_dict():
        #     if key in cc4m_state_dict:
        #         model.visual.state_dict()[key].copy_(cc4m_state_dict[key])
        #     else:
        #         print(key)
        # logging.info("Done copying ViT")
        # new_transformer = Transformer(width=512, layers=6, heads=8)
        # for i in range(len(new_transformer.resblocks)):
        #     new_transformer.resblocks[i].load_state_dict(teacher_model.transformer.resblocks[i].state_dict())
        # ckpt = torch.load(open("/home/roy/open_clip_2/open_clip-main/src/6layerTransformer_all_pretrain.pt", "rb"), map_location=torch.device('cpu'))
        # tr_keys = []
        # for key in ckpt:
        #     if key.startswith('transformer.'):
        #         tr_keys.append(key[12:])
        # for key in new_transformer.state_dict():
        #     if key in tr_keys:
        #         new_transformer.state_dict()[key].copy_(ckpt["transformer."+key])
        #         # print('In:', key)
        #     else:
        #         # print('Not in:', key)
        #         pass
        # del model.transformer
        # del model.visual.transformer
        # model.transformer = new_transformer
        # model.visual.transformer = new_transformer
        # logging.info(f"Layer of text encoder: {len(model.transformer.resblocks)}")
        # print('successfully replaced with first 6 layers of Image ViT')
        if not args.prompt_tuning:
            for p in model.parameters():
                p.requires_grad = False
            for n, p in model.visual.named_parameters():
                p.requires_grad = True
            # for n, p in model.visual.named_parameters():
            #     if 'norm' in n:
            #         p.requires_grad = True
            # model.logit_scale.requires_grad = True
        else:
            model.visual.model.prompt_embedding.requires_grad = True


    # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
    if args.precision == "amp" or args.precision == "fp32" or args.gpu is None:
        convert_models_to_fp32(model)
        convert_models_to_fp32(teacher_model)

    # test text encoder speed up
    # with torch.no_grad():
    #     dummy_input = torch.randn(1, 3, 224, 224)
    #     model.eval()
    #     model.encode_image(dummy_input)
    #     torch.cuda.synchronize()
    #     s = time.time()
    #     cnt = 0
    #     while True:
    #         model.encode_image(dummy_input)
    #         torch.cuda.synchronize()
    #         e = time.time()
    #         if (e-s) >= 1.0:
    #             break
    #         cnt += 1
    #     print(cnt)
    # exit()

    if not torch.cuda.is_available():
        model.float()
        logging.warning("using CPU, this will be slow")
    else:
        model.cuda(args.gpu)
        teacher_model.cuda(args.gpu)
        if args.precision == "fp16":
            convert_weights(model)
            convert_weights(teacher_model)
        # Previously batch size and workers were global and not per GPU.
        # args.batch_size = args.batch_size / ngpus_per_node)
        # args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

        if args.distributed and args.use_bn_sync:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.distributed:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        if args.dp:
            model = torch.nn.DataParallel(model, device_ids=args.multigpu)

        if args.precision == "fp16":
            convert_weights(model)
            convert_weights(teacher_model)

    data = get_data(args, (preprocess_train, preprocess_val))
    gcc_data = get_cc_data(args, (preprocess_train, preprocess_val))


    # logging.info('Start computing Faiss index')
    # flickr30k_caption_features = [np.array([0.0]*512) for _ in range(len(data['train'].dataloader.dataset))]
    # flickr30k_caption_index = faiss.IndexFlatIP(512)
    # with torch.no_grad():
    #     for batch in tqdm(data['train'].dataloader, total=len(data['train'].dataloader)):
    #         images, _, indices = batch
    #         images = images.cuda(args.gpu)
    #         images_feature = model.encode_image(images)
    #         images_feature = images_feature / images_feature.norm(dim=-1, keepdim=True)
    #         images_feature = images_feature.cpu().detach().numpy().astype(np.float32)
    #         for i in range(len(indices)):
    #             flickr30k_caption_features[indices[i].item()] = images_feature[i]
    # flickr30k_caption_features = np.stack(flickr30k_caption_features, axis=0).astype(np.float32)
    # flickr30k_caption_index.add(flickr30k_caption_features)
    # # faiss.write_index(flickr30k_caption_index, "CLIP_images.index")
    # faiss.write_index(flickr30k_caption_index, "logitScalealso_images.index")
    # logging.info('Done')
    # exit()

    exclude = lambda n : "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
    include = lambda n : not exclude(n)

    if not args.prompt_tuning:
        named_parameters = list()
        for n, p in model.named_parameters():
            if p.requires_grad == True:
                named_parameters.append((n, p))
        # for n, p in model.named_parameters():
        #     named_parameters.append((n, p))
    else:
        prompt_params = []
        for n, p in model.visual.model.named_parameters():
            if 'prompt' in n or 'head' in n:
                prompt_params.append((n, p))
        named_parameters = prompt_params

    object_embedding = None
    mapping_layer = None
        
    gain_or_bias_params = [p for n, p in named_parameters if exclude(n) and p.requires_grad]
    rest_params = [p for n, p in named_parameters if include(n) and p.requires_grad]

    if args.train_data is None:
        optimizer = None
        scheduler = None
    else:
        optimizer = optim.AdamW(
            [
                {"params": gain_or_bias_params, "weight_decay": 0.},
                {"params": rest_params, "weight_decay": args.wd},
            ],
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            eps=args.eps,
        )
        total_steps = data["train"].dataloader.num_batches * args.epochs
        scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)


    scaler = GradScaler() if args.precision == "amp" else None

    # optionally resume from a checkpoint
    start_epoch = 0

    cudnn.benchmark = True
    cudnn.deterministic = False

    # determine if this worker should save logs and checkpoints.
    # only do so if it is the 0th worker.
    args.save_logs = (args.logs is not None and args.logs != '' and args.logs.lower() != 'none') and (
        (not args.distributed) or args.gpu == 0
    )
    writer = None
    if args.save_logs and args.tensorboard:
        writer = SummaryWriter(args.tensorboard_path)

    if args.train_data is None:
        evaluate(model, data, start_epoch, args, writer, 0)
        return
    elif start_epoch == 0 and args.val_data is not None:
        pass

    if start_epoch == 1:
        logging.info("Start training with combined objective")
    else:
        logging.info("Start pretraining with contrastive loss")

    # evaluate_correct_test(model, args)

    for epoch in range(start_epoch, args.epochs):
        if args.gpu == 0:
            logging.info(f'Start epoch {epoch}')
        train(object_embedding, mapping_layer, model, teacher_model, data, gcc_data, epoch, optimizer, scaler, scheduler, args, writer)
        steps = data["train"].dataloader.num_batches * (epoch + 1)
        if args.val_data is not None:
            evaluate_correct_val(model, args)
        # Saving checkpoints.
        # if args.save_logs and (args.gpu == 0 or (not args.distributed)):
        #     if (epoch + 1) == args.epochs or (
        #         args.save_frequency > 0 and ((epoch + 1) % args.save_frequency) == 0
        #     ):
        #         torch.save(
        #             {
        #                 "epoch": epoch + 1,
        #                 "name": args.name,
        #                 "state_dict": model.state_dict(),
        #                 "optimizer": optimizer.state_dict(),
        #             },
        #             os.path.join(args.checkpoint_path, f"epoch_{epoch + 1}.pt"),
        #         )

    if args.wandb and (args.gpu == 0 or (not args.distributed)):
        wandb.finish()


def main():
    args = parse_args()

    # get the name of the experiments
    if args.name is None:
        args.name = strftime(
            f"lr={args.lr}_"
            f"wd={args.wd}_"
            f"agg={args.aggregate}_"
            f"model={args.model}_"
            f"batchsize={args.batch_size}_workers={args.workers}_date=%Y-%m-%d-%H-%M-%S",
            gmtime(),
        )

    args.log_path = os.path.join(args.logs, args.name, "out.log")
    if os.path.exists(args.log_path):
        print(
            "Error. Experiment already exists. Use --name {} to specify a new experiment."
        )
        return -1

    assert args.precision in ['amp', 'fp16', 'fp32']
    #assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model)

    args.ngpus_per_node = torch.cuda.device_count()

    args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
    args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to

    args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else ''
    args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
    for dirname in [args.tensorboard_path, args.checkpoint_path]:
        if dirname:
            os.makedirs(dirname, exist_ok=True)
    

    # Set multiprocessing type to spawn.
    # This is important for logging to work with multiprocessing.
    # torch.multiprocessing.set_start_method("spawn")

    # Set logger
    args.log_level = logging.DEBUG if args.debug else logging.INFO
    log_queue = setup_primary_logging(args.log_path, args.log_level)

    # Distributed training = training on more than one GPU.
    # Also easily possible to extend to multiple nodes & multiple GPUs.
    args.distributed = (args.gpu is None) and torch.cuda.is_available() and (not args.dp)
    if args.distributed:
        ngpus_per_node = torch.cuda.device_count()
        args.world_size = ngpus_per_node
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, log_queue, args))
    else:
        if args.dp:
            args.gpu = args.multigpu[0]
            args.world_size = len(args.multigpu)
        else:
            args.world_size = 1
        main_worker(args.gpu, None, log_queue, args)
        # main_worker_vilt(args.gpu, None, log_queue, args)


if __name__ == "__main__":
    main()
