Python 高级索引正在返回形状错误的数组

Python 高级索引正在返回形状错误的数组,python,arrays,numpy,Python,Arrays,Numpy,我一直用它来理解高级索引。一个具体例子如下: cols = x[:, k, i, j] 示例 假设x.shape为(10,20,30),ind为(2,3,4)形索引intp数组,则result=x[…,ind,:]具有形状(10,2,3,4,30),因为(20,)形子空间已替换为(2,3,4)形广播索引子空间。如果我们让i,j,k在(2,3,4)形子空间上循环,那么结果[…,i,j,k,:]=x[…,ind[i,j,k],:]。此示例生成与x.take(ind,axis=-2)相同的结果 我试

我一直用它来理解高级索引。一个具体例子如下:

cols = x[:, k, i, j]
示例 假设x.shape为(10,20,30),ind为(2,3,4)形索引intp数组,则result=x[…,ind,:]具有形状(10,2,3,4,30),因为(20,)形子空间已替换为(2,3,4)形广播索引子空间。如果我们让i,j,k在(2,3,4)形子空间上循环,那么结果[…,i,j,k,:]=x[…,ind[i,j,k],:]。此示例生成与x.take(ind,axis=-2)相同的结果

我试着理解这一点已经有一段时间了,为了帮助我,我有了一个生成一些数组的小脚本。我有

Indexing arrays
i => 12 x 25
j => 12 x 25
k => 12 x 1

Input array
x => 2 x 3 x 4 x 4

Output Array
Cols => 2 x 12 x 25
我用来制作Cols的代码如下:

cols = x[:, k, i, j]
根据我对示例的理解,cols实际上应该具有形状(2x12x1x12x25x12x25)。我得出如下结论:

cols = x[:, k, i, j]
它的原始尺寸是2x3x4x4

2未更改,但所有其他尺寸均已更改

将3替换为k,一个12 x 1的阵列

前4个替换为i,一个12 x 25的阵列

第二个4由j代替,j也是一个12 x 25阵列


很明显,我在这里误解了什么,我哪里出了问题?

这正是你想要的:

i=np.random.randint(0,4,(12,25))
j=np.random.randint(0,4,(12,25))
k=np.random.randint(0,3,(12,1))
x=np.random.randint(1,11,(2,3,4,4))

x1 = x[:,k,:,:][:,:,:,i,:][:,:,:,:,:,j]
x1.shape

(2, 12, 1, 12, 25, 12, 25)
为什么最初的方法不是这样工作的?我认为高级索引在确定是否同时按多个维度进行索引时可能是贪婪的。例如,您的原始形状:

x.shape
(2,3,4,4) 
可以从很多方面来解释。您想要的是每个轴都是独立的,但将其解释为6
(4,4)
矩阵或2
(3,4,4)
张量同样有效。因此,当通过
[…,i,j]
进行索引时,您可以将
i
解释为在第三个轴上,将
j
解释为在第四个轴上,或者将
i,j
解释为在最后两个轴上。Numpy猜你指的是第二个:

x[...,i,j].shape
(2,3,12,25)
您还可以将
x
解释为8
(3,4)
矩阵,这是执行以下操作时发生的情况:

x[:,k,i,:].shape
(2,12,25,4)
请注意,is还将您的
(12,1)
k
数组广播到
(12,25)
,以便匹配
i
进行索引。您可以使用
k
上的
.squeeze()
确认正在进行广播:

x[:,k.squeeze(),i,:]
Traceback (most recent call last):
IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (12,) (12,25)
如果将
x
解释为2
(3,4,4)
张量,则numpy同时执行这两种操作。它将
k
广播到
(12,25)
,然后根据一组三个
(12,25)
索引数组对最后三个维度进行索引,将所有三个维度作为一个单元进行缩减

您可以使用
np.ix\uuu
在某种程度上覆盖此行为,但是
np.ix\uuu
的所有参数都必须是1d,因此如果不进行展平和重塑,您在那里就很倒霉,这在某种程度上违背了此处的目的,但也起到了作用:

x2 = x[np.ix_(np.arange(x.shape[0]), k.flat, i.flat, j.flat)].reshape((x.shape[0], ) + k.shape + i.shape + j.shape)

x2.shape
(2, 12, 1, 12, 25, 12, 25)

np.all(x1 == x2)
True

谢谢你,丹尼尔!我真的很想了解numpy是如何创建Cols数组的,但我就是无法在头脑中容纳那么多维度!是的,我不是那种人。试着这样看:如果
a.shape=(5,5)
,那么
a[np.arange(5),np.arange(5)]
a
不同,它是
a
的对角线,或者
[a[1,1],a[2,2],a[3,3],a[4,4],a[5]
。换句话说,
a[np.arange(5),np.arange(5)]。shape==arange(5)。shape
。所有三个索引数组都广播到(12,25)。广播是理解多索引数组如何工作的核心。使用(n,)和(n,)进行索引生成(n,)输出;使用(n,1)和(m)进行索引生成(n,m)输出。切片增加了一个维度,但如果切片位于索引元组的中间,则有很大的变化。此外,从DOCS部分:高级索引总是被广播和迭代为:。相反,您单独应用它们。