# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import unicode_literals

import codecs

import keras.backend as K
import numpy
import argparse
from keras.engine import Model
from keras.layers import Input
from keras import regularizers
from keras.layers.core import Dense, Lambda, Reshape, Masking
from keras.layers.embeddings import Embedding
from keras.layers.merge import concatenate
from keras.layers.recurrent import LSTM
from keras.layers.wrappers import TimeDistributed, Bidirectional
from keras.preprocessing import sequence


def main():
    parser = argparse.ArgumentParser(
        description='Morph2vec: create vector representations for words based on their morpheme vectors')
    parser.add_argument('--trainingFile', type=str, required=True, help="path to training file")
    parser.add_argument('--inputFile', type=str, required=True,
                        help="path to the file which lists words for which vectors will be created")
    parser.add_argument('--weightFile', type=str, required=True, help="path to the morph2vec weightfile")
    parser.add_argument('--output', '-o', type=str, required=True,
                        help="output directory where the word vectors will be saved")
    parser.add_argument('--segNo', type=int, default=10,
                        help="number of segmentations provided for a given word during training")

    args = parser.parse_args()

    training_file = args.trainingFile
    # vectors of the words listed in this files will be created.
    # ! only one segmentation (preferably the correct one) in provided in this file
    # file format: hypermarket:hyper-market+###+###+###+###+###+###+###+###+###
    file_to_vec = args.inputFile
    output_vector_file = args.output  # the file where the vectors will be saved
    weight_file = args.weightFile  # the name of the weight file
    number_of_segmentation = args.segNo

    print('===================================  Prepare indixes...  ==============================================')
    print('')

    word2sgmt = {}
    word2segmentations = {}
    seq = []
    morphs = []
    delim = "!"
    f = codecs.open(training_file, encoding='utf-8')
    for line in f:
        line = line.rstrip('\n')
        word, sgmnts = line.split(':')
        sgmt = sgmnts.split('+')
        word2segmentations[word] = list(s for s in sgmt)
        sgmt = list(s.split("!") for s in sgmt)
        word2sgmt[word] = sgmt
        seq.extend(sgmt)

    timesteps_max_len = 0

    for sgmt in seq:
        if len(sgmt) > timesteps_max_len: timesteps_max_len = len(sgmt)
        for morph in sgmt:
            morphs.append(morph)

    morph_indices = dict((c, i + 1) for i, c in enumerate(set(morphs)))
    morph_indices['###'] = 0

    indices_morph = dict((i+1, c) for i, c in enumerate(set(morphs)))
    indices_morph[0] = '###'

    print('')
    print('===================================  Prepare data...  ===============================================')
    print('')

    word2sgmt = {}
    word2segmentations = {}
    seq = []
    morphs = []

    f = codecs.open(training_file, encoding='utf-8')
    for line in f:
        line = line.rstrip('\n')
        word, sgmnts = line.split(':')
        sgmt = sgmnts.split('+')
        word2segmentations[word] = list(s for s in sgmt)
        sgmt = list(s.split(delim) for s in sgmt)
        word2sgmt[word] = sgmt
        seq.extend(sgmt)

    timesteps_max_len = 0

    for sgmt in seq:
        if len(sgmt) > timesteps_max_len: timesteps_max_len = len(sgmt)
        for morph in sgmt:
            morphs.append(morph)

    print('number of words: ', len(word2sgmt))

    print('number of morphemes: ', len(morphs))
    print('number of unique morphemes: ', len(set(morphs)))

    number_of_unique_morpheme = len(set(morphs))
    x_train = [[] for i in range(number_of_segmentation)]
    oov = 0
    for word in word2sgmt:
        for i in range(len(word2sgmt[word])):
            try:
                x_train[i].append([morph_indices[c] for c in word2sgmt[word][i]])
            except KeyError:
                print(c)
                oov += 1
    print ("oov", oov)

    for i in range(number_of_segmentation):
        x_train[i] = numpy.array(x_train[i])

    for i in range(len(x_train)):
        x_train[i] = sequence.pad_sequences(x_train[i], maxlen=timesteps_max_len)

    print('')
    print('===================================  Build model...  ===============================================')
    print('')

    morph_seg = []
    for i in range(number_of_segmentation):
        morph_seg.append(Input(shape=(None,), dtype='int32'))

    morph_embedding = Embedding(input_dim=number_of_unique_morpheme, output_dim=50, mask_zero=True, name="embeddding")

    embed_seg = []
    for i in range(number_of_segmentation):
        embed_seg.append(morph_embedding(morph_seg[i]))

    biLSTM = Bidirectional(LSTM(200, dropout=0.2, recurrent_dropout=0.2, return_sequences=False), merge_mode='concat')

    encoded_seg = []
    for i in range(number_of_segmentation):
        encoded_seg.append(biLSTM(embed_seg[i]))

    concat_vector = concatenate(encoded_seg, axis=-1)
    merge_vector = Reshape((number_of_segmentation, 400))(concat_vector)

    masked_vector = Masking()(merge_vector)

    seq_output = TimeDistributed(Dense(200))(masked_vector)

    attention_1 = TimeDistributed(Dense(units=200, activation='tanh', use_bias=False))(seq_output)

    attention_2 = TimeDistributed(Dense(units=1,
                                        activity_regularizer=regularizers.l1(0.01),
                                        use_bias=False))(attention_1)


    def attn_merge(inputs, mask):
        vectors = inputs[0]
        logits = inputs[1]
        # Flatten the logits and take a softmax
        logits = K.squeeze(logits, axis=2)
        pre_softmax = K.switch(mask[0], logits, -numpy.inf)
        weights = K.expand_dims(K.softmax(pre_softmax))
        return K.sum(vectors * weights, axis=1)


    def attn_merge_shape(input_shapes):
        return (input_shapes[0][0], input_shapes[0][2])


    attn = Lambda(attn_merge, output_shape=attn_merge_shape)
    attn.supports_masking = True
    attn.compute_mask = lambda inputs, mask: None
    content_flat = attn([seq_output, attention_2])
    model = Model(inputs=morph_seg, outputs=content_flat)

    model.load_weights(weight_file)
    q = model.predict([x_train[i] for i in range(len(x_train))])

    wordVectorMap = {}

    s = ""
    for x in range(0, len(x_train[0])):
        for i in x_train[0][x]:
            s = s + indices_morph[i]
        s= s.replace("#","")
        wordVectorMap[s] = q[x]
        print(s)
        s = ""

    # save word vectors
    f2 = codecs.open(file_to_vec, encoding='utf-8')
    wordSegMap = {}
    for x in f2:
        wordSegMap[x.split(":")[1].replace("#", "").replace("+", "").replace("\n", "")] = x.split(":")[0]
    with codecs.open(output_vector_file,'w',encoding='utf8') as f:
        s = ""
        f.write(str(len(x_train[0])) + " 300\n")
        for x in range(0, len(x_train[0])):
            for i in x_train[0][x]:
                if "#" not in indices_morph[i]:
                    s = s + delim + indices_morph[i]
            s= s.replace("#","")[1:]
            wordVectorMap[s] = q[x]
            print(wordSegMap[s])
            f.write(wordSegMap[s] + " " + " ".join(str(x) for x in q[x]) + "\n")
            s = ""
    f.close()
if __name__ == '__main__': main()
