from utils import *
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter
from tqdm import tqdm
import os
import re
import json
import random
from sumeval.metrics.rouge import RougeCalculator
from model import *
import time
import argparse

def exec_cmd(cmd):  
    r = os.popen(cmd)  
    text = r.read()  
    r.close()  
    return text 

def arg_parser() :

	parser = argparse.ArgumentParser(description='')

	parser.add_argument('--test_path', dest='test_path', type=str, default="simple/test.json")
	parser.add_argument('--vocab_path', dest='vocab_path', type=str, default="simple/vocab.json")
	parser.add_argument('--dir', dest='logdir', type=str, default='test')
	parser.add_argument('--model', dest='model', type=str)

	parser.add_argument('--batch_size', dest='batch_size', type=int, default=128)
	parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=128)
	parser.add_argument('--embed_size', dest='embed_size', type=int, default=128)
	parser.add_argument('--dropout', dest='dropout', type=float, default=0)
	parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=1e-4)
	parser.add_argument('--truncate', dest='truncate', type=int, default=200)

	args = parser.parse_args()

	return args

def bleu_cal(outputs, idx2word) :
	predicted = []
	for vec in outputs :
		selected  = torch.argmax(vec, dim=1)
		sen, gold = [], []
		for w1 in selected :
			w1 = w1.item()
			if idx2word[str(w1)] == '</S>' :
				break
			sen.append(idx2word[str(w1)])
		predicted.append(' '.join(sen))
	return predicted

def test(model, test_iter, word2idx, idx2word, en_vocab, de_vocab) :

	with torch.no_grad() :
		f = open('case_pred.txt', "w")
		for batch in test_iter.gen_batch_data() :
			src = torch.tensor(batch.src, dtype=torch.int64).cuda()
			src_len = torch.tensor(batch.src_len, dtype=torch.int64)
			tgt = torch.tensor(batch.tgt, dtype=torch.int64).cuda()
			graph_src = torch.tensor(batch.graph_x, dtype=torch.int64).cuda()
			graph_src_len = torch.tensor(batch.graph_x_len, dtype=torch.int64)
			center = torch.tensor(batch.center, dtype=torch.int64)

			outputs = model(src, src_len, tgt, batch.graph, graph_src, graph_src_len, center, 'test')
			outputs = torch.log(outputs)

			predicted = bleu_cal(outputs, idx2word)
			f.write('\n'.join(predicted) + '\n')
		f.close()
		return 

if __name__ == '__main__':
	
	seed = 4
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)

	args = arg_parser()

	print("loading data.")
	test_iter = DataLoader(args.test_path, args.vocab_path, args.batch_size, args.truncate, mode='test')

	print("loading vocab.")
	word2idx, idx2word, en_vocab, de_vocab = test_iter.word2idx, test_iter.idx2word, test_iter.en_vocab, test_iter.de_vocab

	print("building model.")
	model = Graph2Seq(en_vocab, de_vocab, args.embed_size, args.hidden_size, layers=1, dropout=args.dropout, padding_idx=word2idx["<PAD>"])
	model.cuda()
	model.load_state_dict(torch.load(args.model))
	print("loaded")

	print("start training.")
	test(model, test_iter, word2idx, idx2word, en_vocab, de_vocab)



