import json
import requests
# from transformers import CLIPProcessor, CLIPModel
import torch
import clip
from PIL import Image

def image_enhancement(image_features, image_patch_features, text_features):
    # image_features [1,512], image_patch [50,512] , text [77,512]
    raw_text_features = text_features
    raw_image_patch_features = image_patch_features
    text_features /= text_features.norm(dim=-1, keepdim=True)
    image_patch_features_norm =  image_patch_features / image_patch_features.norm(dim=-1, keepdim=True)
    text_to_patch_sim = 100 * image_patch_features_norm @ text_features.half().T  
    text_to_patch_sim = text_to_patch_sim.softmax(dim=1)
    mean_sim = torch.mean(text_to_patch_sim, dim=0, keepdim=True)     # N,P,T --> N,P,1
    #weight_images = mean_sim.T @ image_patch_features # N,P,D *+ N,P,1  --> N,1,D
    weight_images = mean_sim @ raw_text_features
    res = torch.add(image_features, weight_images)    # N,D
    return res

# def text_fusion_feature(text_a, text_b):
#     return torch.cat((text_a,text_b), dim=0).sum(dim=0,keepdim=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

with open("/root/dial2image/train_kw_result_gpt3_all.json", 'r+', encoding='utf-8') as outputtt:
    train_file = json.load(outputtt)
    outputtt.close()

with open("/root/dial2image/test_gpt35_kws_to_sentence.json", 'r+', encoding='utf-8') as output:
    test_file = json.load(output)
    output.close()

keys = list(test_file.keys())

retrieve_res = {}

images = []
photo_descs = []
sentence_descs = []
keyword_descs = []
train_sentence_descs = []
train_keyword_descs = []

for item in test_file:
    p_desc = test_file[item]['photo_description']
    photo_text = p_desc[p_desc.find('Objects'):].lower().strip().replace("objects in the photo: ",'')
    photo_text_tokens = clip.tokenize(photo_text, truncate=True).to(device)
    photo_text_tokens_features = model.encode_text_tokens(photo_text_tokens)
    image = preprocess(
        Image.open("/root/dial2image/test_image/" + test_file[item]['photo_id'].split('/')[-1] + ".jpg")).unsqueeze(
        0).to(device)
    with torch.no_grad():
        image_patch_features, image_features = model.encode_image(image)
        images_enhanced = image_enhancement(image_features, image_patch_features, photo_text_tokens_features)
    images.append(images_enhanced)
    sentence_descs.append(test_file[item]['sentence_description'])
    keyword_descs.append(test_file[item]['dialogue_description'])
    photo_desc = p_desc[p_desc.find('Objects'):]
    photo_descs.append(photo_desc)
    
for item in train_file:
    train_sentence_descs.append(train_file[item]['sentence_description'])
    train_keyword_descs.append(train_file[item]['dialogue_description'])
    
with torch.no_grad():
    train_text_sentence_desc = clip.tokenize(train_sentence_descs, truncate=True).to(device)
    train_text_sentence_desc_features = model.encode_text(train_text_sentence_desc)
    train_text_keyword_desc = clip.tokenize(train_keyword_descs, truncate=True).to(device)
    train_text_keyword_desc_features = model.encode_text(train_text_keyword_desc)
    train_text_fusion_features = train_text_sentence_desc_features + train_text_keyword_desc_features
    train_text_fusion_features /= train_text_fusion_features.norm(dim=-1, keepdim=True)
    train_num = len(train_file)

with torch.no_grad():
    text_photo_desc = clip.tokenize(photo_descs, truncate=True).to(device)
    text_photo_desc_features = model.encode_text(text_photo_desc)
    text_photo_desc_features /= text_photo_desc_features.norm(dim=-1, keepdim=True)
    text_sentence_desc = clip.tokenize(sentence_descs, truncate=True).to(device)
    text_sentence_desc_features = model.encode_text(text_sentence_desc)
    text_keyword_desc = clip.tokenize(keyword_descs, truncate=True).to(device)
    text_keyword_desc_features = model.encode_text(text_keyword_desc)
    text_fusion_features = text_sentence_desc_features + text_keyword_desc_features
    text_fusion_features /= text_fusion_features.norm(dim=-1, keepdim=True)
    #text_photo_desc_features = normalize_vector(text_photo_desc_features)
    images = torch.from_numpy(torch.cat(images).cpu().detach().numpy()).to(device)
    #images = normalize_vector(images)
    images /= images.norm(dim=-1, keepdim=True)
    # similarity_image_1 = (100.0 * text_fusion_features @ images.T).softmax(dim=1)
    # similarity_scene_1 = (100.0 * text_fusion_features @ text_photo_desc_features.T).softmax(dim=1)
    # similarity_base = similarity_image_1 + similarity_scene_1
    
    # similarity_image_2 = (100.0 * text_fusion_features @ images.T).softmax(dim=0)
    # similarity_scene_2 = (100.0 * text_fusion_features @ text_photo_desc_features.T).softmax(dim=0)
    # similarity_balanced = similarity_image_2 + similarity_scene_2

percent_num = 1000
print(percent_num)
for sda in range(5):
    flag = 0
    count_1 = 0
    count_5 = 0
    count_10 = 0
    for item in test_file:
        #print(flag)
        current_idx = int(flag)
        # print(current_idx)
        all_indices = torch.arange(train_num)
        all_indices = all_indices[all_indices != current_idx]
        indices = all_indices[torch.randperm(all_indices.size(0))[:percent_num-1]]
        current_text_fusion_features = train_text_fusion_features[indices]
        current_text_fusion_features = torch.cat((text_fusion_features[current_idx].unsqueeze(0), current_text_fusion_features), dim=0)
        similarity_image = (100.0 * current_text_fusion_features @ images.T).softmax(dim=0)
        similarity_scene = (100.0 * current_text_fusion_features @ text_photo_desc_features.T).softmax(dim=0)
        similarity_balanced = similarity_image + similarity_scene
        similarity = similarity_balanced[0, :]
        #max_similarity, _ = torch.min(torch.stack((similarity_base[int(flag), :], similarity_balanced[int(flag), :])), dim=0)
        #similarity = max_similarity
        #similarity = similarity_scene + similarity
        # similarity = similarity_vision
        similarity_5 = similarity
        similarity_10 = similarity
        values, indices = similarity.topk(1)
        vals, inds = similarity_5.topk(5)
        vs, ins = similarity_10.topk(10)
        res_set_1 = set()
        res_set_5 = set()
        res_set_10 = set()
        for value, index in zip(values, indices):
            res_set_1.add(test_file[keys[int(index)]]['photo_id'])

        for val, ind in zip(vals, inds):
            res_set_5.add(test_file[keys[int(ind)]]['photo_id'])

        for v, i in zip(vs, ins):
            res_set_10.add(test_file[keys[int(i)]]['photo_id'])

        # print(res_set_1)
        # print(res_set_5)
        # print(res_set_10)

        retrieve_temp = {}
        retrieve_temp['top@1'] = list(res_set_1)
        retrieve_temp['top@5'] = list(res_set_5)
        retrieve_temp['top@10'] = list(res_set_10)
        retrieve_res[item] = retrieve_temp

        if test_file[item]['photo_id'] in res_set_1:
            count_1 += 1

        if test_file[item]['photo_id'] in res_set_5:
            count_5 += 1

        if test_file[item]['photo_id'] in res_set_10:
            count_10 += 1
        flag += 1
    print(f'{sda}=============')
    print(f'Recall@1:{count_1 / 1000.0}')
    print(f'Recall@5:{count_5 / 1000.0}')
    print(f'Recall@10:{count_10 / 1000.0}')

#with open("test_retrieve_images_photochat.json", 'w+', encoding='utf-8') as output:
    #json.dump(retrieve_res, output)

