from utils import *
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter
from tqdm import tqdm
import os
import re
import json
import random
from sumeval.metrics.rouge import RougeCalculator
from model import *
import time
import argparse

def exec_cmd(cmd):  
    r = os.popen(cmd)  
    text = r.read()  
    r.close()  
    return text 

def arg_parser() :

	parser = argparse.ArgumentParser(description='')

	parser.add_argument('--train_path', dest='train_path', type=str, default="simple/train.json")
	parser.add_argument('--valid_path', dest='valid_path', type=str, default="simple/valid.json")
	parser.add_argument('--test_path', dest='test_path', type=str, default="simple/test.json")
	parser.add_argument('--vocab_path', dest='vocab_path', type=str, default="simple/vocab.json")
	parser.add_argument('--dir', dest='logdir', type=str, default='test')

	parser.add_argument('--batch_size', dest='batch_size', type=int, default=128)
	parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=128)
	parser.add_argument('--embed_size', dest='embed_size', type=int, default=128)
	parser.add_argument('--dropout', dest='dropout', type=float, default=0)
	parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=1e-4)
	parser.add_argument('--truncate', dest='truncate', type=int, default=200)

	args = parser.parse_args()

	return args

def bleu_cal(outputs, idx2word) :
	predicted = []
	for vec in outputs :
		selected  = torch.argmax(vec, dim=1)
		sen, gold = [], []
		for w1 in selected :
			w1 = w1.item()
			if idx2word[str(w1)] == '</S>' :
				break
			sen.append(idx2word[str(w1)])
		predicted.append(' '.join(sen))
	return predicted

def test(mode, test_iter, word2idx, idx2word, en_vocab, de_vocab, logdir, epoch, pre_file) :

	with torch.no_grad() :
		f = open('{}/predict-{}.txt'.format(logdir, epoch), "w")
		for batch in test_iter.gen_batch_data() :
			src = torch.tensor(batch.src, dtype=torch.int64).cuda()
			src_len = torch.tensor(batch.src_len, dtype=torch.int64)
			tgt = torch.tensor(batch.tgt, dtype=torch.int64).cuda()
			graph_src = torch.tensor(batch.graph_x, dtype=torch.int64).cuda()
			graph_src_len = torch.tensor(batch.graph_x_len, dtype=torch.int64)
			center = torch.tensor(batch.center, dtype=torch.int64)

			outputs = model(src, src_len, tgt, batch.graph, graph_src, graph_src_len, center, 'test')
			outputs = torch.log(outputs)

			predicted = bleu_cal(outputs, idx2word)
			f.write('\n'.join(predicted) + '\n')
		f.close()

		ref_file = '{}/predict-{}.txt'.format(logdir, epoch)
		bleu_info = exec_cmd("perl multi-bleu.perl {} < {}".format(ref_file, pre_file))
		# bleu_info = os.system("perl multi-bleu.perl {} < {}".format(ref_file, pre_file))
		return bleu_info

def train(model, train_iter, valid_iter, test_iter, word2idx, idx2word, en_vocab, de_vocab, logdir, weight_decay) :

	model.cuda()
	optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay)
	criterion = nn.NLLLoss()

	if os.path.exists(logdir) == False :
		os.mkdir(logdir)

	log = open("{}/log.txt".format(logdir), "w")

	Epoches = 50
	log.write("Traning init: ".format(time.asctime(time.localtime(time.time()))))
	print("Traning init: ", time.asctime(time.localtime(time.time())))
	for epoch in range(Epoches) :
		epoch_loss = 0
		total_step = 0
		for batch in train_iter.gen_batch_data():

			src = torch.tensor(batch.src, dtype=torch.int64).cuda()
			src_len = torch.tensor(batch.src_len, dtype=torch.int64)
			tgt = torch.tensor(batch.tgt, dtype=torch.int64).cuda()
			graph_src = torch.tensor(batch.graph_x, dtype=torch.int64).cuda()
			graph_src_len = torch.tensor(batch.graph_x_len, dtype=torch.int64)
			center = torch.tensor(batch.center, dtype=torch.int64)

			outputs = model(src, src_len, tgt, batch.graph, graph_src, graph_src_len, center, 'train')
			outputs = torch.log(outputs)
			optimizer.zero_grad()
			loss = criterion(outputs.contiguous().view(-1, en_vocab), tgt[:, 1:].contiguous().view(-1))
			loss.backward()
			optimizer.step()

			epoch_loss += loss.item()
			total_step += 1
			# if total_step % 100 == 0 :
			# 	print(time.asctime(time.localtime(time.time())), end='    ')
			# 	print("step:%d, train loss:%f" % (total_step, epoch_loss / total_step))

		print(time.asctime(time.localtime(time.time())), end='    ')
		print("epoch:{}, train loss:{}".format(epoch, epoch_loss / total_step))
		log.write("epoch:{}, train loss:{}\n".format(epoch, epoch_loss / total_step))

		torch.save(model.state_dict(), "./{}/model_{}epoch".format(logdir, epoch))

		pre_file = '{}/valid-golden.txt'.format(logdir)
		if not os.path.exists(pre_file) :
			with open(pre_file, 'w') as fp :
				for data in valid_iter.dataset :
					fp.write(data["comment"] + '\n')
			fp.close()

		bleu_info = test(model, valid_iter, word2idx, idx2word, en_vocab, de_vocab, logdir, epoch, pre_file)
		print(bleu_info)
		log.write(bleu_info)

		pre_file = '{}/test-golden.txt'.format(logdir)
		if not os.path.exists(pre_file) :
			with open(pre_file, 'w') as fp :
				for data in test_iter.dataset :
					fp.write(data["comment"] + '\n')
			fp.close()

		bleu_info = test(model, test_iter, word2idx, idx2word, en_vocab, de_vocab, logdir, epoch, pre_file)
		print(bleu_info)
		log.write(bleu_info + '\n')

if __name__ == '__main__':
	
	seed = 4
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)

	args = arg_parser()

	print("loading data.")
	train_iter = DataLoader(args.train_path, args.vocab_path, args.batch_size, args.truncate, mode='train')
	valid_iter = DataLoader(args.valid_path, args.vocab_path, args.batch_size, args.truncate, mode='valid')
	test_iter = DataLoader(args.test_path, args.vocab_path, args.batch_size, args.truncate, mode='test')

	print("loading vocab.")
	word2idx, idx2word, en_vocab, de_vocab = train_iter.word2idx, train_iter.idx2word, train_iter.en_vocab, train_iter.de_vocab

	print("building model.")
	model = Graph2Seq(en_vocab, de_vocab, args.embed_size, args.hidden_size, layers=1, dropout=args.dropout, padding_idx=word2idx["<PAD>"])

	print("start training.")
	train(model, train_iter, valid_iter, test_iter, word2idx, idx2word, en_vocab, de_vocab, args.logdir, args.weight_decay)





