多维numpy数组的对角线

多维numpy数组的对角线,numpy,multidimensional-array,diagonal,Numpy,Multidimensional Array,Diagonal,是否有一种更具蟒蛇风格的方法来执行以下操作: import numpy as np def diagonal(A): (x,y,y) = A.shape diags = [] for a in A: diags.append(np.diagonal(a)) result = np.vstack(diags) assert result.shape == (x,y) return result 假设A将是一个具有形状(m,n,n)的数组(即A可以解

是否有一种更具蟒蛇风格的方法来执行以下操作:

import numpy as np
def diagonal(A):
    (x,y,y) = A.shape
    diags = []
    for a in A: diags.append(np.diagonal(a))
    result = np.vstack(diags)
    assert result.shape == (x,y)
    return result

假设
A
将是一个具有形状(m,n,n)的数组(即
A
可以解释为具有形状
(n,n)
m
数组的集合),下面是一个返回输入视图的快速方法:

In [14]: from numpy.lib.stride_tricks import as_strided

In [15]: def diags(a):
   ....:     b = as_strided(a, strides=(a.strides[0], a.strides[1]+a.strides[2]), shape=(a.shape[0], a.shape[1]))
   ....:     return b
   ....: 

In [16]: a
Out[16]: 
array([[[8, 6, 6, 5],
        [1, 0, 3, 5],
        [8, 1, 6, 7],
        [2, 8, 7, 1]],

       [[0, 8, 8, 0],
        [1, 4, 2, 4],
        [1, 4, 5, 6],
        [2, 5, 2, 7]],

       [[5, 2, 5, 2],
        [2, 5, 7, 6],
        [6, 5, 1, 8],
        [7, 6, 5, 8]]])

In [17]: diags(a)
Out[17]: 
array([[8, 0, 6, 1],
       [0, 4, 5, 7],
       [5, 5, 1, 8]])
当我说返回值是一个视图时,我的意思是它引用与输入相同的底层内存。因此,如果以后更改返回值,原始输入也会更改。比如说,

In [24]: d = diags(a)

In [25]: d[0, :] = 99

In [26]: a[0]
Out[26]: 
array([[99,  6,  6,  5],
       [ 1, 99,  3,  5],
       [ 8,  1, 99,  7],
       [ 2,  8,  7, 99]])
方法#1

一种干净的方法是使用输入数组的转置版本,如下所示-

np.diagonal(A.T)
基本上,我们使用
A.T
翻转输入数组的维度,让
np.diagonal
使用最后两个轴来提取对角线元素,因为默认情况下,它会使用前两个轴。最好的情况是,这适用于任意维数的数组

方法#2

下面是另一种使用-

还可以对
基本索引
-

out = A.reshape(m,-1)[:,np.eye(n,dtype=bool).ravel()]
样本运行-

In [87]: A
Out[87]: 
array([[[73, 52, 62],
        [20,  7,  7],
        [ 1, 68, 89]],

       [[15, 78, 98],
        [24, 22, 35],
        [19,  1, 91]],

       [[ 5, 37, 64],
        [22,  4, 43],
        [84, 45, 12]],

       [[24, 45, 42],
        [70, 45,  1],
        [ 6, 48, 60]]])

In [88]: np.diagonal(A.T)
Out[88]: 
array([[73,  7, 89],
       [15, 22, 91],
       [ 5,  4, 12],
       [24, 45, 60]])

In [89]: m,n = A.shape[:2]

In [90]: A[np.arange(m)[:,None],np.eye(n,dtype=bool)]
Out[90]: 
array([[73,  7, 89],
       [15, 22, 91],
       [ 5,  4, 12],
       [24, 45, 60]])

请详细说明
A
是什么。你认为它是一个
n*m*m
matrix吗?@evan058是的,我试图用对角线()函数的第一行来表达这个假设。谢谢,看起来很像python。尽管如此,它看起来不是特别干净或直观。有更好的方法吗?“干净”、“直观”、“更好”。。。这取决于你对这些词的定义。话虽如此,@Divakar的回答看起来很好(和往常一样!),尤其是方法1。我想你再也洗不干净了。
In [87]: A
Out[87]: 
array([[[73, 52, 62],
        [20,  7,  7],
        [ 1, 68, 89]],

       [[15, 78, 98],
        [24, 22, 35],
        [19,  1, 91]],

       [[ 5, 37, 64],
        [22,  4, 43],
        [84, 45, 12]],

       [[24, 45, 42],
        [70, 45,  1],
        [ 6, 48, 60]]])

In [88]: np.diagonal(A.T)
Out[88]: 
array([[73,  7, 89],
       [15, 22, 91],
       [ 5,  4, 12],
       [24, 45, 60]])

In [89]: m,n = A.shape[:2]

In [90]: A[np.arange(m)[:,None],np.eye(n,dtype=bool)]
Out[90]: 
array([[73,  7, 89],
       [15, 22, 91],
       [ 5,  4, 12],
       [24, 45, 60]])