
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

def shard_features(x, num_datashards):
    x = tf.convert_to_tensor(x)
    batch_size = tf.shape(x)[0]
    size_splits = []

    for i in range(num_datashards):
        size_splits.append(
            tf.cond(tf.greater(tf.mod(batch_size, num_datashards), i),
                    lambda: batch_size // num_datashards + 1,
                    lambda: batch_size // num_datashards)
        )
    
    return tf.split(x, size_splits, axis=0)


def linear(inputs, output_size, bias, concat=True, dtype=None, scope=None):
    with tf.variable_scope(scope, default_name="linear", values=[inputs], dtype=dtype):
        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        input_size = [item.get_shape()[-1].value for item in inputs]

        if len(inputs) != len(input_size):
            raise RuntimeError("inputs and input_size unmatched!")

        output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]], axis=0)
        
        inputs = [tf.reshape(inp, [-1, inp.shape[-1].value]) for inp in inputs]
        
        results = []
        if concat:
            input_size = sum(input_size)
            inputs = tf.concat(inputs, 1)

            shape = [input_size, output_size]
            matrix = tf.get_variable("matrix", shape)
            results.append(tf.matmul(inputs, matrix))
        else:
            for i in range(len(input_size)):
                shape = [input_size[i], output_size]
                name = "matrix_%d" % i
                matrix = tf.get_variable(name, shape)
                results.append(tf.matmul(inputs[i], matrix))
        
        output = tf.add_n(results)

        if bias:
            shape = [output_size]
            bias = tf.get_variable("bias", shape)
            output = tf.nn.bias_add(output, bias)

        output = tf.reshape(output, output_shape)

        return output

def layer_norm(inputs, epsilon=1e-6, name=None, reuse=None):
    with tf.variable_scope(name, default_name="layer_norm", values=[inputs], reuse=reuse):
        channel_size = inputs.get_shape().as_list()[-1]

        scale = tf.get_variable("layer_norm_scale", [channel_size], initializer=tf.ones_initializer())

        offset = tf.get_variable("layer_norm_offset", [channel_size], initializer=tf.zeros_initializer())
        
        mean = tf.reduce_mean(inputs, -1, True)
        variance = tf.reduce_mean(tf.square(inputs - mean), -1, True)

        norm_inputs = (inputs - mean) * tf.rsqrt(variance + epsilon)

        return norm_inputs * scale + offset

