import copy
from models.model_builder import build_model
from utils.data_loader import get_DataLoader, get_vocab
from tqdm import tqdm
from rouge import Rouge
from utils.data_utils import load_pretrained_embedding, make_tgt, make_src, parse_model_path_cfg, parse_data_path_cfg
from models.decode import BeamSearch

path = parse_data_path_cfg()
model_opt = parse_model_path_cfg()

gpu = True if model_opt["gpu"] == 'True' else False
shard_size = 0
src_max_len = 200
tgt_max_len = 52
abs_max_sent_num = 49
node_max_neighbor = 736
batch_size = int(model_opt["batch_size"])
save_results = True

word2index = get_vocab(path)
# word_embeddings = load_pretrained_embedding(path, word2index)
word_embeddings = None
_, test_iter = get_DataLoader(path, word2index, batch_size, True)
model = build_model(model_opt, len(word2index), gpu, word_embeddings, checkpoint=True)
rouge = Rouge()

tgt_vocab = {"stoi": copy.deepcopy(word2index), "itos": {}}
for w in word2index:
    tgt_vocab["itos"][word2index[w]] = w

beam_search = BeamSearch(model, gpu, tgt_max_len, tgt_vocab)

rouge_1 = 0.0
rouge_2 = 0.0
rouge_l = 0.0
count = 0
result = []
for di, data in tqdm(enumerate(test_iter)):
    abstract_src_map = make_src(data, 'abs_src_map', gpu)
    context_src_map = make_src(data, 'context_src_map', gpu)
    decoded_batch, tgt = beam_search.decode(data, src_max_len, abs_max_sent_num, abstract_src_map,
                                            context_src_map, node_max_neighbor)
    for i in range(len(tgt)):
        hypothesis = ""
        for w in decoded_batch[i][0]:
            hypothesis = hypothesis + tgt_vocab["itos"][w] + " "
        hypothesis = hypothesis.rstrip()
        reference = ""
        for w in tgt[i]:
            reference = reference + tgt_vocab["itos"][w] + " "
        reference = reference.rstrip()
        scores = rouge.get_scores(hypothesis, reference)
        rouge_1 += scores[0]["rouge-1"]["f"]
        rouge_2 += scores[0]["rouge-2"]["f"]
        rouge_l += scores[0]["rouge-l"]["f"]
        if save_results:
            instance = {'hypothesis': hypothesis, 'reference': reference, 'rouge-1': scores[0]["rouge-1"]["f"],
                        'rouge-2': scores[0]["rouge-2"]["f"], 'rouge-l': scores[0]["rouge-l"]["f"]}
            result.append(instance)
        count += 1

print("rouge-1 scores: " + str(rouge_1 / count))
print("rouge-2 scores: " + str(rouge_2 / count))
print("rouge-l scores: " + str(rouge_l / count))

if save_results:
    import json
    with open("output.json", "w") as f:
        json.dump(result, f)
    f.close()