import torch
def mini_batch(x, edge_index, seq_len, edge_attr=None):
    batch_size=x.size(0)
    # if edge_index.size(0)!=batch_size or edge_attr.size(0)!=batch_size:
    #     print("error batch.py")
    #     exit()
    sent_len=x.size(1)
    x=x.reshape(batch_size* sent_len,-1).squeeze()
    edge_index_list=list()
    for i in range(batch_size):
        
        e_i=edge_index[i,:,:seq_len[i][1]]
        p_num=i*sent_len
        # p=[[p_num for i in range(seq_len[i][1])],[p_num for i in range(seq_len[i][1])]]
        # pad=torch.tensor(p).float()
        edge_index_list.append(e_i+p_num)
    edge_index=torch.cat(edge_index_list,dim=1)
    if edge_attr==None:
        return x,edge_index
    edge_attr_list=list()
    for i in range(batch_size):
        e_a=edge_attr[i,:seq_len[i][1],]
        edge_attr_list.append(e_a)
    edge_attr=torch.cat(edge_attr_list,dim=0)
    return x,edge_index,edge_attr

def mini_batch_pad(x,edge_index, edge_attr=None):
    batch_size, sent_len = x.size(0), x.size(1)
    x=x.reshape(batch_size* sent_len,-1).squeeze()
    edge_len = edge_index.size(1)
    edge_index_list = list()
    for i in range(batch_size):
        e_i = edge_index[i,:,:]
        p_num = i*edge_len
        edge_index_list.append(e_i+p_num)
    edge_index=torch.cat(edge_index_list,dim=1)
    if edge_attr== None:
        return x,edge_index
    else:
        edge_attr=edge_attr.reshape(batch_size*edge_len,-1).squeeze()
        return x,edge_index,edge_attr

def mini_batch_list(x,edge_index,edge_atter=None):
    batch_size, seq_len = x.size(0),x.size(1)
    x=x.reshape(batch_size* seq_len,-1).squeeze()
    if len(edge_index)!=batch_size:
        print("edge_index dont match batch_size")
        exit()
    edge_index_list=[e_i+(i*seq_len) for i,e_i in enumerate(edge_index)]
    edge_index=torch.cat(edge_index_list,dim=1)
    if edge_atter == None:
        return x,edge_index
    else:
        edge_attr = torch.cat(edge_index_list,dim=0)
        return x,edge_index,edge_atter


