import math
from imagenetv2_pytorch.ImageNetV2_dataset import ImageNetV2Dataset
from tqdm import tqdm, trange
from PIL import Image
from training.transforms import keys_to_transforms
import os
import time
import json
import numpy as np

import torch
import torch.nn as nn

from torch.cuda.amp import autocast
import torch.distributed as dist
from torch.nn.modules.loss import KLDivLoss
from torchvision.datasets.folder import is_image_file
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.nn.utils.clip_grad import clip_grad_norm_
from clip.clip import tokenize
from .zero_shot import zero_shot_eval
import wandb
import logging
from tqdm import tqdm
import random
import cv2
import faiss
caption_index = faiss.read_index('caption.index')
images_index = faiss.read_index('logitScalealso_images.index')
# try:
#     caption_index = faiss.read_index('/home/roy/open_clip_2/open_clip-main/src/flickr30k_caption.index')
#     images_index = faiss.read_index('/home/roy/open_clip_2/open_clip-main/src/flickr30k_largerbs_images.index')
# except:
#     logging.info("Flickr30k index not exist")

best_r10 = 0.0


def is_master(args):
    return (not args.distributed) or args.gpu == 0


@torch.no_grad()
def js_divergence(dist1, dist2):
    middle_dist = (dist1 + dist2) / 2
    log_dist1 = torch.log_softmax(dist1, dim=-1)
    log_dist2 = torch.log_softmax(dist2, dim=-1)
    log_middle_dist = torch.log_softmax(middle_dist, dim=-1)
    kld_loss = nn.KLDivLoss(log_target=True)
    js_loss = (kld_loss(log_dist1, log_middle_dist) +
               kld_loss(log_dist2, log_middle_dist)) / 2
    return js_loss.sum().cpu().item()


def get_metrics(args, image_features, text_features, object_embeddings, mapped_img_features):
    metrics = {}
    if args.object_embedding and object_embeddings is not None and True:
        hit1 = 0
        hit5 = 0
        hit10 = 0
        # text2obj = torch.Tensor(text_features.shape[0], 80).fill_(-1.0)
        # text2obj_pred = text_features @ object_embeddings.weight.cpu().t() # (, 80)
        # _, topk = torch.topk(text2obj_pred, k=args.k, dim=-1)
        # for i in range(text_features.shape[0]):
        #     tmp = topk[i].cpu().numpy().tolist()
        #     for j in range(80):
        #         if j in tmp:
        #             text2obj[i, j] = 1.0
        # img2obj = mapped_img_features @ object_embeddings.weight.cpu().t()
        # img2obj[img2obj>=0.0] = 1.0
        # img2obj[img2obj<0.0] = -1.0
        # obj_text2img = text2obj @ img2obj.t()
        # for i in trange(obj_text2img.shape[0]):
        #     _, topk1000 = torch.topk(obj_text2img[i], k=3000)
        #     new_img_feats = []
        #     new_img_ids = []
        #     for id in topk1000:
        #         new_img_ids.append(id.item())
        #         new_img_feats.append(image_features[id.item()])
        #     new_img_feats  = torch.stack(new_img_feats, dim=0)
        #     final_logits = text_features[i].unsqueeze(0) @ new_img_feats.t()
        #     final_img_ids = []
        #     _, final_topk10 = torch.topk(final_logits[0], k=10)
        #     for id in final_topk10:
        #         final_img_ids.append(topk1000[id.item()])
        #     if i in final_img_ids[:1]:
        #         hit1 += 1
        #     if i in final_img_ids[:5]:
        #         hit5 += 1
        #     if i in final_img_ids[:10]:
        #         hit10 += 1
        text2obj = text_features @ object_embeddings.weight.cpu().t()
        image2obj = image_features @ object_embeddings.weight.cpu().t()
        text2obj_prob = torch.softmax(text2obj, dim=1)
        image2obj_prob = torch.softmax(image2obj, dim=1)
        for i in trange(len(text2obj)):
            js_losses = []
            # for j in range(len(image2obj)):
            #     js_loss = js_divergence(text2obj[i].unsqueeze(0), image2obj[j].unsqueeze(0))
            #     js_losses.append(js_loss)
            kld_losses = torch.sum(image2obj_prob * (image2obj_prob.log(
            ) - text2obj_prob[i].unsqueeze(0).repeat(len(image2obj_prob), 1).log()), dim=1)
            assert kld_losses.size(0) == len(image2obj)
            js_losses = kld_losses
            top2500_ids = torch.argsort(js_losses, descending=False)[:3500]
            top2500_image_features = []
            for k in range(len(top2500_ids)):
                top2500_image_features.append(
                    image_features[top2500_ids[k].item()])
            top2500_image_features = torch.stack(top2500_image_features, dim=0)
            sim_matrix = text_features[i].unsqueeze(
                0) @ top2500_image_features.t()  # (1, 2500)
            _, top10_ids = torch.topk(sim_matrix[0], k=10)
            final_ids = []
            for id in top10_ids:
                final_ids.append(top2500_ids[id.item()].item())
            if i in final_ids[:1]:
                hit1 += 1
            if i in final_ids[:5]:
                hit5 += 1
            if i in final_ids[:10]:
                hit10 += 1
        metrics['text_to_image_R@1'] = hit1 / len(text_features)
        metrics['text_to_image_R@5'] = hit5 / len(text_features)
        metrics['text_to_image_R@10'] = hit10 / len(text_features)
        return metrics
    else:
        logits_per_image = image_features @ text_features.t()
    logits_per_text = logits_per_image.t()

    logits = {"image_to_text": logits_per_image,
              "text_to_image": logits_per_text}
    ground_truth = (
        torch.arange(len(text_features)).view(-1,
                                              1).to(logits_per_image.device)
    )

    for name, logit in logits.items():
        ranking = torch.argsort(logit, descending=True)
        preds = torch.where(ranking == ground_truth)[1]
        preds = preds.detach().cpu().numpy()
        metrics[f"{name}_mean_rank"] = preds.mean() + 1
        metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
        for k in [1, 5, 10]:
            metrics[f"{name}_R@{k}"] = np.mean(preds < k)

    return metrics


def get_contrastive_loss(model, teacher_model, images, texts, loss_img, args):
    """
    Computing contrastive loss when training with Google Conceptual Caption data.
    """
    with torch.no_grad():
        t_image_features, t_text_features, t_logit_scale = teacher_model(
            images, texts)
        # t_image_dist = torch.softmax(t_image_features/0.1, dim=-1)
    image_features, text_features, logit_scale = model(images, texts)
    # image_logdist = torch.log_softmax(image_features/0.1, dim=-1)

    # total_loss = .0
    # kld_lossfn = nn.KLDivLoss()
    # mse_lossfn = nn.MSELoss()
    # total_loss = kld_lossfn(image_logdist, t_image_dist)
    # mse_loss = mse_lossfn(image_features, t_image_features)
    # total_loss += mse_loss
    # return total_loss
    # ground_truth = torch.arange(len(images)).long().to(images.device)
    # img2text = logit_scale * image_features @ text_features.t()
    # t_img2text = t_logit_scale * t_image_features @ t_text_features.t()
    # total_loss = (
    #     loss_img(img2text, ground_truth) +
    #     loss_img(img2text.t(), ground_truth)
    # ) / 2
    # kld_loss = kld_lossfn(torch.log_softmax(img2text, dim=1), torch.softmax(t_img2text, dim=1)) + kld_lossfn(torch.log_softmax(img2text.t(), dim=1), torch.softmax(t_img2text.t(), dim=1))
    # kld_loss = kld_loss / 2
    # total_loss += kld_loss
    # return total_loss

    # contrastive loss
    img2img_matrix = image_features @ t_image_features.t()
    # text2text_matrix = text_features @ t_text_features.t()
    ground_truth = torch.arange(len(images)).long().to(images.device)
    # t2t_loss = loss_img(text2text_matrix, ground_truth)
    # mse_loss_fn = nn.MSELoss()
    i2i_loss = loss_img(img2img_matrix, ground_truth)
    # i2i_mse_loss = mse_loss_fn(image_features, t_image_features)
    total_loss = i2i_loss
    return total_loss


def get_loss(object_embedding, mapping_layer, model, teacher_model, images, texts, indices, loss_img, loss_txt, args, step):
    kld_lossfn = nn.KLDivLoss(log_target=True).to(images.device)
    with torch.no_grad():
        t_image_features, t_text_features, t_logit_scale = teacher_model(
            images, texts)
    image_features, text_features, logit_scale = model(images, texts)
    logit_scale = logit_scale.mean()
    if args.distributed and args.aggregate:
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        # We gather tensors from all gpus to get more negatives to contrast with.
        gathered_image_features = [
            torch.zeros_like(image_features) for _ in range(world_size)
        ]
        gathered_text_features = [
            torch.zeros_like(text_features) for _ in range(world_size)
        ]
        dist.all_gather(gathered_image_features, image_features)
        dist.all_gather(gathered_text_features, text_features)

        all_image_features = torch.cat(
            [image_features]
            + gathered_image_features[:rank]
            + gathered_image_features[rank + 1:]
        )
        all_text_features = torch.cat(
            [text_features]
            + gathered_text_features[:rank]
            + gathered_text_features[rank + 1:]
        )

        # this is needed to send gradients back everywhere.
        logits_per_image = logit_scale * all_image_features @ all_text_features.t()
        logits_per_text = logits_per_image.t()
    else:
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logit_scale * text_features @ image_features.t()

    ground_truth = torch.arange(len(logits_per_image)).long()
    if args.gpu is not None:
        ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)

    total_loss = 0.0

    # in-batch negative
    total_loss += (
        loss_img(logits_per_image, ground_truth)
        + loss_txt(logits_per_text, ground_truth)
    ) / 2


    # response-based KD
    # t_logits_per_image = t_logit_scale * t_image_features @ t_text_features.t()
    # normalized_t_i2t = torch.log_softmax(t_logits_per_image, dim=-1)
    # t_logits_per_text = t_logit_scale * t_text_features @ t_image_features.t()
    # normalized_t_t2i = torch.log_softmax(t_logits_per_text, dim=-1)
    # kld_loss = (kld_lossfn(torch.log_softmax(logits_per_image, dim=1), normalized_t_i2t) +
    #             kld_lossfn(torch.log_softmax(logits_per_text, dim=1), normalized_t_t2i)) / 2
    # kld_loss += (kld_lossfn(normalized_t_i2t, torch.log_softmax(logits_per_image, dim=1)) +
    #              kld_lossfn(normalized_t_t2i, torch.log_softmax(logits_per_text, dim=1))) / 2
    # total_loss += kld_loss


    # contrastive KD
    img2img_matrix = image_features @ t_image_features.t()
    total_loss += loss_img(img2img_matrix, ground_truth)
    # text2text_matrix = text_features @ t_text_features.t()
    # total_loss += loss_img(text2text_matrix, ground_truth)

    # Hard negative for image enocder
    if args.ance:
        bs = images.shape[0]
        image_features_np = image_features.cpu().detach().numpy()
        _, caption_indices = caption_index.search(x=image_features_np, k=bs)
        ANCE_gt = torch.tensor([0]).long().to(images.device)
        ANCE_loss = 0.0
        for i in range(bs):
            cur_id = indices[i].item()
            caption_ids = list(caption_indices[i])
            if cur_id in caption_ids:
                caption_ids.pop(caption_ids.index(cur_id))
            else:
                caption_ids = random.sample(caption_ids, bs-1)
            assert len(caption_ids) == (bs-1)
            cur_img_hardnegtextfeats = None
            for id in caption_ids:
                id = int(id)
                cap_feat = caption_index.reconstruct(id)
                if cur_img_hardnegtextfeats is None:
                    cur_img_hardnegtextfeats = cap_feat[None, :]
                else:
                    cur_img_hardnegtextfeats = np.concatenate(
                        [cur_img_hardnegtextfeats, cap_feat[None, :]], axis=0)
            cur_img_hardnegtextfeats = torch.from_numpy(
                cur_img_hardnegtextfeats).float().to(images.device)  # (bs-1, 512)
            assert cur_img_hardnegtextfeats.shape[0] == (bs-1)
            cur_img_totaltextfeats = torch.cat(
                [text_features[i].unsqueeze(0), cur_img_hardnegtextfeats], dim=0)  # (bs, 51)
            assert cur_img_totaltextfeats.shape[0] == bs
            cur_img_logits = logit_scale * \
                image_features[i].unsqueeze(
                    0) @ cur_img_totaltextfeats.t()  # (1, bs)
            cur_img_loss = loss_img(cur_img_logits, ANCE_gt)
            ANCE_loss += cur_img_loss
        ANCE_loss /= bs
        total_loss += ANCE_loss

    # Hard negative for text enocder
    # if args.ance:
    #     bs = images.shape[0]
    #     text_features_np = text_features.cpu().detach().numpy()
    #     _, image_indices = images_index.search(x=text_features_np, k=bs)
    #     ANCE_gt = torch.tensor([0]).long().to(images.device)
    #     ANCE_loss = 0.0
    #     for i in range(bs):
    #         cur_id = indices[i].item()
    #         images_ids = list(image_indices[i])
    #         if cur_id in images_ids:
    #             images_ids.pop(images_ids.index(cur_id))
    #         else:
    #             images_ids = random.sample(images_ids, bs-1)
    #         assert len(images_ids) == (bs-1)
    #         cur_text_hardnegimgfeats = None
    #         for id in images_ids:
    #             id = int(id)
    #             img_feat = images_index.reconstruct(id)
    #             if cur_text_hardnegimgfeats is None:
    #                 cur_text_hardnegimgfeats = img_feat[None, :]
    #             else:
    #                 cur_text_hardnegimgfeats = np.concatenate(
    #                     [cur_text_hardnegimgfeats, img_feat[None, :]], axis=0)
    #         cur_text_hardnegimgfeats = torch.from_numpy(
    #             cur_text_hardnegimgfeats).float().to(images.device)  # (bs-1, 512)
    #         assert cur_text_hardnegimgfeats.shape[0] == (bs-1)
    #         cur_text_totalimgfeats = torch.cat(
    #             [image_features[i].unsqueeze(0), cur_text_hardnegimgfeats], dim=0)  # (bs, 51)
    #         assert cur_text_totalimgfeats.shape[0] == bs
    #         cur_text_logits = logit_scale * \
    #             text_features[i].unsqueeze(
    #                 0) @ cur_text_totalimgfeats.t()  # (1, bs)
    #         cur_text_loss = loss_img(cur_text_logits, ANCE_gt)
    #         ANCE_loss += cur_text_loss
    #     ANCE_loss /= bs
    #     total_loss += ANCE_loss

    return total_loss


def train(object_embedding, mapping_layer, model, teacher_model, data, gcc_data, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
    os.environ["WDS_EPOCH"] = str(epoch)

    model.train()

    coco_dataloader, coco_sampler = data['train'].dataloader,  data['train'].sampler
    gcc_dataloader, gcc_sampler = gcc_data.dataloader, gcc_data.sampler
    if epoch <= 0:
        dataloader = gcc_dataloader
        sampler = gcc_sampler
        logging.info("Training with merged GCC/Flickr30k")
    else:
        dataloader = coco_dataloader
        sampler = coco_sampler
        logging.info("Training with MS COCO")

    loss_img = nn.CrossEntropyLoss()
    loss_txt = nn.CrossEntropyLoss()
    if args.gpu is not None:
        loss_img = loss_img.cuda(args.gpu)
        loss_txt = loss_txt.cuda(args.gpu)

    if args.distributed and sampler is not None:
        sampler.set_epoch(epoch)

    num_batches_per_epoch = dataloader.num_batches

    end = time.time()
    acc_loss = .0
    for i, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
        step = num_batches_per_epoch * epoch + i

        optimizer.zero_grad()

        images, texts, indices = batch
        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
            texts = texts.cuda(args.gpu, non_blocking=True)

        data_time = time.time() - end

        m = model.module if args.distributed or args.dp else model

        # with automatic mixed precision.
        if args.precision == "amp":
            with autocast():
                total_loss = get_loss(
                    model, teacher_model, images, texts, loss_img, loss_txt, args)
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
            scaler.update()

        else:
            if epoch > 0:
                total_loss = get_loss(object_embedding, mapping_layer, model, teacher_model,
                                      images, texts, indices, loss_img, loss_txt, args, step)
                total_loss.backward()
                acc_loss += total_loss.item()
                clip_grad_norm_(model.transformer.parameters(), max_norm=0.25)
                optimizer.step()
                scheduler(step)
            else:
                total_loss = get_contrastive_loss(
                    model, teacher_model, images, texts, loss_img, args)
                total_loss.backward()
                acc_loss += total_loss.item()
                optimizer.step()
                scheduler(step)

        # Note: we clamp to 4.6052 = ln(100), as in the original paper.
        m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)

        batch_time = time.time() - end
        end = time.time()

        if is_master(args) and (i % 100) == 0:
        # if is_master(args) and (i % 50) == 0:
            num_samples = i * len(images) * args.world_size
            samples_per_epoch = dataloader.num_samples
            percent_complete = 100.0 * i / num_batches_per_epoch
            logging.info(
                f"Train Epoch: {epoch} [{num_samples}/{samples_per_epoch} ({percent_complete:.0f}%)]\t"
                f"Loss: {acc_loss / (i+1):.6f}\tBatch (t) {batch_time:.3f}"
                f"\tLR: {optimizer.param_groups[0]['lr']:5f}\tlogit_scale {m.logit_scale.data:.3f}"
            )
            # save train loss / etc.

            timestep = epoch * num_batches_per_epoch + i
            log_data = {
                "loss": total_loss.item(),
                "data_time": data_time,
                "batch_time": batch_time,
                "scale":  m.logit_scale.data.item(),
                "lr": optimizer.param_groups[0]["lr"]
            }

            for name, val in log_data.items():
                name = "train/" + name
                if tb_writer is not None:
                    tb_writer.add_scalar(name, val, timestep)
                if args.wandb:
                    wandb.log({name: val, 'step': timestep})

        # only evaluate in MSCOCO training stage
        if is_master(args) and (i % 250) == 0 and i > 0 and epoch > 0:
        # if is_master(args) and (i % 50) == 0 and i > 0 and epoch > 0:
            evaluate_correct_val(model, args)
            model.train()

    if epoch == 0:
        # save ckpt after pretraining with contrastive loss on GCC/flickr30k
        torch.save(
            model.visual.state_dict(),
            "pit_small_after_cc12m.pt"
        )
        logging.info("ckpt after GCC/Flickr is saved")


@torch.no_grad()
def evaluate_vilt(vilt_model, data, epoch, args, tokenizer):
    vilt_model.eval()
    coco_test_json_pth = "/home/roy/CLIP/albef_data/coco_test.json"
    coco_test_json = json.load(open(coco_test_json_pth, "r"))
    all_caption = []
    all_images = []
    image2id = dict()
    ground_truth_image_id = []
    fold = 0
    for img_id, img_captions in tqdm(enumerate(coco_test_json), total=len(coco_test_json)):
        if fold*1000 <= img_id < (fold+1)*1000:
            for i, caption in enumerate(img_captions['caption']):
                full_image_path = img_captions['image']
                full_image_path = os.path.join(
                    "/home/roy/CLIP", full_image_path)
                if not os.path.exists(full_image_path):
                    full_image_path = os.path.join(
                        "/home/roy/mscoco/images/train2017", img_id)
                assert os.path.exists(full_image_path), f"{full_image_path}"
                if not full_image_path in image2id:
                    image2id[full_image_path] = len(image2id)
                    all_images.append(full_image_path)
                all_caption.append(caption)
                ground_truth_image_id.append(image2id[full_image_path])

    # encode all images
    pixlebert_transform = keys_to_transforms(['pixelbert'])[0]
    transformed_images = []
    for image_path in tqdm(all_images, total=len(all_images)):
        transformed_image = pixlebert_transform(
            Image.open(image_path))  # (3, 608, 608)
        transformed_images.append(transformed_image)
    transformed_images = torch.stack(
        transformed_images, dim=0)  # (5000, 3, 608, 608)
    # image_embeddings = []
    # for i in trange(0, len(transformed_images), 25):
    #     input_img = transformed_images[i:min(i+25, len(transformed_images))].cuda(args.gpu)
    #     img_features = vilt_model.encode_image_with_prompt(input_img).to(torch.device('cpu')) # (25, 768)
    #     image_embeddings.append(img_features)
    # image_embeddings = torch.cat(image_embeddings, dim=0) # (5000, 768)

    # # encode all captions
    # text_embeddings = []
    # for i in trange(0, len(all_caption), 50):
    #     captions = all_caption[i:min(i+50, len(all_caption))]
    #     input_ids = tokenizer(captions, return_tensors='pt', max_length=45, padding='max_length', truncation=True).input_ids.cuda(args.gpu)
    #     text_features = vilt_model.encode_texts(input_ids).to(torch.device('cpu'))
    #     text_embeddings.append(text_features)
    # text_embeddings = torch.cat(text_embeddings, dim=0)

    hit1 = 0
    hit5 = 0
    hit10 = 0
    # embed_hit1 = 0
    # embed_hit5 = 0
    # embed_hit10 = 0
    # t2i_matrix = text_embeddings @ image_embeddings.t() # (25000, 5000)
    cnt = 0
    for i in trange(len(all_caption)):
        # _, topk_ids = torch.topk(t2i_matrix[i], k=15)
        # topk_img_features = torch.stack([transformed_images[id.item()] for id in topk_ids], dim=0).cuda(args.gpu)
        text_ids = tokenizer(
            all_caption[i], return_tensors='pt', max_length=40, truncation=True).input_ids
        # ranking_scores = vilt_model.sim_score(topk_img_features, text_ids.cuda(args.gpu))
        all_ranking_scores = []
        for j in trange(0, len(transformed_images), 32):
            batched_transformed_images = transformed_images[j:min(
                j+32, len(transformed_images))].cuda(args.gpu)  # (32, 3, 224, 224)
            ranking_scores = vilt_model.sim_score_raw(
                batched_transformed_images, text_ids.cuda(args.gpu)).cpu().detach().squeeze(0)
            if len(all_ranking_scores) == 0:
                all_ranking_scores = ranking_scores
            else:
                all_ranking_scores = torch.cat(
                    [all_ranking_scores, ranking_scores], dim=0)
        assert len(all_ranking_scores) == len(transformed_images)
        _, top10_ids = torch.topk(all_ranking_scores, k=10)
        top10_ids = top10_ids.numpy().tolist()
        if ground_truth_image_id[i] in top10_ids[:1]:
            hit1 += 1
        if ground_truth_image_id[i] in top10_ids[:5]:
            hit5 += 1
        if ground_truth_image_id[i] in top10_ids[:10]:
            hit10 += 1
        cnt += 1
        if cnt == 50:
            break
    logging.info(f"R@1: {hit1 / cnt:.2f}")
    logging.info(f"R@5: {hit5 / cnt:.2f}")
    logging.info(f"R@10: {hit10 / cnt:.2f}")

def evaluate_correct_test_1k(model, args):
    def centerSizeCrop(image, crop_size):
        rows, cols = image.shape[:2]
        x = round((cols - crop_size) / 2.0)
        y = round((rows - crop_size) / 2.0)
        img = image[y:y+crop_size, x:x+crop_size]
        return img

    def _transform(n_px):
        return Compose([
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073),
                      (0.26862954, 0.26130258, 0.27577711)),
        ])
    model.eval()
    preprocess = _transform(224)

    r1s, r5s, r10s = 0.0, 0.0, 0.0
    for fold in range(5):
        all_caption = []
        all_images = []
        image2id = dict()
        ground_truth_image_id = []

        if 'coco' in args.train_data:
            val_json_pth = "/home/roy/CLIP/albef_data/coco_test.json"
        else:
            val_json_pth = "/home/roy/CLIP/albef_data/flickr30k_test.json"
        val_json = json.load(open(val_json_pth, "r"))
        for img_id, img_captions in tqdm(enumerate(val_json), total=len(val_json)):
            if (fold*1000)<=img_id<(fold*1000+1000):
                for i, caption in enumerate(img_captions['caption']):
                    full_image_path = img_captions['image']
                    assert os.path.exists(full_image_path), f"{full_image_path}"
                    if not full_image_path in image2id:
                        image2id[full_image_path] = len(image2id)
                        all_images.append(full_image_path)
                    all_caption.append(caption)
                    ground_truth_image_id.append(image2id[full_image_path])

        all_image_vectors = None
        all_image_features = []
        for image_path in tqdm(all_images, total=len(all_images)):
            img = cv2.imread(image_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(
                img, (224, 224), interpolation=cv2.INTER_CUBIC)
            img = centerSizeCrop(img, 224)
            feature = preprocess(img).unsqueeze(0).to(
                torch.device('cpu'))  # (1, 3, 224, 224)
            all_image_features.append(feature)
        with torch.no_grad():
            for i in tqdm(range(0, len(all_image_features), 10)):
                # (10, 3, 224, 224)
                tmp_feat = torch.cat(
                    all_image_features[i:i+10], dim=0).cuda(args.gpu)
                image_vector = model.encode_image(tmp_feat)
                image_vector = image_vector / \
                    torch.norm(image_vector, dim=1, keepdim=True)
                image_vector = image_vector.cpu().detach().numpy().astype("float32")
                if all_image_vectors is None:
                    all_image_vectors = image_vector
                else:
                    all_image_vectors = np.concatenate(
                        [all_image_vectors, image_vector], axis=0)
        all_text_vectors = None
        with torch.no_grad():
            for id in trange(0, len(all_caption), 50):
                caption = all_caption[id:min(id+50, len(all_caption))]
                input_ids = tokenize(caption).cuda(args.gpu)  # (bs, 77)
                text_feature = model.encode_text(input_ids)
                text_feature = text_feature / \
                    torch.norm(text_feature, dim=1, keepdim=True)
                text_feature = text_feature.cpu().detach().numpy().astype('float32')
                if all_text_vectors is None:
                    all_text_vectors = text_feature
                else:
                    all_text_vectors = np.concatenate(
                        [all_text_vectors, text_feature], axis=0)
        sim_matrix = torch.from_numpy(
            all_text_vectors) @ torch.from_numpy(all_image_vectors).t()  # (25000, 5000)
        _, indices = sim_matrix.topk(k=10, dim=-1)
        gt_img_id = torch.tensor(ground_truth_image_id).long().unsqueeze(-1)
        hit = (indices == gt_img_id)
        r1 = torch.mean(torch.sum(hit[:, :1], dim=-1).float()).item()
        r5 = torch.mean(torch.sum(hit[:, :5], dim=-1).float()).item()
        r10 = torch.mean(torch.sum(hit[:, :10], dim=-1).float()).item()
        r1s += r1
        r5s += r5
        r10s += r10
    r1 = r1s / 5
    r5 = r5s / 5
    r10 = r10s / 5
    if 'flickr' in args.train_data:
        logging.info(
            f"Test on Flickr30K: text2image R@1:{r1:.3f}\ttext2image R@5:{r5:.3f}\ttext2image R@10:{r10:.3f}\n")
    else:
        logging.info("1K Testset")
        logging.info(
            f"Test on MSCOCO: text2image R@1:{r1:.3f}\ttext2image R@5:{r5:.3f}\ttext2image R@10:{r10:.3f}\n")


def evaluate_correct_test(model, args):
    def centerSizeCrop(image, crop_size):
        rows, cols = image.shape[:2]
        x = round((cols - crop_size) / 2.0)
        y = round((rows - crop_size) / 2.0)
        img = image[y:y+crop_size, x:x+crop_size]
        return img

    def _transform(n_px):
        return Compose([
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073),
                      (0.26862954, 0.26130258, 0.27577711)),
        ])
    model.eval()
    preprocess = _transform(224)

    all_caption = []
    all_images = []
    image2id = dict()
    ground_truth_image_id = []

    if 'coco' in args.train_data:
        val_json_pth = "/home/roy/CLIP/albef_data/coco_test.json"
    else:
        val_json_pth = "/home/roy/CLIP/albef_data/flickr30k_test.json"
    val_json = json.load(open(val_json_pth, "r"))
    for img_id, img_captions in tqdm(enumerate(val_json), total=len(val_json)):
        for i, caption in enumerate(img_captions['caption']):
            full_image_path = img_captions['image']
            assert os.path.exists(full_image_path), f"{full_image_path}"
            if not full_image_path in image2id:
                image2id[full_image_path] = len(image2id)
                all_images.append(full_image_path)
            all_caption.append(caption)
            ground_truth_image_id.append(image2id[full_image_path])

    all_image_vectors = None
    all_image_features = []
    for image_path in tqdm(all_images, total=len(all_images)):
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(
            img, (224, 224), interpolation=cv2.INTER_CUBIC)
        img = centerSizeCrop(img, 224)
        feature = preprocess(img).unsqueeze(0).to(
            torch.device('cpu'))  # (1, 3, 224, 224)
        all_image_features.append(feature)
    with torch.no_grad():
        for i in tqdm(range(0, len(all_image_features), 10)):
            # (10, 3, 224, 224)
            tmp_feat = torch.cat(
                all_image_features[i:i+10], dim=0).cuda(args.gpu)
            image_vector = model.encode_image(tmp_feat)
            image_vector = image_vector / \
                torch.norm(image_vector, dim=1, keepdim=True)
            image_vector = image_vector.cpu().detach().numpy().astype("float32")
            if all_image_vectors is None:
                all_image_vectors = image_vector
            else:
                all_image_vectors = np.concatenate(
                    [all_image_vectors, image_vector], axis=0)
    all_text_vectors = None
    with torch.no_grad():
        for id in trange(0, len(all_caption), 50):
            caption = all_caption[id:min(id+50, len(all_caption))]
            input_ids = tokenize(caption).cuda(args.gpu)  # (bs, 77)
            text_feature = model.encode_text(input_ids)
            text_feature = text_feature / \
                torch.norm(text_feature, dim=1, keepdim=True)
            text_feature = text_feature.cpu().detach().numpy().astype('float32')
            if all_text_vectors is None:
                all_text_vectors = text_feature
            else:
                all_text_vectors = np.concatenate(
                    [all_text_vectors, text_feature], axis=0)
    sim_matrix = torch.from_numpy(
        all_text_vectors) @ torch.from_numpy(all_image_vectors).t()  # (25000, 5000)
    _, indices = sim_matrix.topk(k=10, dim=-1)
    gt_img_id = torch.tensor(ground_truth_image_id).long().unsqueeze(-1)
    hit = (indices == gt_img_id)
    r1 = torch.mean(torch.sum(hit[:, :1], dim=-1).float())
    r5 = torch.mean(torch.sum(hit[:, :5], dim=-1).float())
    r10 = torch.mean(torch.sum(hit[:, :10], dim=-1).float())
    if 'flickr' in args.train_data:
        logging.info(
            f"Test on Flickr30K: text2image R@1:{r1:.3f}\ttext2image R@5:{r5:.3f}\ttext2image R@10:{r10:.3f}\n")
    else:
        logging.info("5K Testset")
        logging.info(
            f"Test on MSCOCO: text2image R@1:{r1:.3f}\ttext2image R@5:{r5:.3f}\ttext2image R@10:{r10:.3f}\n")


def evaluate_correct_val(model, args):
    def centerSizeCrop(image, crop_size):
        rows, cols = image.shape[:2]
        x = round((cols - crop_size) / 2.0)
        y = round((rows - crop_size) / 2.0)
        img = image[y:y+crop_size, x:x+crop_size]
        return img

    def _transform(n_px):
        return Compose([
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073),
                      (0.26862954, 0.26130258, 0.27577711)),
        ])
    model.eval()
    preprocess = _transform(224)

    all_caption = []
    all_images = []
    image2id = dict()
    ground_truth_image_id = []

    if 'coco' in args.train_data:
        val_json_pth = "/home/roy/CLIP/albef_data/coco_val.json"
    else:
        val_json_pth = "/home/roy/CLIP/albef_data/flickr30k_val.json"
    val_json = json.load(open(val_json_pth, "r"))
    for img_id, img_captions in tqdm(enumerate(val_json), total=len(val_json)):
        for i, caption in enumerate(img_captions['caption']):
            full_image_path = img_captions['image']
            assert os.path.exists(full_image_path), f"{full_image_path}"
            if not full_image_path in image2id:
                image2id[full_image_path] = len(image2id)
                all_images.append(full_image_path)
            all_caption.append(caption)
            ground_truth_image_id.append(image2id[full_image_path])

    all_image_vectors = None
    all_image_features = []
    for image_path in tqdm(all_images, total=len(all_images)):
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(
            img, (224, 224), interpolation=cv2.INTER_CUBIC)
        img = centerSizeCrop(img, 224)
        feature = preprocess(img).unsqueeze(0).to(
            torch.device('cpu'))  # (1, 3, 224, 224)
        all_image_features.append(feature)
    with torch.no_grad():
        for i in tqdm(range(0, len(all_image_features), 10)):
            # (10, 3, 224, 224)
            tmp_feat = torch.cat(
                all_image_features[i:i+10], dim=0).cuda(args.gpu)
            image_vector = model.encode_image(tmp_feat)
            image_vector = image_vector / \
                torch.norm(image_vector, dim=1, keepdim=True)
            image_vector = image_vector.cpu().detach().numpy().astype("float32")
            if all_image_vectors is None:
                all_image_vectors = image_vector
            else:
                all_image_vectors = np.concatenate(
                    [all_image_vectors, image_vector], axis=0)
    all_text_vectors = None
    with torch.no_grad():
        for id in trange(0, len(all_caption), 50):
            caption = all_caption[id:min(id+50, len(all_caption))]
            input_ids = tokenize(caption).cuda(args.gpu)  # (bs, 77)
            text_feature = model.encode_text(input_ids)
            text_feature = text_feature / \
                torch.norm(text_feature, dim=1, keepdim=True)
            text_feature = text_feature.cpu().detach().numpy().astype('float32')
            if all_text_vectors is None:
                all_text_vectors = text_feature
            else:
                all_text_vectors = np.concatenate(
                    [all_text_vectors, text_feature], axis=0)
    sim_matrix = torch.from_numpy(
        all_text_vectors) @ torch.from_numpy(all_image_vectors).t()  # (25000, 5000)
    _, indices = sim_matrix.topk(k=10, dim=-1)
    gt_img_id = torch.tensor(ground_truth_image_id).long().unsqueeze(-1)
    hit = (indices == gt_img_id)
    r1 = torch.mean(torch.sum(hit[:, :1], dim=-1).float())
    r5 = torch.mean(torch.sum(hit[:, :5], dim=-1).float())
    r10 = torch.mean(torch.sum(hit[:, :10], dim=-1).float())
    if 'flickr' in args.train_data:
        logging.info(
            f"Eval on Flickr30K: text2image R@1:{r1:.3f}\ttext2image R@5:{r5:.3f}\ttext2image R@10:{r10:.3f}\n")
    else:
        logging.info(
            f"Eval on MSCOCO: text2image R@1:{r1:.3f}\ttext2image R@5:{r5:.3f}\ttext2image R@10:{r10:.3f}\n")
    global best_r10
    if r10 > best_r10:
        best_r10 = r10
        torch.save(
            model.visual.state_dict(),
            f"pit_small_coco.pt"
        )
        logging.info("New best val performance, start testing")
        evaluate_correct_test(model, args)
        if not 'flickr' in args.train_data:
            evaluate_correct_test_1k(model, args)


def evaluate(object_embedding, mapping_layer, model, data, epoch, args, tb_writer=None, steps=None):
    if not is_master(args):
        return

    model.eval()

    dataloader = data['val'].dataloader

    loss_img = nn.CrossEntropyLoss()
    loss_txt = nn.CrossEntropyLoss()
    if args.gpu is not None:
        loss_img = loss_img.cuda(args.gpu)
        loss_txt = loss_txt.cuda(args.gpu)

    cumulative_loss = 0.0
    num_elements = 0.0
    all_image_features, all_text_features = [], []
    all_mapped_img_features = []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            images, texts, _ = batch
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                texts = texts.cuda(args.gpu, non_blocking=True)

            image_features, text_features, logit_scale = model(images, texts)
            if args.object_embedding:
                mapped_img_features = mapping_layer(image_features)
                all_mapped_img_features.append(
                    mapped_img_features.to(torch.device('cpu')))
            all_image_features.append(image_features.to(torch.device('cpu')))
            all_text_features.append(text_features.to(torch.device('cpu')))
            logit_scale = logit_scale.mean()
            logits_per_image = logit_scale * image_features @ text_features.t()
            logits_per_text = logits_per_image.t()

            ground_truth = torch.arange(len(images)).long()
            if args.gpu is not None:
                ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
            total_loss = (
                loss_img(logits_per_image, ground_truth)
                + loss_txt(logits_per_text, ground_truth)
            ) / 2

            batch_size = len(images)
            cumulative_loss += total_loss * batch_size
            num_elements += batch_size

        if not args.object_embedding:
            metrics = get_metrics(args,
                                  torch.cat(all_image_features), torch.cat(
                                      all_text_features), None, None,
                                  )
        else:
            metrics = get_metrics(args,
                                  torch.cat(all_image_features), torch.cat(
                                      all_text_features), object_embedding, torch.cat(all_mapped_img_features)
                                  )
        loss = cumulative_loss / num_elements
        metrics.update(
            **{"val_loss": loss.item(), "epoch": epoch, "num_elements": num_elements}
        )

        logging.info(
            f"Eval Epoch: {epoch} "
            + "\t".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
        )

        if args.save_logs:
            for name, val in metrics.items():
                if tb_writer is not None:
                    tb_writer.add_scalar(f"val/{name}", val, epoch)
        if args.wandb:
            for name, val in metrics.items():
                wandb.log({f"val/{name}": val, 'epoch': epoch})

    if args.save_logs:
        with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
            f.write(json.dumps(metrics))
            f.write("\n")
    global best_r10
    if metrics['text_to_image_R@10'] > best_r10:
        best_r10 = metrics['text_to_image_R@10']
        torch.save(
            model.visual.state_dict(),
            f"pretrainedViTSmall_flickr30k_ANCE.pt"
        )
    return metrics
