Python 沿特定维度和特定通道索引整个张量

Python 沿特定维度和特定通道索引整个张量,python,numpy,tensorflow,indexing,pytorch,Python,Numpy,Tensorflow,Indexing,Pytorch,假设我们有一个维度为dim(a)=[i,j,k=6,u,v]的张量a。现在我们有兴趣得到通道=[0:3]在维度k处的整个张量。我知道我们可以这样做: B = A[:, :, 0:3, :, :] 现在,我想知道是否有更好的“pythonic”方法来实现相同的结果,而不必进行这种次优索引。我的意思是类似于 B = subset(A, dim=2, index=[0, 1, 2]) 无论在哪个框架中,即Pytork、tensorflow、numpy等 非常感谢numpy中的,您可以使用take方

假设我们有一个维度为dim(a)=[i,j,k=6,u,v]的张量a。现在我们有兴趣得到通道=[0:3]在维度k处的整个张量。我知道我们可以这样做:

B = A[:, :, 0:3, :, :]
现在,我想知道是否有更好的“pythonic”方法来实现相同的结果,而不必进行这种次优索引。我的意思是类似于

B = subset(A, dim=2, index=[0, 1, 2])
无论在哪个框架中,即Pytork、tensorflow、numpy等


非常感谢numpy中的,您可以使用
take
方法:

B = A.take([0,1,2], axis=2)

在TensorFlow中,没有比使用传统方法更简洁的方法了。使用
tf.slice
会非常冗长:

B = tf.slice(A,[0,0,0,0,0],[-1,-1,3,-1,-1])
您可以潜在地使用试验版的
take
(自TF 2.4起):


在PyTorch中,您可以使用
索引\u选择

torch.index_select(A, dim=2, index=torch.tensor([0,1,2]))

请注意,您可以使用
省略号来跳过第一个维度(或最后一个维度)的显式列出:

# Both are equivalent in that case
B = A[..., 0:3, :, :]
B = A[:, :, 0:3, ...]

@非常感谢,我还发现PyTorch中的“torch.index_select”也有同样的功能。我会将其添加到答案中。谢谢。
tf.gather
将不必要的复杂,在您的情况下,您必须构建索引数组,例如
tf.gather(A,tf.tile([[0,1,2]],[10,10,1]),batch_dims=2)
# Both are equivalent in that case
B = A[..., 0:3, :, :]
B = A[:, :, 0:3, ...]