import torch

import os
import time
import json
import numpy as np
from collections import defaultdict


import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from collections import defaultdict


import heapq

import CLIP.clip as clip


# Generate image features using CLIP

import torch
import clip
from PIL import Image
import os
from transformers import BertTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Generate Image Features
path = "../views_img"
files= os.listdir(path)
features = dict()
for scan in files:
    if os.path.isdir(path+"/"+scan):
        print("scan:",scan)
        views = os.listdir(path+"/"+scan)
        for view in views:
            imgs = os.listdir(path+"/"+scan+"/"+view)
            image_features = None
            for img in imgs:
                image = preprocess(Image.open(path+"/"+scan+"/"+view+"/"+img)).unsqueeze(0).to(device)
                image_feature = model.encode_image(image)
                if image_features is None:
                    image_features = image_feature
                else:
                    image_features = torch.cat((image_features, image_feature),0)
            with open("./CLIP_features/"+scan+"_"+view+".npy", "wb") as f:
                np.save(f, image_features.detach().cpu().numpy())
            f.close()

# Generate Language Features
with open('../RXR/rxr-data/rxr_train_guide_dep.jsonl') as f:
    new_data = json.load(f)
f.close()

tok = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

features = dict()
# path_length = dict()
# path_id = dict()
ins_index_data = dict()
print(len(new_data))
print("Generating Text features")
for i, data in enumerate(new_data):
    print(i)
    if data['language'] == 'te-IN' or data['language'] == 'hi-IN':
        continue
    encoding = tok(data['instruction'], padding='max_length', truncation=True, max_length=80)
    seq = encoding['input_ids']
    text_seq = tok.convert_ids_to_tokens(seq)
    # print(text_seq)
    index = text_seq.index('[SEP]')
    # print(text_seq[1:index])
    text = clip.tokenize(text_seq[1:index]).to(device)
    text_features = model.encode_text(text)
    # print(text_features.shape)
    representation = torch.mean(text_features, dim=0).unsqueeze(0)
    # print(representation.shape)

    features[data['instruction_id']] = F.normalize(representation, dim=1).detach().cpu().numpy()
    # path_length[data['instruction_id']] = len(data['path'])
    # path_id[data['instruction_id']] = data['path_id']
    ins_index_data[data['instruction_id']] = data


# Load Visual Features
path = "./CLIP_features"
files= os.listdir(path)
img_features = dict()
for feature in files:
    if feature[-4:] == '.npy':
        with open("./CLIP_features/"+feature, "rb") as f:
            img_features[feature[:-4]] = np.load(f)
        f.close()


print("Computing Similarity")
results = dict()
i = 0
for ins, query in features.items():
    print(i)
    i+=1
    topN = [(0,"0"), (0,"0")]
    heapq.heapify(topN)
    scan_query = ins_index_data[ins]
    path_query =
    for id, feature in features.items():
        scan = ins_index_data[id]
        if path_id[id] == path_id[ins] or path_length[id] != path_length[ins]:
            continue
        # print(query)
        # print(feature.T)
        sim = np.matmul(query, feature.T)
        # print(sim)
        if sim > topN[0][0]:
            heapq.heapreplace(topN, (sim.item(), id))

    # exit()
    results[ins] = topN





