
import torch
import pickle

import os
import time
import json
import numpy as np
from collections import defaultdict
from speaker import Speaker
from mbert import mBERT

from utils import read_vocab,write_vocab,build_vocab,Tokenizer,padding_idx,timeSince, read_img_features
import utils
from env import R2RBatch
from agent import Seq2SeqAgent
from eval import Evaluation
import warnings
warnings.filterwarnings("ignore")


from tensorboardX import SummaryWriter
from transformers import BertTokenizer

import sys
import random
import math

import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from utils import padding_idx, add_idx, Tokenizer
from collections import defaultdict

from transformers import BertModel, BertConfig, AdamW, get_linear_schedule_with_warmup

import heapq

import CLIP.clip as clip



#-------Main---------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = 'mbert'

path = 'snap/encoder1/state_dict/best_val_unseen_loss'

tok = BertTokenizer.from_pretrained('bert-base-multilingual-cased')


if model == 'language':
    encoder = mBERT("", "", tok, 0)

    states = torch.load(path)

    state = encoder.encoder.state_dict()
    model_keys = set(state.keys())
    load_keys = set(states['encoder']['state_dict'].keys())
    if model_keys != load_keys:
        print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
    state.update(states['encoder']['state_dict'])
    encoder.encoder.load_state_dict(state)
elif model == 'mbert':
    model_config = BertConfig.from_pretrained("bert-base-multilingual-cased", return_dict=True)
    encoder = BertModel.from_pretrained("bert-base-multilingual-cased", config=model_config).to(device)
elif model == 'clip':
    encoder, preprocess = clip.load("ViT-B/32", device=device)

# splits = ['train', 'val_seen', 'val_unseen']
splits = ['val_unseen']
# splits = ['train']

for split in splits:
    with open('../RXR/rxr-data/rxr_%s_guide_multi.jsonl' % split) as f:
        new_data = json.load(f)
    f.close()

    word_dict = defaultdict(list)

    features = dict()
    path_length = dict()
    path_id = dict()
    print(len(new_data))
    print("Generating Text features")
    for i, data in enumerate(new_data):
        print(i)
        if data['language'].split("-")[0] != "en":
            continue
        encoding = tok(data['instruction'], padding='max_length', truncation=True, max_length=160)
        seq = encoding['input_ids']
        words = tok.tokenize(data['instruction'])
        attention_mask = encoding['attention_mask']
        input = torch.from_numpy(np.array(seq)).unsqueeze(0).to(device)
        attn = torch.from_numpy(np.array(attention_mask)).unsqueeze(0).to(device)
        if model == 'language':
            text_features = encoder.encoder(input, attention_mask=attn)
        elif model == 'mbert':
            text_features = encoder(input, attention_mask=attn)
        words_representation = F.normalize(text_features.last_hidden_state[:,1:len(words)+1,:].squeeze(0), dim=1).detach().cpu().numpy()
        length, _ = words_representation.shape
        for j, word in enumerate(words):
            # print(words_representation.shape)
            if j >= length:
                break
            word_dict[word].append(words_representation[j])

    for word, repre in word_dict.items():
        word_dict[word] = sum(repre) / len(repre)
        word_dict[word] = word_dict[word] / np.linalg.norm(word_dict[word])

    test_list = ['kitchen', 'table', 'clothes', 'chair', 'stair', 'sink', 'fire', 'sofa', 'living', 'bed']
    print("Computing Similarity")
    print(len(word_dict))
    results = dict()
    i = 0
    count = 0
    for word in test_list:
        if word not in word_dict:
            continue
        print("Processing:", word)
        target = word_dict[word]
        topN = [(0, "0")] * 5
        heapq.heapify(topN)
        for can, v in word_dict.items():
            if can == word:
                continue
            # if "#" in can:
            #     continue
            sim = np.matmul(target, v.T)
            if sim > topN[0][0]:
                heapq.heapreplace(topN, (sim.item(), can))

        results[word] = topN

    print(results)

    # print(word_dict.keys())
    # print("kitchen" in word_dict)
    # print(tok.tokenize("fridge"))
    # print(tok.tokenize("washroom"))
    # print(tok.tokenize("wash basin"))
    # print("table" in word_dict)
    # print("desk" in word_dict)
    # print(tok.tokenize("table"))
    # print(tok.tokenize("desk"))
    # clothes
    # print("fridge" in word_dict)
    # print("washroom" in word_dict)