import os
import sys

import numpy as np
import tensorflow as tf

import perceptron

from model.nn import SentimentModel
from utils.data_utils import (
    load_pickle,
    load_vocab
)
from utils.logger import get_logger
from utils.initialize import initialize_weights


logger = get_logger(__name__)


def has_but(sentence):
    return ' but ' in sentence


def ids_to_sent(ids, rev_vocab):
    return ' '.join([rev_vocab[x] for x in ids['sentence'] if x != 0])


def calculate(sess, model_dev, data, args, vocab, rev_vocab):
    batch_size = args.config.eval_batch_size
    num_batches = int(np.ceil(float(len(data)) / batch_size))
    grads_l2 = np.zeros([len(data), data[0]['sentence_len']])
    for i in range(num_batches):
        split = data[i * batch_size:(i + 1) * batch_size]
        total = len(split)
        # The last batch is likely to be smaller than batch_size
        split.extend([split[-1]] * (batch_size - total))
        seq_len = np.array([x['sentence_len'] for x in split])
        max_seq_len = np.max(seq_len)
        # sentence_id = np.array([x['sentence_id'] for x in split])
        # labels = np.array([x['label'] for x in split])
        sents = [np.array(x['sentence']) for x in split]
        sentences = np.array([np.lib.pad(x, (0, max_seq_len - len(x)), 'constant') for x in sents])
        sentence_mask = perceptron.compute_mask(split)
        feed_dict = {
            model_dev.inputs.name: sentences,
            model_dev.sentence_mask: sentence_mask
        }

        # Tensorflow gradient computation code
        grads_l2[i * batch_size:(i + 1) * batch_size] = sess.run(model_dev.grads_l2, feed_dict=feed_dict)[:total]

    grad_stats(grads_l2, data, vocab)


def grad_stats(grads_l2, data, vocab):
    avg_grads = []
    avg_no_pad_grads = []
    avg_A_grads = []
    avg_A_no_pad_grads = []
    avg_B_grads = []
    avg_B_no_pad_grads = []

    for grad, instance in zip(grads_l2, data):
        grad_no_pad = np.array([x for token, x in zip(instance['sentence'], grad) if token != 0])
        sent_no_pad = [x for x in instance['sentence'] if x != 0]
        avg_grads.append(np.mean(grad))
        avg_no_pad_grads.append(np.mean(grad_no_pad))

        if vocab['but'] in instance['sentence']:
            # Try to see values of gradient before and after but
            but_location = instance['sentence'].index(vocab['but'])
            avg_A_grads.append(np.mean(grad[:but_location]))
            avg_B_grads.append(np.mean(grad[but_location:]))
            but_no_pad_location = sent_no_pad.index(vocab['but'])
            if but_no_pad_location != 0:
                avg_A_no_pad_grads.append(np.mean(grad_no_pad[:but_no_pad_location]))
                avg_B_no_pad_grads.append(np.mean(grad_no_pad[but_no_pad_location:]))

    logger.info("Average gradients :- %.4f", np.mean(avg_grads))
    logger.info("Average no pad gradients :- %.4f", np.mean(avg_no_pad_grads))
    logger.info("Average A gradients :- %.4f", np.mean(avg_A_grads))
    logger.info("Average A no pad gradients :- %.4f", np.mean(avg_A_no_pad_grads))
    logger.info("Average B gradients :- %.4f", np.mean(avg_B_grads))
    logger.info("Average B no pad gradients :- %.4f", np.mean(avg_B_no_pad_grads))


def analysis(args):
    if args.thread_restrict is True:
        cfg_proto = tf.ConfigProto(intra_op_parallelism_threads=2)
    else:
        cfg_proto = None
    with tf.Session(config=cfg_proto) as sess:
        # Loading the vocabulary files
        vocab, rev_vocab = load_vocab(args)
        args.vocab_size = len(rev_vocab)
        # Creating test model

        train_set = load_pickle(args, split='train')
        args.config.seq_len = train_set[0]['sentence_len']

        with tf.variable_scope("model", reuse=None):
            model_test = SentimentModel(args, queue=None, mode='eval')
        # Reload model from checkpoints, if any
        steps_done = initialize_weights(sess, model_test, args, mode='train')
        logger.info("loaded %d completed steps", steps_done)

        perceptron.append_features(args, train_set, model_test, vocab, rev_vocab)

        dev_set = load_pickle(args, split='dev')
        perceptron.append_features(args, dev_set, model_test, vocab, rev_vocab)

        test_set = load_pickle(args, split='test')
        perceptron.append_features(args, test_set, model_test, vocab, rev_vocab)

        calculate(sess, model_test, test_set, args, vocab, rev_vocab)
