import os

import tensorflow as tf
from tensorflow.python.keras.callbacks import History
import matplotlib.pyplot as plt
import numpy as np
from defiNNet.DefAnalyzer import DefAnalyzer
from defiNNet.parsing.parse_tree import Parser
from preprocessing.w2v_preprocessing_embedding import PreprocessingWord2VecEmbedding, POSToIndex
from nltk.corpus import wordnet as wn


class DefiNNet:
    def __init__(self, denn, pretrained_embeddings_path):
        self.denn = denn
        self.preprocessor = PreprocessingWord2VecEmbedding(pretrained_embeddings_path, binary=True)

        self.defAnalyzer = DefAnalyzer(parser=Parser())

    @staticmethod
    def train(pretrained_embeddings_path, train, tagset, output_dir=None, iteration=None):
        if iteration is None:
            i = ''
        else:
            i = str(iteration)
        (train_data, train_target, train_target_pos, train_w1_pos, train_w2_pos) = train

        FEATURES = 300
        embedding_pos_tags = tf.keras.layers.Embedding(input_dim=len(tagset), output_dim=30, input_length=1)

        first_embedding = tf.keras.layers.Input(shape=(FEATURES,))
        x1 = tf.keras.layers.Dense(500)(first_embedding)
        x1 = tf.keras.layers.LeakyReLU(alpha=0.5)(x1)

        second_embedding = tf.keras.layers.Input(shape=(FEATURES,))
        x2 = tf.keras.layers.Dense(500)(second_embedding)
        x2 = tf.keras.layers.LeakyReLU(alpha=0.5)(x2)

        x = tf.keras.layers.Add()([x1, x2])
        x = tf.keras.layers.Dense(300)(x)
        x = tf.keras.layers.LeakyReLU(alpha=0.5)(x)

        w1_pos = tf.keras.Input(shape=(1,))
        w1_x2 = embedding_pos_tags(w1_pos)
        w1_x2 = tf.keras.layers.Flatten()(w1_x2)
        w1_x2 = tf.keras.layers.Dense(50)(w1_x2)

        w2_pos = tf.keras.Input(shape=(1,))
        w2_x2 = embedding_pos_tags(w2_pos)
        w2_x2 = tf.keras.layers.Flatten()(w2_x2)
        w2_x2 = tf.keras.layers.Dense(50)(w2_x2)

        target_pos = tf.keras.Input(shape=(1,))
        target_x2 = embedding_pos_tags(target_pos)
        target_x2 = tf.keras.layers.Flatten()(target_x2)
        target_x2 = tf.keras.layers.Dense(50)(target_x2)

        t = tf.keras.layers.Concatenate()([w1_x2, w2_x2, target_x2])
        t = tf.keras.layers.Dense(100)(t)
        t = tf.keras.layers.LeakyReLU(alpha=0.5)(t)

        x = tf.keras.layers.Concatenate()([x, t])
        x = tf.keras.layers.Dense(300)(x)
        x = tf.keras.layers.LeakyReLU(alpha=0.5)(x)
        x = tf.keras.layers.Dense(400)(x)
        x = tf.keras.layers.LeakyReLU(alpha=0.5)(x)
        output = tf.keras.layers.Dense(300)(x)

        denn = tf.keras.Model(inputs=[first_embedding, w1_pos, second_embedding, w2_pos, target_pos],
                               outputs=output)
        denn.summary()

        denn.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
            loss=tf.keras.losses.cosine_similarity,
            metrics=[tf.keras.metrics.mse, tf.keras.losses.cosine_similarity]
        )

        N_EPOCHS = 50
        BATCH_SIZE = 128

        print(train_data.shape)
        print(train_data[:, 0].shape)
        print(train_data[:, 1].shape)
        print(train_w1_pos.shape)
        print(train_w2_pos.shape)
        history: History = denn.fit(
            x=[train_data[:, 0], train_w1_pos, train_data[:, 1], train_w2_pos, train_target_pos],
            y=train_target, epochs=N_EPOCHS,
            batch_size=BATCH_SIZE)

        # summarize history for loss
        plt.plot(history.history['loss'])
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'test'], loc='upper left')
        if output_dir is None:
            plt.savefig('learning_curve_model_'+str(i))
        else:
            plt.savefig(os.path.join(output_dir, 'learning_curve_model_' + str(i)))

        return DefiNNet(denn, pretrained_embeddings_path=pretrained_embeddings_path)

    def test(self, test):
        (test_data, test_target, test_target_pos, test_w1_pos, test_w2_pos) = test
        test_history = self.denn.evaluate(
            x=[test_data[:, 0], test_w1_pos, test_data[:, 1], test_w2_pos, test_target_pos],
            y=test_target)
        print(test_history)

    def save(self, output_model_path="denn.h5"):
        self.denn.save(output_model_path)

    @staticmethod
    def load(pretrained_embeddings_path, input_model_path="denn.h5"):
        denn = tf.keras.models.load_model(input_model_path)
        return DefiNNet(denn, pretrained_embeddings_path)

    def predict_and_word(self, target, pos=None):
        if pos is None:
            if len(wn.synsets(target)) > 0:
                pos = wn.synsets(target)[0].pos()
        else:
            pos = pos.lower()

        words, pos = self.defAnalyzer.analyze(target, category=pos)

        w1 = words[1]
        w2 = words[2]

        w1_pos = pos['w1_pos']
        w2_pos = pos['w2_pos']
        target_pos = pos['target_pos']
        return self.predict_analyzed(w1, w1_pos, w2, w2_pos, target_pos), w1+"_"+w2

    def predict(self, target, pos=None):
        if pos is None:
            if len(wn.synsets(target)) > 0:
                pos = wn.synsets(target)[0].pos()
        else:
            pos = pos.lower()

        words, pos = self.defAnalyzer.analyze(target, category=pos)

        w1 = words[1]
        w2 = words[2]

        w1_pos = pos['w1_pos']
        w2_pos = pos['w2_pos']
        target_pos = pos['target_pos']
        return self.predict_analyzed(w1, w1_pos, w2, w2_pos, target_pos)

    def predict_analyzed(self, w1, w1_pos, w2, w2_pos, target_pos):
        return self.denn.predict([np.array([self.preprocessor.get_vector(w1)]),
                                 np.array([POSToIndex.index(w1_pos)]),
                                 np.array([self.preprocessor.get_vector(w2)]),
                                 np.array([POSToIndex.index(w2_pos)]),
                                 np.array([POSToIndex.index(target_pos)])])
