#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
See the following papers for more information on neural translation models.
 * http://arxiv.org/abs/1409.3215
 * http://arxiv.org/abs/1409.0473
 * http://arxiv.org/abs/1412.2007
"""




from __future__ import print_function, division


import os
import sys
import time

import numpy as np
import tensorflow as tf

import data_utils2
import s2s_model2 as s2s_model


tf.app.flags.DEFINE_string('en_vocab','data/s2s.reddit.vocab.enc.txt','vocabulary of encoder')
tf.app.flags.DEFINE_string('de_vocab','data/s2s.reddit.vocab.dec.char.txt','vocabulary of decoder')

tf.app.flags.DEFINE_integer('beam_size',10,'beam search size')
tf.app.flags.DEFINE_boolean('beam_search',False,'Set to True for beam search')
tf.app.flags.DEFINE_boolean('attention',True,'attention model or simple seq2seq model')

tf.app.flags.DEFINE_integer("en_vocab_size", 16000, "encoder vocabulary size.")
tf.app.flags.DEFINE_integer("de_vocab_size", 20000, "decoder vocabulary size.")

tf.app.flags.DEFINE_float('learning_rate',2e-4,'')
tf.app.flags.DEFINE_float('max_gradient_norm',3.0,'')
tf.app.flags.DEFINE_integer('batch_size',1,'')
tf.app.flags.DEFINE_integer('size',100,'')
tf.app.flags.DEFINE_integer('num_layers',1,'')
tf.app.flags.DEFINE_integer('num_epoch',5,'')
tf.app.flags.DEFINE_integer('num_samples', 0, '')
tf.app.flags.DEFINE_integer('num_per_epoch',-1,'')
tf.app.flags.DEFINE_string('buckets_dir','./bucket_dbs','sqlite3')
tf.app.flags.DEFINE_string('model_dir','./model','')
tf.app.flags.DEFINE_string('model_name','s2s.attn.tv3.model','')

tf.app.flags.DEFINE_string('test_file','./data/test/test.char.txt','testfile path')
tf.app.flags.DEFINE_float('diversity_rate',0.1,'diversity rate of gamma k algorithms')


tf.app.flags.DEFINE_integer('beam_test',-1,'test beam results')
tf.app.flags.DEFINE_boolean('test',-1,'test argmax results')
tf.app.flags.DEFINE_boolean("decode", False,"Set to True for interactive decoding.")

FLAGS = tf.app.flags.FLAGS



buckets = [
        (100, 60)
    ]

# TF GPU Config
#os.environ["CUDA_VISIBLE_DEVICES"] = "2"
config = tf.ConfigProto()
#config.gpu_options.allow_growth = True
# config.gpu_options.per_process_gpu_memory_fraction = 0.8
# config.log_device_placement = True


def load_dictionary(fname):
    with open(data_utils2.with_path(fname), 'r') as fp:
        index_word =  {} # OrderedDict()
        word_index = {} # OrderedDict()
        for index, word in enumerate(fp):
            word = word.decode("utf8").strip()
            index_word[index] = word
            word_index[word] = index
        dim=len(index_word.keys())
    return  word_index,index_word,dim

def create_model(session, forward_only,en_vocab_size,de_vocab_size):
    """建立模型"""
    #dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = s2s_model.S2SModel(
        en_vocab_size,de_vocab_size,buckets,
        FLAGS.size,FLAGS.num_layers,FLAGS.max_gradient_norm,FLAGS.batch_size,
        FLAGS.learning_rate,num_samples=FLAGS.num_samples,forward_only=forward_only,
        beam_search=FLAGS.beam_search, beam_size=FLAGS.beam_size, attention=FLAGS.attention,gamma=FLAGS.diversity_rate)

    ckpt=tf.train.get_checkpoint_state(FLAGS.model_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" %ckpt.model_checkpoint_path)
        model.saver.restore(session,ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.global_variables_initializer())
    return model


def train():
    print('prepare data')
    en_word2index,_,en_vocab_size=load_dictionary(FLAGS.en_vocab)
    de_word2index,_,de_vocab_size = load_dictionary(FLAGS.de_vocab)

    bucket_dbs = data_utils2.read_bucket_dbs(FLAGS.buckets_dir,buckets)
    bucket_sizes = []
    for i in range(len(buckets)):
        bucket_size = bucket_dbs[i].size
        bucket_sizes.append(bucket_size)
        print('bucket {} data {} terms'.format(i, bucket_size))
    total_size = sum(bucket_sizes)
    print('total {} terms'.format(total_size))

    # initialize the Session
    with tf.Session(config=config) as sess:
        # Create model.
        print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, forward_only=False,en_vocab_size=en_vocab_size,de_vocab_size=de_vocab_size)
        print('Build Finished.')
        buckets_scale = [
            sum(bucket_sizes[:i + 1]) / total_size
            for i in range(len(bucket_sizes))
        ]
        if not os.path.exists(FLAGS.model_dir):
            os.makedirs(FLAGS.model_dir)

        for epoch_index in range(FLAGS.num_epoch):
            print('Epoch {}:'.format(epoch_index))
            mname = FLAGS.model_name+"."+str(epoch_index)
            sample_trained = 0
            random_number = np.random.random_sample()
            bucket_id = min([
                i for i in range(len(buckets_scale))
                if buckets_scale[i] > random_number])

            time_start = time.time()
            bucket_db = bucket_dbs[bucket_id]         
            for data in bucket_db.get_batch_data(FLAGS.batch_size):
                encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                    data, bucket_id,en_word2index,de_word2index)
                print('*'*90)
                _, step_loss, output = model.step(sess,encoder_inputs,decoder_inputs,
                                                  decoder_weights,bucket_id,False,FLAGS.beam_search)

                sample_trained += FLAGS.batch_size
                # batch_loss.append(step_loss)
                time_spend = time.time() - time_start
                time_start = time.time()
                print('%.3f\t%.d\t%.3f' % \
                    (time_spend, sample_trained, step_loss))

                if sample_trained % 50000 == 0:
                    model.saver.save(sess, os.path.join(FLAGS.model_dir, mname))
            print('\n')


def beam_test(count):
    beam_size = FLAGS.beam_size
    beam_search = FLAGS.beam_search
    en_index2word, en_word2index = load_dictionary(FLAGS.en_vocab)
    de_index2word, de_word2index = load_dictionary(FLAGS.de_vocab)

    lines=open(FLAGS.test_file,'r').read().split('\n')
    lines=[l.split("#TAB#") for l in lines]
    bucket_dbs = data_utils2.read_bucket_dbs(FLAGS.buckets_dir,buckets)
    bucket_sizes = []
    for i in range(len(buckets)):
        bucket_size = bucket_dbs[i].size
        bucket_sizes.append(bucket_size)
        print('bucket {} data {} terms'.format(i, bucket_size))
    total_size = sum(bucket_sizes)
    print('total {} terms'.format(total_size))

    fw=open('diversity_rate_%.2f.log'% FLAGS.diversity_rate,'w')
    with tf.Session(config=config) as sess:
        # create model.
        print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, forward_only=True)
        model.batch_size = 1

        if not os.path.exists(FLAGS.model_dir):
            raise ValueError("No pre-trained model exists!")

        print('#'*80)
        print('#'*20+' '*10+" Beam Search Result "+' '*10+'#'*20)
        print('#' * 80)
        index=0
        for data in lines :
            print(index)
            index+=1
            if index==count: break
            print('Query:',data[0])
            print("Response:",data[1])
            #fw.write("Query: ")
            write(fw,"\nQuery: ",data[0].decode('utf-8'))
            write(fw,"Answer: " ,data[1].decode('utf-8'))

            bucket_id =0

            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                [data],bucket_id,en_word2index,de_word2index)
            #print(len(encoder_inputs),len(encoder_inputs[0]))
            path, symbol, output_logits = model.step(sess,encoder_inputs,decoder_inputs,
                                                     decoder_weights,bucket_id,True,beam_search)
            num_steps=len(path)
            print("Beam Prediction:")
            for k in range(beam_size):
                sentence=[]
                cur=k
                for i in reversed(range(num_steps)):
                    sentence.append(symbol[i][cur])
                    cur=path[i][cur]
                response = data_utils2.indice_sentence(sentence[::-1],de_index2word)
                write(fw,str(k)+' ',response)
                print(response)

            write(fw,'','-'*80)
            print('-'*80)
        fw.close()
        print("*"*30+' '*10+"END"+' '*10+"*"*30)


def write(f,prefix,string):
    try:
        f.write(prefix+string.encode('utf-8', 'ignore')+'\n')

    except Exception:
        f.write(prefix+ string.encode('ascii', 'ignore')+'\n')

def main(_):
    if FLAGS.beam_test>0:
        beam_test(FLAGS.beam_test)
    else:
        train()

if __name__ == '__main__':
    np.random.seed(0)
    tf.set_random_seed(0)
    tf.app.run()
