import torch 
import torch.nn
import torch.nn.functional
import numpy as np
import tensorflow as tf
import torchsample

from torch import optim
'''
addtional part -- torch version
'''
def listwisemle(corr_mat, rank_mat, mask, maxlen):
    corr_mat = mask * corr_mat
    loss = torch.sum(corr_mat, dim = -1)
    exp_output = mask * torch.exp(corr_mat)  # N, T, T
    sum_exp_output = torch.sum(exp_output, dim=-1) + 0.00001  # N, T


    batch_idx = torch.reshape((torch.range(torch.shape(rank_mat)[0]).unsqueeze(dim=1)).repeat(1, maxlen), [-1])
    time_idx = torch.reshape((torch.range(maxlen).unsqueeze(dim=0)).repeat(torch.shape(rank_mat)[0], 1), [-1])
    
    for i in range(maxlen):
        loss -= mask[:, :, i] * torch.log(sum_exp_output)
        rank_idx = rank_mat[:, :, i]
        rank_idx = torch.reshape(rank_idx, [-1])
        idx = torch.stack([batch_idx, time_idx, rank_idx]).t() #transpose or permute
        sum_exp_output -= torch.reshape(torchsample.th_gather_nd(exp_output, idx), [-1, maxlen])
    return -torch.sum(loss) / torch.sqrt(torch.sum(mask))


def positional_encoding(dim, sentence_length, dtype=torch.float32):
    encoded_vec = np.array([pos / np.power(10000, 2 * i / dim) for pos in range(sentence_length) for i in range(dim)])
    encoded_vec[::2] = np.sin(encoded_vec[::2])
    encoded_vec[1::2] = np.cos(encoded_vec[1::2])

    return torch.from_numpy(encoded_vec.reshape([sentence_length, dim]), dtype=dtype)


def normalize(inputs, epsilon=1e-8): #scope="ln", reuse=None
    '''Applies layer normalization.

    Args:
      inputs: A tensor with 2 or more dimensions, where the first dimension has
        `batch_size`.
      epsilon: A floating number. A very small number for preventing ZeroDivision Error.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.

    Returns:
      A tensor with the same shape and data dtype as `inputs`.
    '''
    inputs_shape = inputs.get_shape()
    params_shape = inputs_shape[-1:]  # len

    mean = torch.mean(inputs, dim=-1, keepdim=True)
    variance = torch.var(inputs, dim=-1, keep_dims=True)
    beta = torch.Variable(torch.zeros(params_shape))
    gamma = torch.Variable(torch.ones(params_shape))
    normalized = (inputs - mean) / ((variance + epsilon) ** (.5))
    outputs = gamma * normalized + beta

    return outputs

def embedding(inputs, vocab_size, num_units, zero_pad=True, scale=True, weight_decay=1e-5, scope="embedding", with_t=False,
              reuse=None):
    # with tf.variable_scope(scope, reuse=reuse):
    lookup_table = torch.Variable(torch.randn(vocab_size, num_units))
    
    # L2 regularization  methods not sure

    # optimizer = torch.optim.Adam(model.parameters(),lr=1e-4,weight_decay=weight_decay)

    if zero_pad:
        lookup_table = torch.cat((torch.zeros(shape=[1, num_units]), lookup_table[1:, :]), 0)
    outputs = torch.index_select(lookup_table, 0, inputs)

    if scale:
        outputs = outputs * (num_units ** 0.5)
    if with_t: return outputs, lookup_table
    else: return outputs


'''
additional part
'''
def cal_rbf(hidden_emb, maxlen, num_units, num_heads):
    squared_sum = torch.sum(torch.square(hidden_emb), axis=-1, keepdims=True).repeat(1, 1, maxlen)  # (h*N, T, T)
    multiplied = torch.matmul(hidden_emb, torch.permute(hidden_emb, (0, 2, 1)))  # (h*N, T, T)
    corr_temp = torch.exp(-(squared_sum - 2 * multiplied + torch.permute(squared_sum, (0, 2, 1))) / (
    (num_units / num_heads) ** 1.3))
    return corr_temp

'''
additional part
'''
def cal_mul(hidden_emb):
    corr_temp = torch.matmul(hidden_emb, torch.permute(hidden_emb, (0, 2, 1)))
    return corr_temp

'''
changed part
'''
def multihead_attention(queries, keys, avg_mat, rank_mat, corr_mat, maxlen, user_emb,kernel_type='mul',
                        shape_gamma=5, shape_learnable=False, num_units=None, num_heads=8, kernel='',
                        dropout_rate=0, is_training=True, sampling=False, with_qk=False, prior='skew_norm'):
    kernel = kernel.replace('_rbf','')
    # with tf.variable_scope(scope, reuse=reuse):
    # Set the fall back option for num_units
    if num_units is None:
        num_units = list(queries.size(-1))
        # num_units = queries.get_shape().as_list[-1]

    # Key Masking
    key_masks = torch.sign(torch.abs(torch.sum(keys, axis=-1)))  # (N, T_k)
    key_masks = key_masks.repeat(num_heads, 1)  # (h*N, T_k)
    key_masks = key_masks.unsqueeze(1).repeat(1, queries.size()[1], 1)  # (h*N, T_q, T_k)

    # Query Masking
    masks_total = key_masks * torch.permute(key_masks, (0, 2, 1))

    corr_mats = []
    if 'count' in kernel:
        corr_mat = corr_mat.repeat(num_heads, 1, 1)
        corr_mats.append(corr_mat)
        
    hidden_emb_init = torch.cat(torch.split(queries, num_heads, axis=2), axis=0)  # (h*N, T, C/h)
    user_emb_init = torch.cat(torch.split(user_emb, num_heads, axis=1), axis=0)  # (h*N, C/h)

    if 'item' in kernel:
        hidden_emb = torch.nn.functional.normalize(hidden_emb_init, p=2, dim=-1)  # (h*N, T, C/h)
        if kernel_type == 'mul':
            corr_mats.append(cal_mul(hidden_emb))
        else:
            corr_mats.append(cal_rbf(hidden_emb, maxlen, num_units, num_heads))

    if 'user' in kernel:
        user_item_sim_w = torch.Variable(torch.randn(num_units//num_heads, num_units//num_heads))
        user_item_sim = torch.matmul(user_emb_init, user_item_sim_w)
        user_item_sim = torch.nn.functional.normalize(user_item_sim[:, :, None] * torch.permute(hidden_emb_init, (0, 2, 1)), p=2, dim=-1)
        corr_mats.append(torch.matmul(torch.permute(user_item_sim, (0, 2, 1)), user_item_sim))

    if len(kernel.split('_')) > 1:
        kernel_weight_w = torch.Variable(torch.randn(num_units//num_heads, len(kernel.split('_')))) #
        kernel_weight_b = torch.Variable(torch.zeros_like(len(kernel.split('_'))))
        kernel_weight = torch.nn.softmax(torch.matmul(user_emb_init, kernel_weight_w) + kernel_weight_b)
        corr_mat = torch.sum(torch.stack(corr_mats, axis=1) * kernel_weight[:, :, None, None], axis=1)
    else:
        corr_mat = corr_mats[0]

    corr_mat += torch.eye(torch.shape(queries)[1]) * 0.00001
    l_mat = torch.linalg.cholesky(corr_mat) # (h*N, T, T)

    if any([x in kernel for x in ['item', 'user']]):
        loss = listwisemle(corr_mat, rank_mat, masks_total, maxlen)
    else:
        loss = 0

    # Linear projections

    # Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C)
    # K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C)
    # V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C)
    Q = torch.nn.functional.linear(queries, num_units)  # (N, T_q, C)
    K = torch.nn.functional.linear(keys, num_units)  # (N, T_k, C)
    V = torch.nn.functional.linear(keys, num_units)  # (N, T_k, C)

    # Split and concat
    Q_ = torch.cat(torch.split(Q, num_heads, axis=2), axis=0)  # (h*N, T_q, C/h)
    K_ = torch.cat(torch.split(K, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)
    V_ = torch.cat(torch.split(V, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)

    scale_factor = (K_.size()[-1] ** 0.5)

    # Multiplication
    outputs = torch.bmm(Q_, K_.transpose(2, 1)) # (h*N, T_q, T_k)
    # masks
    diag_vals = torch.ones_like(outputs[0, :, :])  # (T_q, T_k)
    tril = torch.tril(diag_vals).to_dense()  # (T_q, T_k) #not sure
    masks = tril.unsqueeze(0).repeat(torch.shape(outputs)[0], 1, 1) * masks_total  # (h*N, T_q, T_k)
    paddings = torch.ones_like(key_masks) * (-2 ** 32 + 1)

    # Scale
    outputs = outputs / scale_factor
    # outputs = tf.nn.softmax(tf.where(tf.equal(masks, 0), paddings, outputs))  # (h*N, T_q, T_k)
    pre_outputs = outputs[:, -1, :]

    # variance
    var_Q = torch.nn.functional.linear(queries, num_units)
    var_K = torch.nn.functional.linear(keys, num_units)

    var_Q_ = torch.cat(torch.split(var_Q, num_heads, axis=2), axis=0)
    var_K_ = torch.cat(torch.split(var_K, num_heads, axis=2), axis=0)

    variance = torch.matmul(var_Q_, torch.permute(var_K_, (0, 2, 1)))  # h*N, T_q, T_k
    variance = torch.nn.functional.softplus(variance / scale_factor)

    # normal sampling -> multiply correlation
    shape = torch.matmul(avg_mat, torch.permute(avg_mat, (0, 2, 1)))  # (N, T, T)

    max_shape = torch.max(shape, -1, keepdims=True)
    shape = (shape) / (max_shape + 0.000001)
    shape = shape.repeat(num_heads, 1, 1)

    if shape_learnable != 'none':
        shape_Q = torch.nn.functional.linear(queries, num_units)
        shape_K = torch.nn.functional.linear(keys, num_units)

        shape_Q_ = torch.cat(torch.split(shape_Q, num_heads, axis=2), axis=0)
        shape_K_ = torch.cat(torch.split(shape_K, num_heads, axis=2), axis=0)

        if shape_learnable == 'all':
            shape_gamma = torch.matmul(shape_Q_, torch.permute(shape_K_, (0, 2, 1))) / scale_factor
        else:
            shape_gamma = torch.nn.functional.softplus(torch.matmul(shape_Q_, torch.permute(shape_K_, (0, 2, 1))) / scale_factor)  # (h*N, T, T)
    shape = shape_gamma * shape
    delta = shape / torch.sqrt(1 + torch.square(shape))  # (h*N, T, T)

    def train_sampling():
        if prior == 'skew_norm':
            y = torch.normal(0, 1, size=torch.shape(outputs))
            x = outputs + variance * y
            return y, x
        else:
            y = torch.normal(0, 1, size=torch.shape(outputs)) * 10 # (h*N, T, T)
            y = torch.matmul(y, l_mat.transpose(2, 1))
            y_0 = torch.abs(torch.normal(0, 1, torch.shape(outputs)[:2]))  # (h*N, T)
            y_0 = y_0.unsqueeze(-1).repeat([1, 1, torch.shape(outputs)[-1]])
            z = y_0 * delta + torch.sqrt(1 - torch.square(delta)) * y
            x = outputs + variance * z
            return z, x

    def evaluate_sampling():
        z = torch.zeros_like(outputs)
        if prior == 'skew_norm':
            return z, outputs + variance * delta * tf.sqrt(2 / np.pi)
        else:
            return z, outputs
    z, outputs = torch.where(torch.equal(is_training, sampling), lambda: train_sampling(), lambda: evaluate_sampling())
        #torch.where 能不能操作函数

    # Causality = Future blinding
    outputs = torch.where(torch.eq(masks, 0), paddings, outputs)  # (h*N, T_q, T_k)
    pre_outputs = torch.where(torch.eq(masks[:, -1, :], 0), paddings[:, -1, :], pre_outputs)

    # Activation
    outputs = torch.nn.Softmax(dim=2)(outputs) # (h*N, T_q, T_k)
    pre_outputs = torch.nn.softmax(pre_outputs)

    attention_output = outputs[:, -1, :]

    # Dropouts
    outputs = torch.nn.functional.dropout(outputs, p=dropout_rate, training=is_training)

    # Weighted sum
    outputs = torch.matmul(outputs, V_)  # ( h*N, T_q, C/h)

    # Restore shape
    outputs = torch.cat(torch.split(outputs, num_heads, axis=0), axis=2)  # (N, T_q, C)

    # Residual connection
    outputs += queries

    # Normalize
    # outputs = normalize(outputs) # (N, T_q, C)

    if with_qk: return Q, K
    else: return outputs, loss, attention_output, corr_mat, shape[:, -1, :], pre_outputs

def feedforward(inputs, num_units=[2048, 512], dropout_rate=0.2, is_training=True):
    '''Point-wise feed forward net.

    Args:
      inputs: A 3d tensor with shape of [N, T, C].
      num_units: A list of two integers.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.

    Returns:
      A 3d tensor with the same shape and dtype as inputs
    '''
    # with tf.variable_scope(scope, reuse=reuse):
    # Inner layer
    params = {"inputs"  : inputs, "filters": num_units[0], "kernel_size": 1, "activation": tf.nn.relu,
                "use_bias": True}
    outputs = torch.nn.functional.conv1d(input=inputs, weight=num_units[0], stride=1, )  # N, T, num_units[0]
    outputs = torch.nn.functional.dropout(outputs, p=dropout_rate, training=is_training)
    # Readout layer
    params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1, "activation": None, "use_bias": True}
    outputs = torch.nn.functional.conv1d(input=inputs, weight=num_units[0], stride=1, )
    outputs = torch.nn.functional.dropout(outputs, p=dropout_rate, training=is_training)

    # Residual connection
    outputs += inputs

    # Normalize
    # outputs = normalize(outputs)

    return outputs

