#numpy and theano
import theano
import theano.tensor as T
import theano.ifelse as ifelse
from theano import shared
from theano.tensor.shared_randomstreams import RandomStreams
import numpy as np
#common
from itertools import izip
import random
import logging
from collections import OrderedDict
import time
import json 
import scipy.io
import Optimizer
from Evaluator import Evaluator2, Evaluator5
from word2vec import load_word2vec_format

class RNN(object):
    def init_param(self):
        def shared_matrix(dim, name, W=0):
            matrix = self.srng.uniform(dim, low=-W, high=W, dtype=theano.config.floatX)
            f = theano.function([], matrix)
            return theano.shared(f(), name=name)
        
        numg, nump = self.numg, self.nump
        dv, dp = self.dimv, self.dimp
        self.W = shared_matrix((numg, dv, 2*dv+2*dp), 'W', self.uniform_range)
        self.b = shared_matrix((numg, dv), 'b', 0)
        self.Wr = shared_matrix((numg, 2*dv, dv+dp), 'Wr', self.uniform_range)
        self.br = shared_matrix((numg, 2*dv), 'br', 0)
        self.Wv = shared_matrix((self.n_words, dv), 'Wv', self.uniform_range)
        self.Ws = shared_matrix((dv, self.n_grained), 'Ws', self.uniform_range)
        self.params = [self.W, self.b, self.Wr, self.br, self.Wv, self.Ws]
        if self.dimp > 0:
            self.Wp = shared_matrix((nump, dp), 'Wp', self.uniform_range)
            self.params.append(self.Wp)
        if self.rntn == 1:
            self.T = shared_matrix((2*dv+2*dp, dv, 2*dv+2*dp), 'T', self.uniform_range)
            self.params.append(self.T)
        self.sum_grad = {}
        for param in self.params:
            self.sum_grad[param] = np.zeros_like(param.get_value())
        
        self.details = {'loss':[], 'loss_rec':[], 'loss_sen':[], 'loss_l2':[], \
                'acc_train':[], 'acc_dev':[], 'acc_test':[], 'avg_llk':[]}

    def __init__(self, numg, nump, n_words, dimv, dimp, n_grained, uniform_range, rntn, \
            normalize, w_rec, w_sen, w_l2, name, dropout, verbose):
        logging.info('RNN init: %s' % (locals()))
        self.rntn = rntn
        self.numg = numg
        self.nump = nump
        self.n_words = n_words
        self.dimv = dimv
        self.dimp = dimp
        self.n_grained = n_grained
        self.uniform_range = uniform_range
        seed=int(time.time())
        self.srng = RandomStreams(seed=seed)
        logging.info('seed %d' % seed)
        self.init_param()

        self.normalize = normalize
        self.name = name
        self.dropout = dropout
        self.verbose = verbose
        
        self.nodes = T.lmatrix()
        self.seq = T.lmatrix()
        self.solution = T.matrix()
        vectors, _ = theano.scan(lambda x: self.Wv[x[0]], sequences=self.nodes)
        tag_g, _ = theano.scan(lambda x: x[1], sequences=self.nodes)
        
        if self.dimp > 0:
            tag_p, _ = theano.scan(lambda x: self.Wp[x[2]], sequences=self.nodes)
        
        def encode(t, vectors, _loss_rec):
            x, y = vectors[t[0]], vectors[t[1]]
            g = tag_g[t[2]]
            
            v = T.concatenate([x, y])
            if self.dimp > 0:
                px, py = tag_p[t[0]], tag_p[t[1]]
                v = T.concatenate([x, y, px, py])
            
            z = T.tanh(self.b[g] + T.dot(self.W[g], v))
            if self.rntn:
                z = T.tanh(self.b[g] + T.dot(self.W[g], v) + \
                        T.tensordot(v, T.tensordot(self.T, v, [[2], [0]]), [[0], [0]]))

            if self.normalize:
                z = z/T.sqrt(T.sum(z**2)) 
            
            if w_rec > 0:
                if self.dimp == 0:
                    r = self.br[g] + T.dot(self.Wr[g], z)
                else:
                    pz = tag_p[t[2]]
                    r = self.br[g] + T.dot(self.Wr[g], T.concatenate([z, pz]))
                xr, yr = r[:self.dimv], r[self.dimv:]
                if self.normalize:
                    xr = xr/T.sqrt(T.sum(xr**2)) 
                    yr = yr/T.sqrt(T.sum(yr**2)) 
                loss_rec = T.sum((xr-x)**2) + T.sum((yr-y)**2)
            else:
                loss_rec = _loss_rec

            return T.set_subtensor(vectors[t[2]], z), loss_rec
        
        zero = theano.shared(np.array(0.0, dtype=theano.config.floatX))
        scan_result, _ = theano.scan(encode, sequences=[self.seq], outputs_info=[vectors, zero])
        self.loss_rec = T.sum(scan_result[1]) * w_rec
        
        pred = T.dot(scan_result[0][-1], self.Ws)
        self.pred = T.nnet.softmax(pred)
        
        self.loss_sen = -T.tensordot(self.solution, T.log(self.pred), axes=2) * w_sen
        self.loss_l2 = sum([T.sum(param**2) for param in self.params]) * w_l2
        self.loss = self.loss_rec + self.loss_sen + self.loss_l2
        
        logging.info('get grads...')
        grads = T.grad(self.loss, self.params)
        self.updates = OrderedDict()
        self.grad = {}
        for param, grad in zip(self.params, grads):
            g = theano.shared(np.asarray(np.zeros_like(param.get_value()), dtype=theano.config.floatX))
            self.grad[param] = g
            self.updates[g] = g + grad

        logging.info("compile func of encode_func")
        self.encode_func = theano.function(
                inputs = [self.nodes, self.seq, self.solution],
                outputs = [self.loss, self.loss_rec, self.loss_sen, self.loss_l2],
                updates = self.updates,
                on_unused_input='warn')

        logging.info("compile func of test")
        self.func_test = theano.function(
                inputs = [self.nodes, self.seq],
                outputs = self.pred,
                on_unused_input='warn')

    def fit(self, optimizer, lr, batch_size, start_epoch, end_epoch, train_data, dev_data, test_data):
        if start_epoch > 0:
            self.load(start_epoch - 1)
            with open('result/%s.txt' % self.name) as f:
                details = json.loads(f.read())
                for key, value in details.items():
                    self.details[key] = value[:start_epoch]
        
        epoch = start_epoch
        logging.info('start epoch: %s, lr: %s' % (start_epoch, lr))
        
        batch_n = (len(train_data)-1) / batch_size + 1
        if optimizer == 'SGD':
            self.optimizer = Optimizer.SGD(self.params, lr)
        else:
            assert optimizer == 'ADADELTA'
            self.optimizer = Optimizer.ADADELTA(self.params, lr)
        while epoch < end_epoch:
            random.shuffle(train_data)
            loss = self.train(epoch, train_data, batch_size, batch_n, lr)
            for key, value in loss.items():
                self.details[key].append(value)
            self.details['acc_train'].append(self.test(train_data))
            self.details['acc_dev'].append(self.test(dev_data))
            self.details['acc_test'].append(self.test(test_data))
            logging.info('train data %s, dev data: %s, test data:%s' % 
                    (self.details['acc_train'][-1], self.details['acc_dev'][-1], self.details['acc_test'][-1]))
            
            self.dump(epoch)
            with open('result/%s.txt' % self.name, 'w') as f:
                f.writelines(json.dumps(self.details))

            epoch += 1

    def onlytest(self, train_data, dev_data, test_data, end_epoch):
        for epoch in xrange(end_epoch):
            self.load(epoch)
            self.details['acc_train'].append(self.test(train_data))
            self.details['acc_dev'].append(self.test(dev_data))
            self.details['acc_test'].append(self.test(test_data))
            logging.info('train data %s, dev data: %s, test data:%s' % 
                    (self.details['acc_train'][-1], self.details['acc_dev'][-1], self.details['acc_test'][-1]))
            with open('result/%s.txt' % self.name, 'w') as f:
                f.writelines(json.dumps(self.details))
    
    def train(self, epoch_num, train_data, batch_size, batch_n, lr):
        st_time = time.time()
        loss_sum = np.array([0.0, 0.0, 0.0, 0.0])
        total_nodes = 0
        for batch in xrange(batch_n):
            start = batch * batch_size
            end = min((batch + 1) * batch_size, len(train_data))
            batch_loss, batch_total_nodes = self.do_train(train_data[start:end])
            loss_sum += batch_loss
            total_nodes += batch_total_nodes
            if self.verbose:
                logging.info("epoch %s batch %d: llk: %f loss sum: %s" % (epoch_num, batch, 
                    batch_loss[2]/batch_total_nodes, batch_loss))
        logging.info("epoch %s: loss sum: %.1f, loss rec: %.1f, loss sen: %.1f, loss l2: %.1f, llk: %f" % 
                (epoch_num, loss_sum[0], loss_sum[1], loss_sum[2], loss_sum[3],
                    loss_sum[2]/total_nodes))
        logging.info('time of train: %s' % (time.time() - st_time))
        return {'loss':loss_sum[0], 'loss_rec':loss_sum[1], \
                'loss_sen':loss_sum[2], 'loss_l2':loss_sum[3], \
                'avg_llk':loss_sum[2]/total_nodes}
    
    def do_train(self, train_data):
        eps0 = 1e-8
        batch_loss = np.array([0.0, 0.0, 0.0, 0.0])
        total_nodes = 0
        for nodes, seq, solution in train_data:
            batch_loss += np.array(self.encode_func(nodes, seq, solution))
            total_nodes += len(solution)
        for _, grad in self.grad.iteritems():
            grad.set_value(grad.get_value() / float(len(train_data)))
        self.optimizer.iterate(self.grad)
        return batch_loss, total_nodes

    def test(self, test_data, batch_size=128):
        size = len(test_data)
        if self.n_grained == 5:
            evaluator = Evaluator5(self.verbose)
        else:
            evaluator = Evaluator2(self.verbose)

        for nodes, seq, solution in test_data:
            pred = self.func_test(nodes, seq)
            evaluator.accumulate(solution[-1:], pred[-1:])
        ret = evaluator.statistic()
        return ret
    
    def dump(self, epoch):
        mdict = {}
        for param in self.params:
            val = param.get_value()
            mdict[param.name] = val
        scipy.io.savemat('mat/%s.%s' % (self.name, epoch), mdict=mdict)

    def load(self, epoch):
        mdict = scipy.io.loadmat('mat/%s.%s' % (self.name, epoch))
        for param in self.params:
            if len(param.get_value().shape) == 1:
                param.set_value(np.asarray(mdict[param.name][0], dtype=theano.config.floatX))
            if len(param.get_value().shape) >= 2:
                param.set_value(np.asarray(mdict[param.name], dtype=theano.config.floatX))
    
    def load_wordvec(self, fname, wordlist):
        dic = load_word2vec_format(fname, wordlist, self.dimv)
        logging.info('loaded word vectors. words: %s' % len(dic))
        
        not_found = 0
        Wv = self.Wv.get_value()
        for word, index in wordlist.items():
            if dic.has_key(word):
                if self.normalize:
                    Wv[index] = dic[word] / np.sqrt(np.dot(dic[word], dic[word]))
                else:
                    Wv[index] = dic[word]
            else:
                not_found += 1
        self.Wv.set_value(Wv)
        logging.info('loading wordvec... %s words not found.' % not_found)
        
def gen_wordlist(dataset):
    def dfs(tree):
        if tree.has_key('word'):
            return [tree['word'].lower()]
        else:
            return dfs(tree['children'][0]) + dfs(tree['children'][1])
    
    words = []
    for fname in ['train', 'dev', 'test']:
        for line in open('%s/%s.res' % (dataset, fname)):
            data = json.loads(line.strip())
            if len(data) > 0:
                words += dfs(data)
    
    words = list(set(words))
    wordlist = {key:value for value, key in enumerate(words)}
    
    return wordlist

def gen_postag(dataset, numg, nump):
    def dfs_all(tree):
        if tree.has_key('word'):
            return [tree['postag']]
        else:
            return dfs_all(tree['children'][0]) + dfs_all(tree['children'][1]) + [tree['postag']]
    def dfs_no_leaf(tree):
        if tree.has_key('word'):
            return []
        else:
            return dfs_no_leaf(tree['children'][0]) + dfs_no_leaf(tree['children'][1]) + [tree['postag']]
    
    def get_top_k(tags, k):
        dic = {}
        for tag in tags:
            dic[tag] = dic.get(tag, 0) + 1
        
        items = dic.items()
        items = sorted(items, key=lambda x:x[1], reverse=True)[:k]
        dic = {x[0]:index+1 for index, x in enumerate(items)}
        return dic

    postag_all = []
    postag_no_leaf = []
    for fname in ['train', 'dev', 'test']:
        for line in open('%s/%s.res' % (dataset, fname)):
            data = json.loads(line.strip())
            if len(data) > 0:
                postag_all += dfs_all(data)
                postag_no_leaf += dfs_no_leaf(data)
    
    tags = list(set(postag_all))
    postag_all = get_top_k(postag_all, nump)
    postag_no_leaf = get_top_k(postag_no_leaf, numg)
    postag = {tag:(postag_no_leaf.get(tag, 0), postag_all.get(tag, 0)) for tag in tags}

    return postag

def load_data(fname, wordlist, postag, numg, nump, sent_label, grained):
    result = []
    postag_g = [0] * (numg+1)
    postag_p = [0] * (nump+1)
    
    def get_rating(s):
        if grained == 2:
            if int(s) == 0: return [1.0, 0.0]
            if int(s) == 1: return [0.0, 1.0]
            return [0, 0]
        if grained == 5:
            if int(s) == 0: return [1.0, 0.0, 0.0, 0.0, 0.0]
            if int(s) == 1: return [0.0, 1.0, 0.0, 0.0, 0.0]
            if int(s) == 2: return [0.0, 0.0, 1.0, 0.0, 0.0]
            if int(s) == 3: return [0.0, 0.0, 0.0, 1.0, 0.0]
            if int(s) == 4: return [0.0, 0.0, 0.0, 0.0, 1.0]
    
    def dfs(tree, nodes, seq, solution):
        if tree.has_key('word'):
            now = len(nodes)
            tg, tp = postag[tree['postag']]
            nodes.append((wordlist[tree['word'].lower()], tg, tp))
            postag_p[tp] += 1
            solution.append(get_rating(tree['rating']))
            return now
        else:
            l = dfs(tree['children'][0], nodes, seq, solution)
            r = dfs(tree['children'][1], nodes, seq, solution)
            now = len(nodes)
            tg, tp = postag[tree['postag']]
            postag_g[tg] += 1
            postag_p[tp] += 1
            nodes.append((0, tg, tp))
            seq.append([l, r, now])
            solution.append(get_rating(tree['rating']))
            return now
    
    for line in open(fname, 'r'):
        data = json.loads(line.strip())
        if len(data) == 0: continue
        nodes, seq, solution = [], [], []
        dfs(data, nodes, seq, solution)
        if sent_label == 0:
            solution = [[0.0]*grained]*(len(solution)-1)+[solution[-1]]
        if sent_label == 1:
            solution = [[0.0]*grained]*(len(solution)-1)+[solution[-1]]
            for x in seq:
                solution[x[2]] = solution[-1]
        result.append((np.array(nodes), np.array(seq), np.array(solution, dtype=theano.config.floatX)))
    postag_g = [str(x) for x in postag_g]
    postag_p = [str(x) for x in postag_p]
    logging.info('number of postag_g: %s' % '/'.join(postag_g))
    logging.info('number of postag_p: %s' % '/'.join(postag_p))
    return result

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    #model
    parser.add_argument('name', help='the name of this model')
    parser.add_argument('--grained', type=int, default=5, choices=[2, 5], \
            help='label types')
    parser.add_argument('--model', choices=['rnn', 'rntn'], default='rnn', \
            help='rnn or rntn model')
    parser.add_argument('--dimv', type=int, default=25, \
            help='dimension for word vector and parse vector')
    parser.add_argument('--dimp', type=int, default=0, \
            help='dimension for tag vector, 0 for traditional RNN and TG-RNN, \
            otherwise for TE-RNN/RNTN')
    parser.add_argument('--numg', type=int, default=8, \
            help='the number of composition function, \
            0 for traditional RNN and TE-RNN/RNTN, otherwise for TG-RNN')
    parser.add_argument('--nump', type=int, default=30, \
            help='the number of useful tag vectors, please select 30')
    parser.add_argument('--uniform_range', type=float, default=0.1)
    parser.add_argument('--normalize', type=int, default=1, choices=[0, 1], \
            help='normalize or not, please use 1')
    parser.add_argument('-r', '--w_rec', type=float, default=0, \
            help='the weight of reconstruction loss, please use 0')
    parser.add_argument('-s', '--w_sen', type=float, default=1, \
            help='the weight of sentiment error loss')
    parser.add_argument('-l', '--w_l2', type=float, default=0.0, \
            help='the weight of l2 norm loss')
    #train
    parser.add_argument('lr', type=float, help='learning rate')
    parser.add_argument('--epoch_start', type=int, default=0, \
            help='start epoch, select 0 for training a new model')
    parser.add_argument('--epoch_end', type=int, default=30, help='end epoch')
    parser.add_argument('--batch', type=int, default=30, help='batch size')
    parser.add_argument('--wordvec', type=str, default='', help='the directory of word vector.')
    parser.add_argument('--dropout', type=int, default=0, \
            choices=[0, 1], help='whether to use dropout in training')
    parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'ADADELTA'], \
            help='optimizer for training')
    parser.add_argument('--verbose', type=int, choices=[0, 1], default=0, \
            help='verbose for output loss for each batch')
    #dataset
    parser.add_argument('--dataset', type=str, default='trees/', help='the directory of dataset')
    parser.add_argument('--fast', type=int, choices=[0, 1], default=0, \
            help='fast train with a few sentences')
    parser.add_argument('--label', type=int, default=2, choices=[0, 1, 2], \
            help='use all label(2), use root copy(1), use root only(0), please use 2')
    args = parser.parse_args()
    
    logging.basicConfig(
            filename = 'log/%s.log' % args.name, 
            level=logging.DEBUG, 
            format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', 
            datefmt='%a, %d %b %Y %H:%M:%S')
    
    wordlist= gen_wordlist(args.dataset)
    logging.info('the number of words: %s' % len(wordlist))

    postag = gen_postag(args.dataset, args.numg, args.nump)
    logging.info('numg: %s, nump: %s' % (args.numg, args.nump))
    
    load_data_args = [wordlist, postag, args.numg, args.nump, args.label, args.grained]
    logging.info('sentence label type: %s' % args.label)
    train_data = load_data('%s/train.res' % args.dataset, *load_data_args)
    dev_data = load_data('%s/dev.res' % args.dataset, *load_data_args)
    test_data = load_data('%s/test.res' % args.dataset, *load_data_args)
    if args.fast == 1:
        train_data = train_data[:700]
        dev_data = dev_data[:100]
        test_data = test_data[:200]
    logging.info('train data: %s, dev data: %s, test data: %s' % \
            (len(train_data), len(dev_data), len(test_data)))
    
    rnn = RNN(args.numg+1, args.nump+1, len(wordlist), args.dimv, args.dimp, \
            args.grained, args.uniform_range, args.model=='rntn', args.normalize, \
            args.w_rec, args.w_sen, args.w_l2, args.name, args.dropout, args.verbose)
    if args.wordvec != '':
        rnn.load_wordvec(args.wordvec, wordlist)
    rnn.fit(args.optimizer, args.lr, args.batch, args.epoch_start, args.epoch_end, \
            train_data, dev_data, test_data)

