from __future__ import print_function
from __future__ import unicode_literals

import codecs
import sys

import keras.backend as K
import numpy
from gensim.models import KeyedVectors
from keras import regularizers
from keras.engine import Model
from keras.layers import Input
from keras.layers.core import Dense, Lambda, Reshape, Masking
from keras.layers.embeddings import Embedding
from keras.layers.merge import concatenate
from keras.layers.recurrent import LSTM
from keras.layers.wrappers import TimeDistributed, Bidirectional
from keras.preprocessing import sequence
import argparse


def main():
    parser = argparse.ArgumentParser(
        description='Learn morph2vec: morpheme-based representation learning ')
    parser.add_argument('--input', type=str, required=True, help = "path to training file")
    parser.add_argument('--wordVector', type=str, required=True, help = "path to word2vec vector file")
    parser.add_argument('--output', '-o', type=str, required=True, help = "output directory where the weight file will be saved")
    parser.add_argument('--segNo', type=int, default=10, help = "number of segmentations provided for a given word during training")
    parser.add_argument('--batch', type=int, default=32,help = "batch size")
    parser.add_argument('--epoch', type=int, default=5,help = "number of epochs")
    parser.add_argument('--dim', type=int, default=300, help="dimension")

    args = parser.parse_args()

    number_of_segmentation = args.segNo
    gensim_model = args.wordVector
    training_file = args.input
    output_file =args.output
    batch_size = args.batch
    number_of_epoch = args.epoch
    dim = args.dim
    print('===================================  Prepare data...  ==============================================')
    print('')

    word2sgmt = {}
    word2segmentations = {}
    seq = []
    morphs = []

    f = codecs.open(training_file, encoding='utf-8')
    for line in f:
        line = line.rstrip('\n')
        word, sgmnts = line.split(':')
        sgmt = sgmnts.split('+')
        word2segmentations[word] = list(s for s in sgmt)
        sgmt = list(s.split('-') for s in sgmt)
        word2sgmt[word] = sgmt
        seq.extend(sgmt)

    timesteps_max_len = 0

    for sgmt in seq:
        if len(sgmt) > timesteps_max_len: timesteps_max_len = len(sgmt)
        for morph in sgmt:
            morphs.append(morph)

    print('number of words: ', len(word2sgmt))

    morph_indices = dict((c, i + 1) for i, c in enumerate(set(morphs)))
    morph_indices['###'] = 0

    indices_morph = dict((i+1, c) for i, c in enumerate(set(morphs)))

    print('number of morphemes: ', len(morphs))
    print('number of unique morphemes: ', len(set(morphs)))

    x_train = [[] for i in range(number_of_segmentation)]
    for word in word2sgmt:
        for i in range(len(word2sgmt[word])):
            x_train[i].append([morph_indices[c] for c in word2sgmt[word][i]])

    for i in range(number_of_segmentation):
        x_train[i] = numpy.array(x_train[i])

    for i in range(len(x_train)):
        x_train[i] = sequence.pad_sequences(x_train[i], maxlen=timesteps_max_len)

    print('')
    print('==========================  Load pre-trained word vectors...  ======================================')
    print('')

    y_train = []
    try:
        w2v_model = KeyedVectors.load_word2vec_format(gensim_model, binary=False, encoding='utf-8')
    except:
        w2v_model = KeyedVectors.load_word2vec_format(gensim_model, binary=True, encoding='utf-8')

    for word in word2sgmt:
        y_train.append(w2v_model[word].tolist())
    y_train = numpy.array(y_train)
    if len(y_train) != len(word2sgmt): sys.exit(
        'ERROR: Pre-trained vectors do not contain all words in wordlist !!')
    print('number of pre-trained vectors: ', len(w2v_model.vocab))

    print('number of words found: ', len(y_train))
    print('shape of Y: ', y_train.shape)

    print('')
    print('===================================  Save Input and Output...  ===============================================')
    print('')

    numpy.save("x_train", x_train)
    numpy.save("y_train", y_train)

    print('')
    print('===================================  Build model...  ===============================================')
    print('')

    morph_seg = []
    for i in range(number_of_segmentation):
        morph_seg.append(Input(shape=(None,), dtype='int32'))

    morph_embedding = Embedding(input_dim=len(set(morphs))+1, output_dim=int(dim/4), mask_zero=True, name="embeddding")

    embed_seg = []
    for i in range(number_of_segmentation):
        embed_seg.append(morph_embedding(morph_seg[i]))

    biLSTM = Bidirectional(LSTM(dim, dropout=0.2, recurrent_dropout=0.2, return_sequences=False), merge_mode='concat')

    encoded_seg = []
    for i in range(number_of_segmentation):
        encoded_seg.append(biLSTM(embed_seg[i]))

    concat_vector = concatenate(encoded_seg, axis=-1)
    merge_vector = Reshape((number_of_segmentation, (2*dim)))(concat_vector)

    masked_vector = Masking()(merge_vector)

    seq_output = TimeDistributed(Dense(dim))(masked_vector)

    attention_1 = TimeDistributed(Dense(units=dim, activation='tanh', use_bias=False))(seq_output)

    attention_2 = TimeDistributed(Dense(units=1,
                                        activity_regularizer=regularizers.l1(0.01),
                                        use_bias=False))(attention_1)


    def attn_merge(inputs, mask):
        vectors = inputs[0]
        logits = inputs[1]
        # Flatten the logits and take a softmax
        logits = K.squeeze(logits, axis=2)
        pre_softmax = K.switch(mask[0], logits, -numpy.inf)
        weights = K.expand_dims(K.softmax(pre_softmax))
        return K.sum(vectors * weights, axis=1)


    def attn_merge_shape(input_shapes):
        return (input_shapes[0][0], input_shapes[0][2])


    attn = Lambda(attn_merge, output_shape=attn_merge_shape)
    attn.supports_masking = True
    attn.compute_mask = lambda inputs, mask: None
    content_flat = attn([seq_output, attention_2])

    model = Model(inputs=morph_seg, outputs=content_flat)

    model.compile(loss='cosine_proximity', optimizer='adam', metrics=['accuracy'])
    model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=number_of_epoch)

    print('')
    print('===================================  Save model weights...  ===============================================')
    print('')

    model.save_weights(output_file)
    print("Model saved in path: %s" % output_file)
if __name__ == '__main__': main()
