import numpy as np
from tqdm import tqdm
import json
# import sys
# sys.path.append('buildpy36')
# import MatterSim
from eval import Evaluation
from utils import read_vocab,write_vocab,build_vocab,Tokenizer,padding_idx,timeSince, read_img_features
from transformers import BertTokenizer
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import os
import time
import json
import numpy as np
from collections import defaultdict
from speaker import Speaker
from mbert import mBERT
from visual import Visual

from utils import read_vocab,write_vocab,build_vocab,Tokenizer,padding_idx,timeSince, read_img_features
import utils
from env import R2RBatch
from agent import Seq2SeqAgent
from eval import Evaluation
from param import args

import warnings
warnings.filterwarnings("ignore")


from tensorboardX import SummaryWriter
from transformers import BertTokenizer

from transformers import BertModel, BertConfig, AdamW, get_linear_schedule_with_warmup



features = 'img_features/ResNet-152-imagenet.tsv'
feat_dict = read_img_features(features)
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
tok = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
evaluator = Evaluation(['val_unseen'], featurized_scans, tok)

with open("snap/agent42/submit_val_unseen.json") as f:
    results_baseline = json.load(f)
f.close()

with open("snap/agent61/submit_val_unseen.json") as f:
    results_clear = json.load(f)
f.close()

assert len(results_baseline) == len(results_clear)

results_baseline_dict = dict()
for result in results_baseline:
    results_baseline_dict[result['instr_id']] = [result['trajectory'], result['path_id']]

results_clear_dict = dict()
for result in results_clear:
    results_clear_dict[result['instr_id']] = [result['trajectory'], result['path_id']]

results_baseline_ordered = []
results_clear_ordered = []

for k, v in results_baseline_dict.items():
    results_baseline_ordered.append({'instr_id': k, 'trajectory': v[0], 'path_id': v[1]})
    clear_result = results_clear_dict[k]
    results_clear_ordered.append({'instr_id': k, 'trajectory': clear_result[0], 'path_id': clear_result[1]})

indexes = [i for i in range(len(results_baseline_ordered))]
sr_baseline = []
ndtw_baseline = []
sdtw_baseline = []
spl_baseline = []
sr_clear = []
ndtw_clear = []
sdtw_clear = []
spl_clear = []
ndtw_count = 0
sr_count = 0
for _ in tqdm(range(1000)):
    samples = np.random.choice(indexes, len(results_baseline_ordered))
    baseline = [results_baseline_ordered[i] for i in samples]
    clear = [results_clear_ordered[i] for i in samples]
    score_baseline, _ = evaluator.score(baseline)
    score_clear, _ = evaluator.score(clear)

    if score_clear['ndtw'] > score_clear['ndtw']:
        ndtw_count += 1
    if score_clear['success_rate'] > score_clear['success_rate']:
        sr_count += 1

    # for metric, val in score_baseline.items():
    #     if metric == 'success_rate':
    #         sr_baseline.append(val)
    #     elif metric == 'ndtw':
    #         ndtw_baseline.append(val)
    #     elif metric == 'sdtw':
    #         sdtw_baseline.append(val)
    #     elif metric == 'spl':
    #         spl_baseline.append(val)
    #
    # for metric, val in score_clear.items():
    #     if metric == 'success_rate':
    #         sr_clear.append(val)
    #     elif metric == 'ndtw':
    #         ndtw_clear.append(val)
    #     elif metric == 'sdtw':
    #         sdtw_clear.append(val)
    #     elif metric == 'spl':
    #         spl_clear.append(val)

# sr_baseline = np.array(sr_baseline)
# ndtw_baseline = np.array(ndtw_baseline)
# sdtw_baseline = np.array(sdtw_baseline)
# spl_baseline = np.array(spl_baseline)
# sr_clear = np.array(sr_clear)
# ndtw_clear = np.array(ndtw_clear)
# sdtw_clear = np.array(sdtw_clear)
# spl_clear = np.array(spl_clear)
#
# sr_diff = sr_clear - sr_baseline
# ndtw_diff = ndtw_clear - ndtw_baseline
# sdtw_diff = sdtw_clear - sdtw_baseline
# spl_diff = spl_clear - spl_baseline
#
# sr_mean = np.mean(sr_diff)
# sr_std = np.std(sr_diff)
# ndtw_mean = np.mean(ndtw_diff)
# ndtw_std = np.std(ndtw_diff)
# sdtw_mean = np.mean(sdtw_diff)
# sdtw_std = np.std(sdtw_diff)
# spl_mean = np.mean(spl_diff)
# spl_std = np.std(spl_diff)
#
# print("SR:", sr_mean, sr_std)
# print("NDTW", ndtw_mean, ndtw_std)
# print("SDTW:", sdtw_mean, sdtw_std)
# print("SPL", spl_mean, spl_std)

print("ndtw", ndtw_count / 5000)
print("sr", sr_count / 5000)