'''
This file describes the Deep Neural Lemmatization Model Architecture used in the paper. We have implemented the model using 'keras'. This is not an executable file. It just describes how to implement the model.
The "get_model" function defines the structure of the model. If face any problem in understanding, please email to abhisek0842@gmail.com
'''


from __future__ import print_function
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Activation, advanced_activations, Embedding, merge, Reshape, Embedding
from keras.layers import LSTM, SimpleRNN, GRU, TimeDistributedDense, TimeDistributed, Bidirectional

############## MODEL PARAMETERS ###################
word_vector_size = #<Semantic word vector dimension, this is the dimension of the vectors obtained from word2vec>
word_context_length = #<Maximum length of the sentence present in the dataset, this is the maximum length of the sequence for the outer level recurrent network which does the edit tree classification>
char_feature_output = #<Number of hidden layer neurons at the inner level (character level) recurrent network>
hidden_size = #<Number of hidden layer neurons at the outer level recurrent network>
max_word_length = #<Maximum length of the word present in the dataset, this is the maximum length of the sequence for the inner level recurrent network which builds the syntactic representations of words>
nb_tree_classes = #<Number of unique edit trees present in the training data>
no_of_uniq_chars = #<Number of unique unicode characters present in the language, it is the vocabulary size given input to the embedding layer>
output_dim = #<Dimension of the dense embedding>
###################################################


############## DEFINING THE MODEL ARCHITECTURE ###########
def get_model(nn1="BGRU", nn2="BGRU", isApplicableTreeInput=True): #nn1 = cell type of first level recurrent network, nn2 = cell type of second level recurrent network; isApplicableTreeInput = True/False => denotes whether applicable edit trees are used in training or not
    print('Build model...')

    char_input = Input(shape=(word_context_length,max_word_length,), dtype='float32', name='char_input')
    char_input1 = TimeDistributed(Embedding(no_of_uniq_chars, output_dim, input_length = max_word_length))(char_input)
    char_input2 = Dropout(0.2)(char_input1)

    if(nn1 == "BRNN"):
        lstm_out = TimeDistributed(Bidirectional(SimpleRNN(char_feature_output,dropout_W=0.2, dropout_U=0.2)))(char_input2)
    elif(nn1=="BLSTM")
        lstm_out = TimeDistributed(Bidirectional(LSTM(char_feature_output,dropout_W=0.2, dropout_U=0.2)))(char_input2)
    else:
        lstm_out = TimeDistributed(Bidirectional(GRU(char_feature_output,dropout_W=0.2, dropout_U=0.2)))(char_input2)

    word_input = Input(shape=(word_context_length,word_vector_size,), name='word_input')
    merged = merge([lstm_out, word_input], mode='concat',concat_axis=2)
    if(nn2 == "BRNN"):
        x = Bidirectional(SimpleRNN(hidden_size,return_sequences=True,dropout_W=0.2, dropout_U=0.2))(merged)
    elif(nn2=="BLSTM"):
        x = Bidirectional(LSTM(hidden_size,return_sequences=True,dropout_W=0.2, dropout_U=0.2))(merged)
    else:
        x = Bidirectional(GRU(hidden_size,return_sequences=True,dropout_W=0.2, dropout_U=0.2))(merged)
    if(isApplicableTreeInput):
        applicable_tree_input = Input(shape=(word_context_length,nb_tree_classes,), name='applicable_tree_input')
        x = merge([x,applicable_tree_input], mode='concat',concat_axis=2)
        main_loss1 = TimeDistributed(Dense(nb_tree_classes, activation='softplus'))(x)
        main_loss = Activation('softmax')(main_loss1)
        model = Model(input=[char_input, word_input,applicable_tree_input], output=[main_loss])
    else:
        main_loss1 = TimeDistributed(Dense(nb_tree_classes, activation='softplus'))(x)
        main_loss = Activation('softmax')(main_loss1)
        model = Model(input=[char_input, word_input], output=[main_loss])        
    return model
###################################################################################





############## DESCRIPTION ABOUT HOW THE TRAINING DATA AND TEST DATA SHOULD BE PRE-PROCESSED FOR FITTING INTO THE MODEL #############

#Let there are p <word:lemma> pairs in the training data. Let they are divided into m fixed length sequences where each sequence has length word_context_length. i.e. m * word_context_length = p

#X_word_train is the array of semantic vectors of the words in the training data. Its shape should be (m, word_context_length, word_vector_size)

#X_char_train denotes the array of numeric representation of the character sequences of words in the training data.
#For example, Let {a,b,c,d} be the vocabulary and we build the mapping table {a:1, b:2, c:3, d:4}. Then the word "acb" will be represented as "132". Let the maximum length of any word is restricted to 3. Then all the words in the language will be represented by a 3 digit number where each digit will be in the range [1,4]. e.g. "adbc" will be represented as "142".
#So, the shape of X_char_train should be (m, word_context_length, max_word_length). no_of_uniq_chars will be the vocabulary size.

#X_applic_trees_train denotes the array of applicable tree vectors of the words in the training data. For a word, its applicable tree vector shaould be of length nb_tree_classes where the dimensions corresponding to the applicable trees for the word are set to 1 and the rest dimensions are set to 0. So, the shape of X_applic_trees_train should be (m, word_context_length, nb_tree_classes)

#Y_tree_train denotes the array of one hot encoded class labels of the words in the training data. For a word, only the dimension corresponding to the appropriate tree is 1, rests are 0. So, the shape of Y_tree_train should be (m, word_context_length, nb_tree_classes)




#Let there are q words in the test data. Let they are divided into n fixed length sequences where each sequence has length word_context_length. i.e. n * word_context_length = q

#X_word_test is the array of semantic vectors of the words in the test data. Its shape should be (n, word_context_length, word_vector_size)

#X_char_test denotes the array of numeric representation of the character sequences of words in the test data. The shape of X_char_test should be (n, word_context_length, max_word_length)

#X_applic_trees_test denotes the array of applicable tree vectors of the words in the test data. The shape of X_applic_trees_test array should be (n, word_context_length, nb_tree_classes)

###############################################################################################################




#################### CALLING, COMPILATION AND TRAINING OF THE MODEL. FINALLY USING IT FOR PREDICTION ####################################
if __name__ == "__main__":
    model = get_model() #calling the model, it will return BGRU-BGRU model as this is the default option
    print(model.summary())
    model.compile(optimizer='adam', loss=['categorical_crossentropy'],metrics=['accuracy'])
    model.fit([X_char_train, X_word_train,X_applic_trees_train],[Y_tree_train],nb_epoch=nb_epoch, batch_size=batch_size)   
    Y_pred = model.predict([X_char_test, X_word_test,X_applic_trees_test])

###############################################################################################################

