import heapq
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import os
import time
import json
import numpy as np
from collections import defaultdict
from speaker import Speaker
from mbert import mBERT
from visual import Visual

from utils import read_vocab,write_vocab,build_vocab,Tokenizer,padding_idx,timeSince, read_img_features
import utils
from env import R2RBatch
from agent import Seq2SeqAgent
from eval import Evaluation
from param import args

import warnings
warnings.filterwarnings("ignore")


from tensorboardX import SummaryWriter
from transformers import BertTokenizer


from visual import Visual
from utils import read_img_features

# Generate image features using CLIP

import torch
from PIL import Image
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

from param import args
import numpy as np

class Env():
    ''' A simple wrapper for a batch of MatterSim environments,
        using discretized viewpoints and pretrained features '''

    def __init__(self, feature_store=None, batch_size=100):

        self.feature_size = 2048

train_env = Env()
encoder = Visual(train_env, "", None, 7000)

encoder.load("snap/visual18/state_dict/best_val_unseen_loss")

#ResNet features
features = 'img_features/ResNet-152-imagenet.tsv'
feat_dict = read_img_features(features)

args.load_visual = "snap/visual18/state_dict/best_val_unseen_loss"
# Generate Image Features
path = "../views_img"
files= os.listdir(path)
features = dict()
for scan in files:
    if os.path.isdir(path+"/"+scan):
        print("scan:",scan)
        views = os.listdir(path+"/"+scan)
        for view in views:
            imgs = os.listdir(path+"/"+scan+"/"+view)
            resnet_feature = feat_dict[scan + '_' + view]
            assert len(resnet_feature) == 36
            assert len(imgs) == 36
            image_feature = encoder.visual_encoder(torch.from_numpy(resnet_feature).cuda())

            with open("./learned_features/"+scan+"_"+view+".npy", "wb") as f:
                np.save(f, image_feature.detach().cpu().numpy())
            f.close()