#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
See the following papers for more information on neural translation models.
 * http://arxiv.org/abs/1409.3215
 * http://arxiv.org/abs/1409.0473
 * http://arxiv.org/abs/1412.2007
"""

from __future__ import print_function, division


import os
import sys
import time

import cPickle as pickle

import numpy as np
import tensorflow as tf

import data_utils2 as data_utils
import s2s_model

tf.app.flags.DEFINE_string('en_vocab','data/s2s.model.vocab.enc.txt','vocabulary of encoder')
tf.app.flags.DEFINE_string('de_vocab','data/s2s.model.vocab.dec.char.txt','vocabulary of decoder')

tf.app.flags.DEFINE_integer('beam_size',-1,'beam search size')
tf.app.flags.DEFINE_boolean('beam_search',False,'Set to True for beam search')
tf.app.flags.DEFINE_boolean('attention',True,'attention model or simple seq2seq model')


tf.app.flags.DEFINE_float('learning_rate',2e-4,'')
tf.app.flags.DEFINE_float('max_gradient_norm',3.0,'')
tf.app.flags.DEFINE_integer('batch_size',1,'')
tf.app.flags.DEFINE_integer('size',20,'')
tf.app.flags.DEFINE_integer('num_layers',1,'')
tf.app.flags.DEFINE_integer('num_epoch',5,'')
tf.app.flags.DEFINE_integer('num_samples', 0, '')
tf.app.flags.DEFINE_integer('num_per_epoch',-1,'')
tf.app.flags.DEFINE_string('buckets_dir','./bucket_dbs','sqlite3')
tf.app.flags.DEFINE_string('model_dir','./model','')
tf.app.flags.DEFINE_string('model_name','s2s.attn.tv3.model','')

tf.app.flags.DEFINE_string('test_file','./data/test/test.char.txt','testfile path')
tf.app.flags.DEFINE_float('diversity_rate',0.0,'diversity rate of gamma k algorithms')

tf.app.flags.DEFINE_integer('beam_test',-1,'test beam results')
tf.app.flags.DEFINE_boolean('test',-1,'test argmax results')
tf.app.flags.DEFINE_boolean("decode", False,"Set to True for interactive decoding.")

tf.app.flags.DEFINE_boolean("bidirection", False,"bidirection encoder")

FLAGS = tf.app.flags.FLAGS
buckets = [(100,60)
        #(50, 50)
    ]


def load_dictionary(fname):
    fr=open(data_utils.with_path(fname), 'r').read().split('\n')
    index2word =  {} # OrderedDict()
    word2index = {} # OrderedDict() 
    for id_and_word in fr:
        index,word=id_and_word.split('\t')
        word = word.decode("utf8").strip()
        index=int(index)
        index2word[index] = word
        word2index[word] = index
    dim=len(index2word.keys())
    return  word2index,index2word,dim

## gamma denotes diversity rate
# TF GPU Config
#print("gpu device:",FLAGS.gpu_device)
#os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_device
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.8
#config.log_device_placement = True
def create_model(session, forward_only,en_vocab_size,de_vocab_size):
    #dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = s2s_model.S2SModel(
        en_vocab_size,de_vocab_size,buckets,FLAGS.size,
        FLAGS.num_layers,FLAGS.max_gradient_norm,FLAGS.batch_size,
        FLAGS.learning_rate,num_samples=FLAGS.num_samples,forward_only=forward_only,
        beam_search=FLAGS.beam_search, beam_size=FLAGS.beam_size, attention=FLAGS.attention,gamma=FLAGS.diversity_rate,bidirection=FLAGS.bidirection)

    ckpt=tf.train.get_checkpoint_state(FLAGS.model_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" %ckpt.model_checkpoint_path)
        model.saver.restore(session,ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.global_variables_initializer())
    return model


def train():
    print('prepare data')
    en_word2index,_,en_dim = load_dictionary(FLAGS.en_vocab)
    de_word2index,_,de_dim = load_dictionary(FLAGS.de_vocab)
    bucket_dbs = data_utils.read_bucket_dbs(FLAGS.buckets_dir,buckets)
    bucket_sizes = []
    for i in range(len(buckets)):
        bucket_size = bucket_dbs[i].size
        bucket_sizes.append(bucket_size)
        print('bucket {} data {} terms'.format(i, bucket_size))
    total_size = sum(bucket_sizes)
    print('total data {} terms'.format(total_size))

    # initialize the Session
    with tf.Session(config=config) as sess:
        # Create model.
        print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
        print('en_dim:',en_dim)#,'de_dim:',de_dim)
        model = create_model(sess, forward_only=False,en_vocab_size=en_dim,de_vocab_size=de_dim)
        print('Build Finished.')
        buckets_scale = [
            sum(bucket_sizes[:i + 1]) / total_size
            for i in range(len(bucket_sizes))
        ]
        if not os.path.exists(FLAGS.model_dir):
            os.makedirs(FLAGS.model_dir)

        for epoch_index in range(FLAGS.num_epoch):
            print('Epoch {}:'.format(epoch_index))
            mname = FLAGS.model_name+"."+str(epoch_index)
            sample_trained = 0
            random_number = np.random.random_sample()
            bucket_id = min([
                i for i in range(len(buckets_scale))
                if buckets_scale[i] > random_number])

            time_start = time.time()
            bucket_db = bucket_dbs[bucket_id]         
            for data in bucket_db.get_batch_data(FLAGS.batch_size):

                encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                    data, bucket_id,en_word2index,de_word2index)
                #print('*'*90)
                _, step_loss, output = model.step(sess,encoder_inputs,decoder_inputs,
                                                  decoder_weights,bucket_id,False,FLAGS.beam_search)

                sample_trained += FLAGS.batch_size
                # batch_loss.append(step_loss)
                time_spend = time.time() - time_start
                time_start = time.time()
                print('%.3f\t%.d\t%.3f' % \
                    (time_spend, sample_trained, step_loss))

                if sample_trained % 50000 == 0:
                    model.saver.save(sess, os.path.join(FLAGS.model_dir, mname))
            print('\n')


def beam_test(count):
    beam_size = FLAGS.beam_size
    beam_search = FLAGS.beam_search
    en_word2index,_,en_dim = load_dictionary(FLAGS.en_vocab)
    de_word2index,de_index2word,de_dim = load_dictionary(FLAGS.de_vocab)

    lines=open(FLAGS.test_file,'r').read().split('\n')
    lines=[l.split("#TAB#") for l in lines]

    bucket_dbs = data_utils.read_bucket_dbs(FLAGS.buckets_dir,buckets)
    bucket_sizes = []
    for i in range(len(buckets)):
        bucket_size = bucket_dbs[i].size
        bucket_sizes.append(bucket_size)
        print('bucket {} data {} terms'.format(i, bucket_size))
    total_size = sum(bucket_sizes)
    print('total {} terms'.format(total_size))

    fw=open('diversity_rate_%.2f.log'% FLAGS.diversity_rate,'w')
    Answer_list=[]
    with tf.Session(config=config) as sess:
        # create model.
        print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, forward_only=True,en_vocab_size=en_dim,de_vocab_size=de_dim)
        model.batch_size = 1

        if not os.path.exists(FLAGS.model_dir):
            raise ValueError("No pre-trained model exists!")

        index=0
        responses=[]
        for data in lines :
            index+=1
            if index==count: break
            #fw.write("Query: ")
            write(fw,"\nQuery: ",data[0].decode('utf-8'),'')
            write(fw,"Answer: " ,data[1].decode('utf-8'),'')

            bucket_id =0

            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                [data],bucket_id,en_word2index,de_word2index)
            #print(len(encoder_inputs),len(encoder_inputs[0]))
            path, symbol,log_probs, output_logits = model.step(sess,encoder_inputs,decoder_inputs,
                                                     decoder_weights,bucket_id,True,beam_search)
            #print(log_probs.shape)
            #print(symbol.shape)
            num_steps=len(path)
            #print("Beam Prediction:")
            for k in range(beam_size):
                sentence=[]
                cur=k
                for i in reversed(range(num_steps)):
                    sentence.append(symbol[i][cur])
                    cur=path[i][cur]
                response = data_utils.indice_sentence(sentence[::-1],de_index2word)
                response=response+' .'
                write(fw,str(k)+' ',response,str(log_probs[k][-1]))
                responses.append(response)
                print(response,log_probs[k][-1])

            write(fw,'','-'*80,'')
            print('-'*80)
            Answer_list.append(responses)
        fw.close()
        with open('answer_list.pkl','w')as f:
        	pickle.dump(Answer_list,f)
        print("*"*30+' '*10+"END"+' '*10+"*"*30)



def decode():
    with tf.Session() as sess:
        # Create model and load parameters.
        beam_size = FLAGS.beam_size
        beam_search = FLAGS.beam_search
        attention = FLAGS.attention
        model = create_model(sess, True)
        model.batch_size = 1  # We decode one sentence at a time.
        # Load vocabularies.
        # Decode from standard input.
        if beam_search:
            sys.stdout.write("> ")
            sentence = sys.stdin.readline()
            while sentence:
                # Get token-ids for the input sentence.
                token_ids=data_utils.sentence_indice(sentence.split(' '))
                # Which bucket does it belong to?
                bucket_id = min([b for b in range(len(buckets))
                                 if buckets[b][0] > len(token_ids)])
                # Get a 1-element batch to feed the sentence to the model.
                encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                    {bucket_id: [(token_ids, [])]}, bucket_id)
                # Get output logits for the sentence.
                # print bucket_id
                path, symbol, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                                         target_weights, bucket_id, True, beam_search)

                k = output_logits[0]
                paths = []
                for kk in range(beam_size):
                    paths.append([])
                curr = range(beam_size)
                num_steps = len(path)
                for i in range(num_steps - 1, -1, -1):
                    for kk in range(beam_size):
                        paths[kk].append(symbol[i][curr[kk]])
                        curr[kk] = path[i][curr[kk]]
                recos = set()
                print("Replies --------------------------------------->")
                for kk in range(beam_size):
                    foutputs = [int(logit) for logit in paths[kk][::-1]]

                    # If there is an EOS symbol in outputs, cut them at that point.
                    if data_utils.EOS_ID in foutputs:
                        #         # print outputs
                        foutputs = foutputs[:foutputs.index(data_utils.EOS_ID)]
                    rec = " ".join([data_utils.indice_sentence(output) for output in foutputs])
                    if rec not in recos:
                        recos.add(rec)
                        print(rec)

                print("> ", "")
                sys.stdout.flush()
                sentence = sys.stdin.readline()
        else:
            sys.stdout.write("> ")
            sentence = sys.stdin.readline()

            while sentence:
                # Get token-ids for the input sentence.
                token_ids = data_utils.sentence_indice(sentence.split(' '))
                # Which bucket does it belong to?
                bucket_id = min([b for b in range(len(buckets))
                                 if buckets[b][0] > len(token_ids)])
                # for loc in locs:
                # Get a 1-element batch to feed the sentence to the model.
                encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                    {bucket_id: [(token_ids, [],)]}, bucket_id)

                _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                                 target_weights, bucket_id, True, beam_search)
                # This is a greedy decoder - outputs are just argmaxes of output_logits.

                outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
                # If there is an EOS symbol in outputs, cut them at that point.
                if data_utils.EOS_ID in outputs:
                    # print outputs
                    outputs = outputs[:outputs.index(data_utils.EOS_ID)]

                rec = " ".join([data_utils.indice_sentence(output) for output in outputs])
                print("> ", "")
                sys.stdout.flush()
                sentence = sys.stdin.readline()



def test(count):
    print('prepare')
    bucket_dbs = data_utils.read_bucket_dbs(FLAGS.buckets_dir)
    bucket_sizes = []
    for i in range(len(buckets)):
        bucket_size = bucket_dbs[i].size
        bucket_sizes.append(bucket_size)
        print('bucket {} data {} terms'.format(i, bucket_size))
    total_size = sum(bucket_sizes)
    print('total {} terms'.format(total_size))
    if count <= 0:
        count = total_size

    with tf.Session(config=config) as sess:
        # create model.
        print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, forward_only=True)
        model.batch_size = 1
        buckets_scale = [
            sum(bucket_sizes[:i + 1]) / total_size
            for i in range(len(bucket_sizes))]
        if not os.path.exists(FLAGS.model_dir):
            raise ValueError("No pre-trained model exists!")

        random_number = np.random.random_sample()
        bucket_id = min([i for i in range(len(buckets_scale)) if buckets_scale[i] > random_number])
        bucket_db = bucket_dbs[bucket_id]
        print('#' * 80)
        print('#' * 20 + ' ' * 10 + " Beam Search Result " + ' ' * 10 + '#' * 20)
        print('#' * 80)
        for data in bucket_db.get_batch_data(FLAGS.batch_size):

            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                data, bucket_id)
            if count <= 0: break
            count -= 1

            _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                             decoder_weights, bucket_id, True,False)

            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]

            response = data_utils.indice_sentence(outputs)

            print(" Prediction:" + '\t' + response)
            print('-' * 80)
        print("*" * 30 + ' ' * 10 + "END" + ' ' * 10 + "*" * 30)


def write(f,prefix,string,score):
    try:
        f.write(prefix+score+string.encode('utf-8', 'ignore')+'\n')

    except Exception:
        f.write(prefix+ score+string.encode('ascii', 'ignore')+'\n')

def main(_):
    if FLAGS.beam_test>0:
        beam_test(FLAGS.beam_test)
    elif FLAGS.test>-1:
        test(FLAGS.test)
    elif FLAGS.decode:
        decode()
    else:
        train()

if __name__ == '__main__':
    np.random.seed(0)
    tf.set_random_seed(0)
    tf.app.run()
