# import warnings
# warnings.filterwarnings("ignore")  # Turn warnings off
import argparse
import nltk
import numpy as np
from string import punctuation
import re
from keras.models import load_model
from unicodecsv import DictReader
import tensorflow as tf
from keras.preprocessing import sequence
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from numpy import array
from numpy import argmax
from keras import backend as K
from keras.losses import sparse_categorical_crossentropy
import pdb


UCI_TOKEN = 'UCI_TOKEN'
pad_token = '<pad>'
unknown_token = '<unk>'


def masked_sparse_categorical_crossentropy(y_true, y_pred):
    # mask zeroes for the loss
    mask_value = K.variable(0)
    mask = K.all(K.equal(y_true, mask_value), axis=-1)
    mask = 1 - K.cast(mask, K.floatx())
    loss = sparse_categorical_crossentropy(y_true, y_pred) * mask
    return K.sum(loss) / K.sum(mask)


def categorical_accuracy(y_true, y_pred, mask=None):
    # define accuracy for masked loss
    # https://github.com/keras-team/keras/issues/2260
    if mask is not None:
        eval_shape = (reduce(mul, y_true.shape[:-1]), y_true.shape[-1])
        y_true_ = K.reshape(y_true, eval_shape)
        y_pred_ = K.reshape(y_pred, eval_shape)
        flat_mask = K.flatten(mask)
        comped = K.equal(K.argmax(y_true_, axis=-1),
                          K.argmax(y_pred_, axis=-1))
        ## not sure how to do this in tensor flow
        good_entries = flat_mask.nonzero()[0]
        return K.mean(K.gather(comped, good_entries))

    else:
        return K.mean(K.equal(K.argmax(y_true, axis=-1),
                              K.argmax(y_pred, axis=-1)))


def sparse_loss(y_true, y_pred):
    # define a loss so that the model can take integer as the output
    return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)


def mask_sparse_loss(y_true, y_pred):
    # define a loss so that the model can take integer as the output
    mask_value = 0
    # mask=0 if y_true=0 (padding)
    mask = K.cast(K.not_equal(y_true, mask_value), K.floatx())
    return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true*mask, logits=y_pred*mask)


def word_for_id(integer, n2w):
    for word, index in n2w.items():
        if index == integer:
            return word
    return None


def str_to_id(source, w2n):
    source=source.split('')
    target = []
    unknown_token_id = len(w2n)  # assume w2n contains a unknown token, default is vocab_size+1
    for w in source:
        if w not in w2n:
            target.append(unknown_token_id)
        else:
            target.append(w2n[w])
    return target


# generate target given source sequence
def predict_sequence(model, source, n2w=None, w2n=None):
    if isinstance(source, str):
        source = str_to_id(source, w2n)
    if not (isinstance(source, list) or isinstance(source, np.ndarray)):
        raise ValueError('source must be a string or np.array.')
    prediction = model.predict(source, verbose=0)
    integers = argmax(prediction, axis=2)
    """ if input is a list of string
    target = list()
    for i in integers:
        word = n2w[i]
        if word is None:
            break
        target.append(word)
    """
    return integers


def get_data(data_df, max_seq_len, emb_file=None):
    w2n, n2w, vocab_size = create_word_num_maps(data_df)
    x_train_raw = create_unlabeled_keras_input(data_df, w2n, vocab_size)
    # x: pad pad w1 w2
    # y: w1 w2 pad pad
    x_train = pad_sequences(x_train_raw, maxlen=max_seq_len, padding='post')
    y_train = pad_sequences(x_train_raw, maxlen=max_seq_len, padding='post')
    emb_matrix = None
    if emb_file is not None:
        emb_matrix = get_embeddings(emb_file, w2n, vocab_size)
    return x_train, y_train, emb_matrix, w2n, n2w, vocab_size


def encode_output(sequences, vocab_size):
    """
    one hot encode target sequence
    """
    ylist = list()
    for sequence in sequences:
        encoded = to_categorical(sequence, num_classes=vocab_size)
        ylist.append(encoded)
    y = array(ylist)
    y = y.reshape(sequences.shape[0], sequences.shape[1], vocab_size)
    return y


def create_word_num_maps(data_df, max_vocab_size=20000):
    """
    Build two hashmaps: word-to-id (w2n) and id-to-word (n2w)
    :param data_df: training data that has multi-lingual langs in STEMMED col (already been tokenized)
    :return: two hashmaps
    """
    corpus = []
    vocab = {}
    for _, row in data_df.itertuples():
        text = str(row) if len(str(row)) > 0 else ''
        for word in text.split():
            # word = word.strip(punctuation)  # remove en leading/trailing punctuation
            # clean_word = re.sub(r'[0-9]+', '', word)  # remove numbers
            clean_word = word
            if re.match('\S+', clean_word):
                # if clean_word does not begin with space
                corpus.extend([clean_word])
                vocab[clean_word] = '1'
    vocab_size = len(vocab.keys())
    print('Vocab-size:', vocab_size)
    freqdist = nltk.FreqDist(corpus)
    final_vocab_size = min(max_vocab_size, vocab_size)
    print('Final Vocab-size:', final_vocab_size)
    top = freqdist.most_common(final_vocab_size)
    w2n, n2w = ({} for _ in range(2))
    for i in range(len(top)):
        # w2n is ranked by freq order, leaving 0 as pad
        w2n[top[i][0]] = i+1
        n2w[i+1] = top[i][0]

    CODES = {pad_token: 0, unknown_token: final_vocab_size+1}
    w2n[pad_token], w2n[unknown_token]= CODES[pad_token], CODES[unknown_token]
    n2w[CODES[pad_token]], n2w[CODES[unknown_token]] = pad_token, unknown_token
    """
    # Because we use text.lower() when extracting w2n
    if UCI_TOKEN in w2n:
        w2n[UCI_TOKEN.lower()] = w2n[UCI_TOKEN]
        w2n.pop(UCI_TOKEN)
        n2w[w2n[UCI_TOKEN.lower()]] = UCI_TOKEN.lower()
    """
    return w2n, n2w, len(w2n)


def create_unlabeled_keras_input(data_df, w2n, vocab_size):
    # Create vector for each word, to identify which vector in w2n
    # df_train: tokenized training data
    x_q = []
    for _, row in data_df.itertuples():
        text = str(row) if len(str(row)) > 0 else ''
        # q = [map_to_number(word.strip(punctuation), w2n, vocab_size) for word in text.lower().split()]
        q = [w2n[word] if word in w2n else w2n[unknown_token] for word in text.split()]
        x_q.append(q)
    return x_q


def get_embeddings(filename, w2n, vocab_size):
    embeddings_index = {}
    f = open(filename)
    for line in f:
        values = line.split()
        word = values[0]
        coefs = np.asarray(values[1:], dtype='float32')
        embeddings_index[word] = coefs
    f.close()
    embedding_dim = len(coefs)
    print('Found %s word vectors in the pre-trained matrix.' % len(embeddings_index))
    # prepare embedding matrix
    print('Preparing embedding matrix.')
    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    for word, i in w2n.items():
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            # oovs will have all-zeros emb, including padding and unknown
            embedding_matrix[i] = embedding_vector
    return embedding_matrix


# ----------------------------------------------------------------------------------
def get_model(model_fname):
    return load_model(model_fname)


def create_argparser():
    parser = argparse.ArgumentParser(description='Train an LSTM model', add_help=True,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--train_file', type=str, action='store', dest='train_file', default='None')
    parser.add_argument('--test_file', type=str, action='store', dest='test_file', default='None')
    parser.add_argument('--model_file', type=str, action='store', dest='model_file', default='lstm-model-file')
    parser.add_argument('--emb_file', type=str, action='store', default='none', dest='emb_file')
    parser.add_argument('--log_file', type=str, action='store', dest='log_file')
    parser.add_argument('--models_dir', type=str, action='store', dest='models_dir')
    parser.add_argument('--val_results_file', type=str, action='store', dest='val_results_fname')
    parser.add_argument('--results_file', type=str, action='store', dest='results_file', help='Output filename',
                        default='temp-results')
    parser.add_argument('--prefix', type=str, action='store', dest='prefix', help='Prefix for file name',
                        default='None')

    parser.add_argument('--pretrained', type=int, action='store', dest='pretrained', default=1)
    parser.add_argument('--emb_dim', type=int, action='store', dest='emb_dim', default=100)
    parser.add_argument('--max_seq_len', type=int, action='store', dest='max_seq_len', default=100)
    parser.add_argument('--lstm_dim', type=int, action='store', default=100, dest='lstm_dim')
    parser.add_argument('--batch_size', type=int, action='store', default=32, dest='batch_size')
    parser.add_argument('--dropout', type=float, action='store', dest='dropout', default=0.2, help='Dropout')
    parser.add_argument('--lr', type=float, action='store', dest='lr', default=0.1, help='Learning Rate')

    parser.add_argument('--n', type=int, action='store', dest='n', help='Number of experiments to run', default=20)
    parser.add_argument('--num_classes', type=int, action='store', default=2, dest='num_classes')

    parser.add_argument('--gpu', type=str, action='store', dest='gpu', help='GPU', default='1')

    return parser

