import os
import torch
import argparse
import numpy as np
from torch import nn
from dataloaders import coco, flickr30k

dataset_dict = {'coco': coco,
                'flickr': flickr30k}

def main(args):
    
    tokenize_model = tokenize
    
    model_image_rwkv = VRWKV6(img_size = args.input_size,
                            patch_size= args.Image_patch_size,
                            embed_dims = args.Image_embed_dims, 
                            hidden_rate= args.Image_hidden_rate, 
                            depth=args.Image_depth,
                            num_heads=args.Image_num_heads,
                            output_cls_token=args.Image_output_cls_token,
                            with_cls_token=args.Image_with_cls_token)
            
    model_text_rwkv = Text_RWKV(args)
    model = get_model(model_image_rwkv, model_text_rwkv, Image_cls_token=args.Image_output_cls_token)
    
    state_dict = torch.load(args.model_weight)
    state_dict_removed = {}
    for k, value in state_dict.items():
        k_removed = k
        if "module." in k_removed:
            k_removed = k.split("module.")[-1]
        if '_orig_mod.' in k_removed:
            k_removed = k_removed.split('_orig_mod.')[-1]
            state_dict_removed[k_removed] = value
        else:
            state_dict_removed[k_removed] = value
    model.load_state_dict(state_dict_removed, strict=True)
    
    model.eval()
    model.cuda()

    transform = get_transform(args.input_size)
    dataset_module = dataset_dict[args.dataset]
    assert hasattr(dataset_module, "get_loader_image")
    assert hasattr(dataset_module, "get_loader_text")

    kwargs_text = {
        "batch_size": args.batch_size,
        "preprocess": transform,
        "tokenize":tokenize_model}
    kwargs_image = {
        "batch_size": args.batch_size,
        "preprocess": transform}
    
    text_loader = dataset_module.get_loader_text(**kwargs_text)
    text_features = get_text_feature(model, text_loader, args)

    image_loader, txt2img, img2txt = dataset_module.get_loader_image(**kwargs_image)
    image_features = get_image_feature(model, image_loader, args)

    ## unified image & text dtype
    text_features = torch.tensor(text_features, dtype=torch.float32)
    image_features = torch.tensor(image_features, dtype=torch.float32)
    
    similarity_scores = image_features.cuda() @ text_features.cuda().t()
    similarity_scores = similarity_scores
    t2i_dict, i2t_dict = compute_retrieval(similarity_scores, txt2img, img2txt)
    print('Text retrieval', i2t_dict)
    print('Image retrieval', t2i_dict)
    with open(args.output_dir, 'a') as f:
        f.write(args.dataset + '\n')
        f.write('Text retrieval: ')
        f.write(str(i2t_dict) + '\n')
        f.write('Image retrieval: ')
        f.write(str(t2i_dict) + '\n')

def compute_retrieval(similarity_scores, txt2img, img2txt):
    # comput text -> image
    t2i_similarity_score = similarity_scores.t()
    t2i_ranks = torch.zeros(t2i_similarity_score.shape[0])

    for index, score in enumerate(t2i_similarity_score):
        inds = torch.argsort(score, descending= True)
        t2i_ranks[index] = torch.where(inds == txt2img[index])[0][0]
        print('Evaluating batch {}/{}, {}'.format(index, t2i_similarity_score.shape[0], t2i_ranks[index]), end = "\r")

    # Compute metrics
    tr1 = 100.0 * len(torch.where(t2i_ranks < 1)[0]) / len(t2i_ranks)
    tr5 = 100.0 * len(torch.where(t2i_ranks < 5)[0]) / len(t2i_ranks)
    tr10 = 100.0 * len(torch.where(t2i_ranks < 10)[0]) / len(t2i_ranks)

    t2i_report_dict = {"r1": tr1, "r5": tr5, "r10": tr10}


    #comput image -> text
    i2t_similarity_score = similarity_scores
    i2t_ranks = torch.zeros(i2t_similarity_score.shape[0])
    for index, score in enumerate(i2t_similarity_score):
        print('Evaluating batch {}/{}'.format(index, i2t_similarity_score.shape[0]), end = "\r")
        inds = torch.argsort(score, descending= True)
        # Score
        rank = 1e10
        for i in img2txt[index]:
            tmp = torch.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        i2t_ranks[index] = rank

    # Compute metrics
    ir1 = 100.0 * len(torch.where(i2t_ranks < 1)[0]) / len(i2t_ranks)
    ir5 = 100.0 * len(torch.where(i2t_ranks < 5)[0]) / len(i2t_ranks)
    ir10 = 100.0 * len(torch.where(i2t_ranks < 10)[0]) / len(i2t_ranks)
    i2t_report_dict = {"r1": ir1, "r5": ir5, "r10": ir10}
    return t2i_report_dict, i2t_report_dict

def get_image_feature(model, data_loader, args):
    image_features = []
    for batch_idx, batch in enumerate(data_loader):
        print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end = "\r")
        images, _ = batch
        images = images.cuda()
        image_embedding = WarperCLIP_V_T_RWKV_method(model, images)
        image_features.append(image_embedding.detach().cpu())

    image_features = torch.cat(image_features, 0)
    
    print('Done image feature extract.')
    print(image_features.shape)

    # normalized features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    return image_features

def get_text_feature(model, data_loader, args):
    text_features = []
    for batch_idx, batch in enumerate(data_loader):
        print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end = "\r")
        text = batch.squeeze()
        text = text.cuda()
        text_embedding = WarperCLIP_V_T_RWKV_text_change_head(model, text) # 
    
        text_features.append(text_embedding.detach().cpu())

    text_features = torch.cat(text_features, 0)
    print('Done text feature extract.')
    print(text_features.shape)

    # normalized features
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    return text_features

def get_transform(image_size):
    image_mean = (0.48145466, 0.4578275, 0.40821073)
    image_std = (0.26862954, 0.26130258, 0.27577711)
    preprocess = image_transform(image_size, is_train=False, mean=image_mean, std=image_std)
    return preprocess

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description="ZeroShot")
    parser.add_argument("--batch-size", default=128, type=int,
                        help="Name of the dataset to use.")
    parser.add_argument("--dataset", default="coco", type=str)#flickr&coco
    parser.add_argument("--model-type", default="rwkv_clip", type=str) #vision_rwkv/
    parser.add_argument("--model-name", default="ViT-B-32-384")
    parser.add_argument("--model-weight", default= "WEIGHT_PATH")
    parser.add_argument("--output_dir", default="OUTPUT_PATH", type=str, )
    
    ###################################################################################################
    ################################# Additional Setting 
    ###################################################################################################
    parser.add_argument("--precision", default="bf16", type=str)
    parser.add_argument("--dataset_type", default="original", type=str)
    parser.add_argument('--dropout', type=float, default=0.0, metavar='PCT',help='Dropout rate (default: 0.)')
    ################################# Image rwkv V4 #################################
    parser.add_argument("--my_testing_image", default='V6', type=str, help="Image rwkv version")
    parser.add_argument("--input-size", default=224, type=int, help="input_image_size")
    parser.add_argument("--Image_depth", default=12, type=int)
    parser.add_argument("--Image_embed_dims", default=384, type=int)
    parser.add_argument("--Image_patch_size", default=16, type=int)
    parser.add_argument("--Image_hidden_rate", default=4, type=int)
    parser.add_argument("--Image_num_heads", default=6, type=int)
    parser.add_argument("--Image_output_cls_token", default="False", type=str)
    parser.add_argument("--Image_with_cls_token", default="False", type=str)

    ################################# Text rwkv V6 #################################
    parser.add_argument("--data_type", default="utf-8", type=str)
    parser.add_argument("--ctx_len", default=77, type=int, help="")
    parser.add_argument("--vocab_size", default=49408, type=int, help="Vocabular number")
    parser.add_argument("--my_testing", default='V4', type=str, help="Text rwkv version")
    parser.add_argument("--text_initialization", default="True", type=str)
    parser.add_argument("--head_size_a", default=64, type=int) 
    parser.add_argument("--Text_num_head", default=0, type=int)
    parser.add_argument("--head_size_divisor", default=8, type=int)
    parser.add_argument("--n_layer", default=12, type=int)
    parser.add_argument("--n_embd", default=384, type=int)
    parser.add_argument("--dim_att", default=0, type=int)
    parser.add_argument("--dim_ffn", default=0, type=int)
    parser.add_argument("--pre_ffn", default=0, type=int) 
    parser.add_argument("--my_pos_emb", default=0, type=int)
    parser.add_argument("--head_qk", default=0, type=int) 
    parser.add_argument("--tiny_att_dim", default=0, type=int)  
    parser.add_argument("--tiny_att_layer", default=-999, type=int) 

    args = parser.parse_args()

    assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "uint16"]
    assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
    assert args.my_testing_image in ["V4", "V6"]
    assert args.Image_embed_dims == args.n_embd

    if args.text_initialization == "True":
        args.text_initialization = True
    else:
        args.text_initialization = False

    if args.Image_output_cls_token == "True":
        args.Image_output_cls_token = True
        args.Image_with_cls_token = True
    else:
        args.Image_output_cls_token = False
        args.Image_with_cls_token = False
        
    if args.dim_att <= 0:
        args.dim_att = args.n_embd
    if args.dim_ffn <= 0:
        args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
    if args.Text_num_head != 0:
        assert args.n_embd % args.Text_num_head == 0, "text embedding size can not divide head num"
        args.head_size_a = args.n_embd//args.Text_num_head
    args.with_cp = False
    
    os.environ["RWKV_MY_TESTING"] = str(args.my_testing)
    os.environ['RWKV_MY_TESTING_IMAGE'] = str(args.my_testing_image) 
    os.environ["RWKV_CTXLEN"] = str(args.ctx_len)
    os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
    os.environ['RWKV_FLOAT_MODE'] = str(args.precision)
    os.environ['Image_T_max'] = str((args.input_size/args.Image_patch_size)**2)
    os.environ['Text_T_max'] = str(256)
    os.environ['Image_HEAD_SIE'] = str(args.Image_embed_dims // args.Image_num_heads)
    
    print("--------------------------------------------")
    print("Image_output_cls_token: ", args.Image_output_cls_token)
    print("Image RWKV Version: ", args.my_testing_image)
    print("Text RWKV Version: ", args.my_testing)
    print("--------------------------------------------")
    
    from model import Text_RWKV, VRWKV6, get_model, tokenize, image_transform
    from model.utils import WarperCLIP_V_T_RWKV_text_change_head, WarperCLIP_V_T_RWKV_method
    
    main(args)
