Machine learning 具有动态批量大小的批量矩阵乘法

Machine learning 具有动态批量大小的批量矩阵乘法,machine-learning,pytorch,matrix-multiplication,transformation-matrix,Machine Learning,Pytorch,Matrix Multiplication,Transformation Matrix,我正在编写一个使用批处理矩阵乘法的程序,可能不是在普通设置下。我正在考虑下列投入: # Let's say I have a list of points in R^3, from 3 distinct objects # (so my data batch has 3 data entry) # X: (B1+B2+B3) * 3 X = torch.tensor([[1,1,1],[1,1,1], [2,2,2],[2,2,2],[2,2,2],

我正在编写一个使用批处理矩阵乘法的程序,可能不是在普通设置下。我正在考虑下列投入:

# Let's say I have a list of points in R^3, from 3 distinct objects
# (so my data batch has 3 data entry)
# X: (B1+B2+B3) * 3 
X = torch.tensor([[1,1,1],[1,1,1],
                  [2,2,2],[2,2,2],[2,2,2],
                  [3,3,3],])
# To indicate which object the points are corresponding to,
# I have a list of indices (say, starting from 0):
# idx: (B1+B2+B3)
idx = torch.tensor([0,0,1,1,1,2])
# For each point from the same object, I want to multiply it to a 3x3 matrix, A_i.
# As I have 3 objects here, I have A_0, A_1, A_2.
# A: 3 x 3 x 3
A = torch.tensor([[[1,1,1],[1,1,1],[1,1,1]],
                  [[2,2,2],[2,2,2],[2,2,2]],
                  [[3,3,3],[3,3,3],[3,3,3]]])
所需输出为:

out = X.unsqueeze(1).bmm(A[idx])
out = out.squeeze(1)                # just to remove excessive dimension
# out = torch.tensor([[[1,1,1]],[[1,1,1]],            # obj0 mult with A_0
                      [[2,2,2]],[[2,2,2]],[[2,2,2]],  # obj1 mult with A_1
                      [[3,3,3]],])                    # obj2 mult with A_2
它实际上在pytorch中非常方便,只有一行

在这里,我想改进这个程序。请注意,我使用[idx]为每个点复制一个矩阵A_I,因此我可以在这里使用torch.bmm()函数(1点1矩阵)。首先,它需要为[idx]的中间表示分配内存。通常,如果我的数据批中有BN个对象,那么[idx]=(B1+…+BN)*3*3的大小可能非常大

因此,我想知道是否可以避免矩阵A_I的复制

我已经找到了以前提出的有关Batch Mat的大多数问题。穆特。仅假设固定批量大小。提出了与我相同的问题,并提供了tensorflow中的解决方案。但是,解决方案是使用tf.tile()实现的,它也在复制矩阵


总之,我的问题是关于批量矩阵乘法,同时实现:

- dynamic batch size
    - input shape: (B1+...+BN) x 3
    - index shape: (B1+...+BN)
- memory efficiency
    - probably w/out massive replication of matrix
我在这里使用pytorch,但我也接受其他实现。我也接受在其他结构中表示输入(例如,要乘法的矩阵,A),如果它能提高内存效率的话