import os

import time
import argparse
import numpy as np
import torch
from sklearn.utils import shuffle
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
from utils import read_data, read_imbdata
from modcnn import modCNN
from sklearn.metrics import precision_score, accuracy_score
from sklearn.metrics import f1_score
from itertools import chain

import nltk
nltk.download('punkt')

'''
This script is the training script for the adapted CNN model
'''

cuda = torch.device('cuda')
torch.manual_seed(210)
torch.cuda.manual_seed(123)

data_file = './set2_labeled.txt'

data  = read_imbdata(data_file)

data["vocab"] = sorted(list(
	                set([w for sent in data["train_x"]+data["dev_x"]+data["test_x"] for w in sent])))

data["classes"] =sorted(list(set(data["train_y"])))
data["word_to_idx"]={w:i for i, w in enumerate(data["vocab"])}
data["idx_to_word"]={i:w for i, w in enumerate(data["vocab"])}


parser = argparse.ArgumentParser(description="----[CNN classifier training]------")
parser.add_argument("--mode", default="train", help="train: train (with test) a model / test: test saved models")
parser.add_argument("--model", default="non-static", help="available models: rand, static, non-static, multichannel")
parser.add_argument("--save_model", default=False, action='store_true', help="whether saving model or not")
parser.add_argument("--early_stopping", default=False, action='store_true', help="whether to apply early stopping")
parser.add_argument("--epoch", default=30, type=int, help="number of max epoch")
parser.add_argument("--learning_rate", default=1, type=float, help="learning rate")

options = parser.parse_args()

# w_neg = len(data['train_y'])/(data['train_y'].count(0))
# w_neu = len(data['train_y'])/(data['train_y'].count(1))
# w_pos = len(data['train_y'])/(data['train_y'].count(2))

params={
     "model": options.model,
     "save_model": options.save_model,
     "early_stopping": options.early_stopping,
     "epoch": options.epoch,
     "learning_rate": options.learning_rate,
     "max_sent_len": max([len(sent) for sent in data["train_x"]+data["dev_x"]+data["test_x"]]),
     "batch_size": 50,
     "word_dim": 300,
     "vocab_size": len(data["vocab"]),
     "class_size": len(data["classes"]),
     "filters": [3,4,5],
     "filter_num": [100,100,100],
     "dropout_prob": 0.5,
     "norm_limit": 3,
     "use_cuda": True,
	
	}

print('Step 1: Model initialization complete')
##### load modified word vectors and then make adjustments for padding########
path_to_word_vectors = './domain_beauty_input.txt'
#path_to_word_vectors = '/Users/prathyusha/neuralCodes/modcnn/glove_mr1k_input.txt'
wobj = open(path_to_word_vectors,'r')
mod_word_vecs = wobj.readlines()
mod_word_dict = {}
for i in range(len(mod_word_vecs)):
	word_vec_obj = mod_word_vecs[i].rstrip()
	word, vec = word_vec_obj.split(' ',1)
	mod_word_dict[word] = np.fromstring(vec, sep=' ')

wv_mat = []
for i in range(len(data["vocab"])):
	word = data["idx_to_word"][i]
	if word in list(mod_word_dict.keys()):
		wv_mat.append(mod_word_dict[word])
	else:
		wv_mat.append(np.random.uniform(-0.01,0.01,600).astype("float32"))

wv_mat.append(np.random.uniform(-0.01,0.01,600).astype("float32"))
wv_mat.append(np.zeros(600).astype("float32"))

wv_mat = np.array(wv_mat)
wv_mat = torch.from_numpy(wv_mat)

half_dims = params["word_dim"]
num_dims = params["word_dim"]*2
l1 = range(half_dims)
l2 = range(half_dims,num_dims)
l = zip(l1,l2)
indx = list(chain(*l))
indx = torch.LongTensor(indx)

wv_mat = Variable(wv_mat.view(len(data["vocab"])+2,1,params["word_dim"]*2))
wv_mat = wv_mat[:,:,indx]

params["wv_mat"] = wv_mat.cuda()

model = modCNN(**params)
model.double()
model.cuda()

print('Step 2: All parameters loaded')
######################### define testing function #########
def test(data, model, params, mode="test"):
	model.eval()
	if mode == "dev":
		x,y = data["dev_x"], data["dev_y"]
	elif mode =="test":
		x,y = data["test_x"], data["test_y"]

	x = [[data["word_to_idx"][w] if w in data["vocab"] else params["vocab_size"] for w in sent]+
	[params["vocab_size"]+1]*(params["max_sent_len"]-len(sent))for sent in x]

	x = Variable(torch.cuda.LongTensor(x))
	y = [data["classes"].index(c) for c in y]
	
	pred = model(x).cpu().data.max(1)[1]
	
	temp = np.asarray(y) - np.asarray(pred)
	temp = temp.tolist()
	new_acc = temp.count(0)/len(temp)
	acc = accuracy_score(np.asarray(y),np.asarray(pred))
	# pres = precision_score(np.asarray(y),np.asarray(pred),average='macro')
	fscr = f1_score(np.asarray(y),np.asarray(pred),average='macro')

	return (new_acc,fscr)

######################## training model ########################
parameters = filter(lambda p:p.requires_grad, model.parameters())
optimizer = optim.RMSprop(parameters, lr=0.0005)
weight = torch.DoubleTensor(3).fill_(1)
criterion = nn.CrossEntropyLoss().cuda()
criterion.weight = torch.cuda.DoubleTensor([w_neg,w_neu,w_pos])

prev_dev_pres = 0
max_dev_pres = 0
max_test_pres = 0
total_loss = 0

print('Step 3: Begin training')
for  e in range(params["epoch"]):
	data["train_x"], data["train_y"] = shuffle(data["train_x"],data["train_y"])

	for i in range(0, len(data["train_x"]), params["batch_size"]):
		batch_range = min(params["batch_size"], len(data["train_x"])-i)

		batch_y = [data["classes"].index(c) for c in data["train_y"][i:i+batch_range]]

		batch_x = [[data["word_to_idx"][w] for w in sent]+
		[params["vocab_size"]+1]*(params["max_sent_len"]-len(sent))
		for sent in data["train_x"][i:i+batch_range]]

		batch_x = Variable(torch.cuda.LongTensor(batch_x))
		batch_y = Variable(torch.cuda.LongTensor(batch_y))
		batch_x.cuda()
		batch_y.cuda()

		optimizer.zero_grad()
		model.train()
		pred = model(batch_x)
		loss = criterion(pred, batch_y)
		loss.backward()
		total_loss = total_loss + loss.item()
		nn.utils.clip_grad_norm_(parameters, max_norm = params["norm_limit"])
		optimizer.step()

		dev_acc,dev_f = test(data, model, params, mode="dev")
		test_acc,test_f = test(data,model, params)

#		print("epoch:", e+1, "/dev_acc:", dev_pres,"/test_acc",test_pres)

		# if params["early_stopping"] and dev_pres <= prev_dev_pres:
		# 	print("early stoppin by dev_acc")
		# 	break
		# else:
		# 	prev_dev_pres = dev_pres

		# if dev_pres > max_dev_pres:
		# 	max_dev_pres = dev_pres
		# 	max_test_pres = test_pres

	# if e % 50 == 0:
	# 	for param_group in optimizer.param_groups:
	# 		param_group['lr'] = param_group['lr']/10


	if e % 1 == 0:
		print("epoch:",e+1,"test_metrics:", (test_acc, dev_acc),(test_f,dev_f))
		print(total_loss/len(batch_x))
	total_loss = 0	


est_kernel = model.conv1w.weight.cpu()
print("Estimated kernel,", est_kernel)


import pdb
pdb.set_trace()