# Utils.py put it on the other modules
import torch
import torch.nn.functional as F
import numpy as np
import math


def change_feat_location(x):
    return torch.transpose(x, -1, 1).contiguous()


def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


# reordering tensor to simulate the backward lstm or gpt.
def reverse_tensor(input_tensor, leng_st):
    # only leng_st is valid
    input_tensor_reversed = torch.zeros_like(input_tensor)
    # leng_st: based on batch
    for sent_idx, leng_sent in enumerate(leng_st):
        # print(origin_tensor)
        input_tensor_reversed[sent_idx][:leng_sent] = torch.flip(input_tensor[sent_idx][:leng_sent], (0,))
    return input_tensor_reversed


def _addindent(s_, numSpaces):
    s = s_.split('\n')
    # don't do anything for single-line stuff
    if len(s) == 1:
        return s_
    first = s.pop(0)
    s = [(numSpaces * ' ') + line for line in s]
    s = '\n'.join(s)
    s = first + '\n' + s
    return s


# order sequences and restore sequences according to their lenghts
# TODO: test the speed of orderSeq and restoreSeq
def orderSeq(seq_unordered, leng_unordered):
    # leng_unordered is a tensor
    # seq_unordered is a numpy
    leng_ordered, seq_index = leng_unordered.sort(descending=True) 
    _, reverse_index = seq_index.sort()
    leng_ordered = leng_ordered[leng_ordered>0]
    seq_index    = seq_index[:len(leng_ordered)]
    seq_ordered  = seq_unordered[seq_index.cpu()]
    return seq_ordered, leng_ordered, reverse_index

def restoreSeq(seq_ordered, reverse_index):
    # shape = list(seq_ordered.shape)
    data_type = seq_ordered.type()
    shape = list(seq_ordered.shape)
    shape[0] = len(reverse_index) - shape[0]
    t = torch.cat([seq_ordered, torch.zeros(shape).type(data_type)])
    seq_restored = t[reverse_index]
    return seq_restored


# padding sequences 
def pad_packed_sequence(var_data, batch_sizes, batch_first=True, padding_value=0.0):

    # var_data, batch_sizes = sequence
    max_batch_size = int(batch_sizes[0])
    max_seq_length = batch_sizes.size(0)

    output = var_data.data.new(max_seq_length, max_batch_size, *var_data.size()[1:]).fill_(padding_value)
    lengths = []
    data_offset = 0
    prev_batch_size = int(batch_sizes[0])
    prev_i = 0
    for i, batch_size in enumerate(batch_sizes.tolist() + [0]):
        if batch_size != prev_batch_size:
            l = prev_batch_size * (i - prev_i)
            tmp = var_data[data_offset:data_offset + l]
            output[prev_i:i, :prev_batch_size] = tmp.view(i - prev_i, prev_batch_size, *tmp.size()[1:])
            data_offset += l
            prev_i = i
        dec = prev_batch_size - batch_size
        if dec > 0:
            lengths.extend((i,) * dec)
        prev_batch_size = batch_size
    lengths.reverse()
    if batch_first:
        output = output.transpose(0, 1)
    return output, torch.LongTensor(lengths)


######################################################################################################################################
# as_token

def reshape_as_token(info, leng_tk, leng_tk_mask, leng_st, misc_info):
    # shape = tensor1.shape 
    # bs, a, b, c = tensor1.shape
    shape = info.shape
    bs, a, b, c = shape
    # tensor1_AT_1 = tensor1.contiguous().view(bs * a, b, c)
    info = info.contiguous().view(bs * a, b, c)
    # leng_tk_AT_1 = leng_tk.view(bs * a)
    leng_tk = leng_tk.view(bs * a)
    # tensor1_AT_2, leng_tk_AT_2, reverse_id = orderSeq(tensor1_AT_1, leng_tk_AT_1)
    info, leng_tk_AT_2, reverse_id = orderSeq(info, leng_tk)
    
    # notice: here leng_tk is leng for info
    return info, leng_tk_AT_2, reverse_id, shape


def restore_as_token_extractor(info, leng_tk, leng_tk_mask, leng_st, misc_info, shape, reverse_id):
    bs, a, b, c_old = shape
    # tensor1_AT_4 = restoreSeq(tensor1_AT_3, reverse_id)
    info = restoreSeq(info, reverse_id)
    # tensor1_AT_5 = tensor1_AT_4.contiguous().view(bs, a, b, tensor1_AT_4.size(-1))
    info = info.contiguous().view(bs, a, b, info.size(-1))
    return info


def restore_as_token_reducer(info, leng_tk, leng_tk_mask, leng_st, misc_info, shape, reverse_id):
    bs, a, b, c_old = shape
    # tensor1_AT_R_4 = restoreSeq(tensor1_AT_R_3, reverse_id)
    info = restoreSeq(info, reverse_id)
    # tensor1_AT_R_5 = tensor1_AT_R_4.contiguous().view(bs, a, tensor1_AT_R_4.size(-1))
    info = info.contiguous().view(bs, a, info.size(-1))
    return info


######################################################################################################################################
# as_sent
def reshape_as_sent(info, leng_tk, leng_tk_mask, leng_st, misc_info):
    # shape = tensor1.shape 
    # bs, a, b, c = tensor1.shape
    shape = info.shape
    bs, a, b, c = shape
    device = info.device
    # tensor1_AS_1 = tensor1.contiguous().view(bs, a * b, c)
    info = info.contiguous().view(bs, a * b, c)
    
    # leng_tk_mask_AS_1 = leng_tk_mask.view(bs, a * b)
    leng_tk_mask = leng_tk_mask.view(bs, a * b)
    
    leng_grain_st = torch.sum(leng_tk, 1)
    
    tensor1_AS_1_5 = torch.zeros(bs, torch.max(leng_grain_st), c).to(device)
    for sent_idx in range(bs):
        
        # index = torch.arange(a*b).masked_select(leng_tk_mask_AS_1[sent_idx] == 0) 
        index = torch.arange(a*b).to(device).masked_select(leng_tk_mask[sent_idx] == 0) 
        
        # print(index.shape)
        # print(tensor1_AS_1[sent_idx].shape)
        
        # tensor1_AS_1_5[sent_idx][:leng_grain_st[sent_idx]] = tensor1_AS_1[sent_idx][index]
        tensor1_AS_1_5[sent_idx][:leng_grain_st[sent_idx]] = info[sent_idx][index]
    
    
    # tensor1_AS_1_5.shape  
    # tensor1_AS_2, leng_grain_st_AS, reverse_id = orderSeq(tensor1_AS_1_5, leng_grain_st)
    info, leng_grain_st_AS_2, reverse_id = orderSeq(tensor1_AS_1_5, leng_grain_st)
    # tensor1_AS_2.shape
    
    # notice: here leng_grain_st_AS_2 is leng for tensor1_AS_2
    return info, leng_grain_st_AS_2, reverse_id, shape

def restore_as_sent_extractor(info, leng_tk, leng_tk_mask, leng_st, misc_info, shape, reverse_id):
    device = info.device
    bs, a, b, c_old = shape
    
    # tensor1_AS_3_5 = restoreSeq(tensor1_AS_3, reverse_id)
    info = restoreSeq(info, reverse_id)

    # tensor1_AS_4 = torch.zeros(bs, a*b, tensor1_AS_3_5.size(-1))
    tensor1_AS_4 = torch.zeros(bs, a*b, info.size(-1)).to(device)

    #################
    # leng_tk_mask_AS_1 = leng_tk_mask.view(bs, a * b)
    leng_tk_mask = leng_tk_mask.view(bs, a * b)
    leng_grain_st = torch.sum(leng_tk, 1)
    #################

    for sent_idx in range(bs):
        # index = torch.arange(a*b).masked_select(leng_tk_mask_AS_1[sent_idx] == 0) 
        index = torch.arange(a*b).to(device).masked_select(leng_tk_mask[sent_idx] == 0) 
        # tensor1_AS_4[sent_idx][index] = tensor1_AS_3_5[sent_idx][:leng_grain_st[sent_idx]]
        tensor1_AS_4[sent_idx][index] = info[sent_idx][:leng_grain_st[sent_idx]]
        
        
    # tensor1_AS_5 = tensor1_AS_4.contiguous().view(bs, a, b, tensor1_AS_4.size(-1))
    info = tensor1_AS_4.contiguous().view(bs, a, b, tensor1_AS_4.size(-1))
    
    return info

def restore_as_sent_reducer_fwd(info, leng_tk, leng_tk_mask, leng_st, misc_info, shape, reverse_id):
    device = info.device
    bs, a, b, c_old = shape
    
    # tensor1_AS_3_5 = restoreSeq(tensor1_AS_3, reverse_id)
    info = restoreSeq(info, reverse_id)

    # tensor1_AS_4 = torch.zeros(bs, a*b, tensor1_AS_3_5.size(-1))
    tensor1_AS_4 = torch.zeros(bs, a*b, info.size(-1)).to(device)

    #################
    # leng_tk_mask_AS_1 = leng_tk_mask.view(bs, a * b)
    leng_tk_mask = leng_tk_mask.view(bs, a * b)
    leng_grain_st = torch.sum(leng_tk, 1)
    #################

    for sent_idx in range(bs):
        # index = torch.arange(a*b).masked_select(leng_tk_mask_AS_1[sent_idx] == 0) 
        index = torch.arange(a*b).to(device).masked_select(leng_tk_mask[sent_idx] == 0) 
        # tensor1_AS_4[sent_idx][index] = tensor1_AS_3_5[sent_idx][:leng_grain_st[sent_idx]]
        tensor1_AS_4[sent_idx][index] = info[sent_idx][:leng_grain_st[sent_idx]]
        
    #--------------------------------------------------
    # these are new codes
    endgrain_idx = torch.arange(leng_tk.size(1)).to(device).unsqueeze(0).expand(leng_tk.size(0), leng_tk.size(1))
    endgrain_idx = (endgrain_idx * b + leng_tk -1).masked_fill(leng_tk == 0, 0)
    
    # tensor1_AS_R_5 = torch.zeros((bs, a, tensor1_AS_4.size(-1)))
    info = torch.zeros((bs, a, info.size(-1))).to(device)
    
    for sent_id in range(bs):
        tk_num = leng_st[sent_id]
        # tensor1_AS_R_5[sent_id][:tk_num] = tensor1_AS_4[sent_id][endgrain_idx[sent_id][:tk_num]]
        info[sent_id][:tk_num] = tensor1_AS_4[sent_id][endgrain_idx[sent_id][:tk_num]]

    return info

def restore_as_sent_reducer_bi(info, leng_tk, leng_tk_mask, leng_st, misc_info, shape, reverse_id):
    # BI: for BI-MIX or BI-SEP
    device = info.device
    bs, a, b, c_old = shape
    
    # tensor1_AS_3_5 = restoreSeq(tensor1_AS_3, reverse_id)
    info = restoreSeq(info, reverse_id)

    # tensor1_AS_4 = torch.zeros(bs, a*b, tensor1_AS_3_5.size(-1))
    tensor1_AS_4 = torch.zeros(bs, a*b, info.size(-1)).to(device)

    #################
    # leng_tk_mask_AS_1 = leng_tk_mask.view(bs, a * b)
    leng_tk_mask = leng_tk_mask.view(bs, a * b)
    leng_grain_st = torch.sum(leng_tk, 1)
    #################

    for sent_idx in range(bs):
        # index = torch.arange(a*b).masked_select(leng_tk_mask_AS_1[sent_idx] == 0) 
        index = torch.arange(a*b).to(device).masked_select(leng_tk_mask[sent_idx] == 0) 
        # tensor1_AS_4[sent_idx][index] = tensor1_AS_3_5[sent_idx][:leng_grain_st[sent_idx]]
        tensor1_AS_4[sent_idx][index] = info[sent_idx][:leng_grain_st[sent_idx]]
        
    #--------------------------------------------------
    tensor1_AS_R_4_fwd, tensor1_AS_R_4_bwd = tensor1_AS_4.chunk(2, -1)
    
    # device = tensor1_AS_4.device
    endgrain_idx = torch.arange(leng_tk.size(1)).to(device).unsqueeze(0).expand(leng_tk.size(0), leng_tk.size(1))
    endgrain_idx = (endgrain_idx * b + leng_tk -1).masked_fill(leng_tk == 0, 0)
    # endgrain_idx

    tensor1_AS_R_5_fwd = torch.zeros((bs, a, tensor1_AS_R_4_fwd.size(-1))).to(device)
    # print(tensor1_AS_R_5_fwd.shape)
    for sent_id in range(bs):
        tk_num = leng_st[sent_id]
        tensor1_AS_R_5_fwd[sent_id][:tk_num] = tensor1_AS_R_4_fwd[sent_id][endgrain_idx[sent_id][:tk_num]]

    # final result
    # print(tensor1_AS_R_5_fwd.shape)
    
    # device = tensor1_AS_4.device
    startgrain_idx = torch.arange(leng_tk.size(1)).to(device).unsqueeze(0).expand(leng_tk.size(0), leng_tk.size(1))
    startgrain_idx = startgrain_idx.masked_fill(leng_tk == 0, 0) * b #+ leng_tk -1
    # startgrain_idx
    
    
    tensor1_AS_R_5_bwd = torch.zeros((bs, a, tensor1_AS_R_4_bwd.size(-1))).to(device)
    # print(tensor1_AS_R_5_bwd.shape)
    for sent_id in range(bs):
        tk_num = leng_st[sent_id]
        tensor1_AS_R_5_bwd[sent_id][:tk_num] = tensor1_AS_R_4_bwd[sent_id][startgrain_idx[sent_id][:tk_num]]

    # final result
    # print(tensor1_AS_R_5_bwd.shape)
    
    info = torch.cat([tensor1_AS_R_5_fwd, tensor1_AS_R_5_bwd], -1)
    # print('in restore & reshape', info.device)
    return info
    

######################################################################################################################################
   

################################
def reshape_untouch(info, leng_tk, leng_tk_mask, leng_st, misc_info):
    # return leng_tk instead of return leng_st, or leng_tk_mask
    return info, leng_tk, None, None
    
def restore_untouch(info, leng_tk, leng_tk_mask, leng_st, misc_info, shape, reverse_id):
    return info
################################



# define: reshape and restore tensors under different scenario.
def reshape_grain_sequene(info, leng_tk, shape):
    device = info.device
    batch_size, st_size, tk_size, embed_size = shape

    leng_tk = leng_tk.view(batch_size, -1)
    info = info.reshape(-1, embed_size)
    leng_tk = leng_tk.reshape(-1)
    start = torch.arange(0, batch_size * st_size * tk_size, tk_size).to(device)
    grain_idx = []
    for i in range(0, tk_size):
        nonzero_idx = F.relu(leng_tk - i) != 0
        grain_idx.append(start[nonzero_idx] + leng_tk[nonzero_idx] - i - 1)
    grain_idx = torch.cat(grain_idx, dim=0).sort()[0]
    new_info = torch.zeros((batch_size, st_size * tk_size, embed_size)).to(device)
    leng = []
    for i in range(batch_size):
        ith_batch_idx = grain_idx[(grain_idx >= i * st_size * tk_size) & (grain_idx < (i + 1) * st_size * tk_size)]
        new_info[i, 0:len(ith_batch_idx)] = info[ith_batch_idx]
        leng.append(len(ith_batch_idx))
    leng = torch.tensor(leng).to(device)
    return new_info[:, 0:leng.max(dim=0)[0]], grain_idx, leng

def restore_grain_sequence(info, grain_idx, shape):
    device = info.device
    batch_size, st_size, tk_size, embed_size = shape

    new_info = torch.zeros((batch_size * st_size * tk_size, embed_size)).to(device)
    for i in range(batch_size):
        ith_batch_idx = grain_idx[(grain_idx >= i * st_size * tk_size) & (grain_idx < (i + 1) * st_size * tk_size)]
        new_info[ith_batch_idx] = info[i, 0:len(ith_batch_idx)]
    new_info = new_info.reshape(batch_size, st_size, tk_size, embed_size)
    return new_info

def generate_square_subsequent_mask(st_size):
    r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
    """
    mask = (torch.triu(torch.ones(st_size, st_size)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def get_leng_mask(leng):
    # leng : [st_length(size)] example: batch_size is 4, leng is a 4-dim tensor(array)
    device = leng.device
    batch_size = leng.shape[0]
    
    # max_st_leng is the length of longest sequence in a batch (always change) 
    max_st_leng = max(leng)

    mask = torch.zeros(batch_size, max_st_leng).to(device).byte()
    for i, j in enumerate(leng):
        mask[i, j:max_st_leng] = 1
        
    # mask_size :[batch_size,max_st_length]
    return mask

# not using
def LM_mask(leng):
    leng = leng.cpu().numpy()

    # leng : [st_length(size)] example: batch_size is 4, leng is a 4-dim tensor(array)
    batch_size = leng.shape[0]

    # max_st_leng is the length of longest sequence in a batch (always change)
    max_st_leng = leng[0]

    mask = [[],[]]

    for i , j in enumerate(leng):

        mask[0] = np.append(mask[0],np.arange(j-1)+i*max_st_leng)
        mask[1] = np.append(mask[1],np.arange(1,j)+i*max_st_leng)


    mask[0] = torch.LongTensor(mask[0])
    mask[1] = torch.LongTensor(mask[1])

    return mask



