Python 在自定义函数中实现轴参数

Python 在自定义函数中实现轴参数,python,numpy,array-broadcasting,Python,Numpy,Array Broadcasting,我正在编写一个相当简单的函数,用于在对数空间中应用梯形规则执行积分 我想添加axis参数以实现类似于numpy.trapz函数的功能,但对于如何正确实现它,我有点困惑 不可广播的函数如下所示: import numpy as np def logtrapz(y, x): logx = np.log(x) dlogx = np.diff(logx) logy = np.log(y) dlogy = np.diff(logy) b = dlogx +

我正在编写一个相当简单的函数,用于在对数空间中应用梯形规则执行积分

我想添加axis参数以实现类似于
numpy.trapz
函数的功能,但对于如何正确实现它,我有点困惑

不可广播的函数如下所示:

import numpy as np

def logtrapz(y, x):

    logx = np.log(x)
    dlogx = np.diff(logx)

    logy = np.log(y)
    dlogy = np.diff(logy)

    b = dlogx + dlogy
    a = np.exp(logx + logy)

    dF = a[:-1] * (np.exp(b) - 1)/b * dlogx

    return np.sum(dF)
这适用于1D输入


我认为解决方案在于
numpy.expand_dims
,但我不确定如何实现它来演示交互式会话中的
片段探索:

In [216]: slice(None)                                                           
Out[216]: slice(None, None, None)
In [217]: slice??                                                               
Init signature: slice(self, /, *args, **kwargs)
Docstring:     
slice(stop)
slice(start, stop[, step])

Create a slice object.  This is used for extended slicing (e.g. a[0:10:2]).
Type:           type
Subclasses:     
In [218]: np.s_[:]                                                              
Out[218]: slice(None, None, None)
我没有看过
np.trapz
代码,但我知道其他
numpy
函数通常在需要
轴时构造索引元组

例如,三维阵列的通用索引:

In [221]: arr = np.arange(24).reshape(2,3,4)                                    
In [223]: idx = [slice(None) for _ in range(3)]                                 
In [224]: idx                                                                   
Out[224]: [slice(None, None, None), slice(None, None, None), slice(None, None, None)]
In [225]: idx[1]=1                                                              
In [226]: idx                                                                   
Out[226]: [slice(None, None, None), 1, slice(None, None, None)]
In [227]: tuple(idx)                                                            
Out[227]: (slice(None, None, None), 1, slice(None, None, None))
In [228]: arr[tuple(idx)]     # arr[:,1,:]                                                  
Out[228]: 
array([[ 4,  5,  6,  7],
       [16, 17, 18, 19]])
In [229]: idx[2]=2                                                              
In [230]: arr[tuple(idx)]     # arr[:,1,2]                                                  
Out[230]: array([ 6, 18])

要演示交互式会话中的
片段
探索,请执行以下操作:

In [216]: slice(None)                                                           
Out[216]: slice(None, None, None)
In [217]: slice??                                                               
Init signature: slice(self, /, *args, **kwargs)
Docstring:     
slice(stop)
slice(start, stop[, step])

Create a slice object.  This is used for extended slicing (e.g. a[0:10:2]).
Type:           type
Subclasses:     
In [218]: np.s_[:]                                                              
Out[218]: slice(None, None, None)
我没有看过
np.trapz
代码,但我知道其他
numpy
函数通常在需要
轴时构造索引元组

例如,三维阵列的通用索引:

In [221]: arr = np.arange(24).reshape(2,3,4)                                    
In [223]: idx = [slice(None) for _ in range(3)]                                 
In [224]: idx                                                                   
Out[224]: [slice(None, None, None), slice(None, None, None), slice(None, None, None)]
In [225]: idx[1]=1                                                              
In [226]: idx                                                                   
Out[226]: [slice(None, None, None), 1, slice(None, None, None)]
In [227]: tuple(idx)                                                            
Out[227]: (slice(None, None, None), 1, slice(None, None, None))
In [228]: arr[tuple(idx)]     # arr[:,1,:]                                                  
Out[228]: 
array([[ 4,  5,  6,  7],
       [16, 17, 18, 19]])
In [229]: idx[2]=2                                                              
In [230]: arr[tuple(idx)]     # arr[:,1,2]                                                  
Out[230]: array([ 6, 18])

我通过复制
numpy.trapz
中使用的方法解决了这个问题。这有点复杂,但效果相当不错

对于将来的读者,上述函数的可广播版本是

import numpy as np

def logtrapz(y, x, axis=-1):

    x = np.asanyarray(x)
    logx = np.log(x)
    if x.ndim == 1:
        dlogx = np.diff(logx)
        # reshape to correct shape
        shape1 = [1]*y.ndim
        shape1[axis] = dlogx.shape[0]
        shape2 = [1]*y.ndim
        shape2[axis] = logx.shape[0]
        dlogx = dlogx.reshape(shape1)
        logx  = logx.reshape(shape2)
    else:
        dlogx = np.diff(x, axis=axis)

    nd = y.ndim
    slice1 = [slice(None)]*nd
    slice2 = [slice(None)]*nd
    slice1[axis] = slice(None, -1)
    slice2[axis] = slice(1, None)
    slice1 = tuple(slice1)
    slice2 = tuple(slice2)

    logy = np.log(y)
    dlogy = logy[slice2] - logy[slice1]

    b = dlogx + dlogy
    a = np.exp(logx + logy)

    dF = a[slice1] * (np.exp(b) - 1)/b * dlogx

    np.sum(dF, axis=axis)
为了实现“可广播性”,采用了
重塑
切片
的组合,显式地创建具有所需输出形状的“形状”向量


我认为这可以用一种更短、更简洁的方法来实现,但显然这是numpy本身实现的方法。

我复制了
numpy.trapz
中使用的方法来解决这个问题。这有点复杂,但效果相当不错

对于将来的读者,上述函数的可广播版本是

import numpy as np

def logtrapz(y, x, axis=-1):

    x = np.asanyarray(x)
    logx = np.log(x)
    if x.ndim == 1:
        dlogx = np.diff(logx)
        # reshape to correct shape
        shape1 = [1]*y.ndim
        shape1[axis] = dlogx.shape[0]
        shape2 = [1]*y.ndim
        shape2[axis] = logx.shape[0]
        dlogx = dlogx.reshape(shape1)
        logx  = logx.reshape(shape2)
    else:
        dlogx = np.diff(x, axis=axis)

    nd = y.ndim
    slice1 = [slice(None)]*nd
    slice2 = [slice(None)]*nd
    slice1[axis] = slice(None, -1)
    slice2[axis] = slice(1, None)
    slice1 = tuple(slice1)
    slice2 = tuple(slice2)

    logy = np.log(y)
    dlogy = logy[slice2] - logy[slice1]

    b = dlogx + dlogy
    a = np.exp(logx + logy)

    dF = a[slice1] * (np.exp(b) - 1)/b * dlogx

    np.sum(dF, axis=axis)
为了实现“可广播性”,采用了
重塑
切片
的组合,显式地创建具有所需输出形状的“形状”向量


我原以为这可以用一种更短更简洁的方法来实现,但显然这是numpy本身实现的方法。

你可以阅读
np.trapz
np.expand_dims
@hpaulj的Python代码,我现在正在做,但有些段落使用了几乎没有文档的特性。特别是
np.trapz
源代码()的第4059到4064行使用了
slice(None)
,这在官方slice文档中没有记录。
slice(None)
生成一个slice对象,
slice(None,None,None)
与索引符号中的
[:]
[:]
。这是一种从经验中比从文档中获得更多的细节。在交互环境中尝试一些代码对于这种代码阅读和开发是必不可少的。您可以阅读
np.trapz
np.expand_dims
@hpaulj的Python代码。我现在正在做这件事,但有些段落使用了几乎没有文档记录的特性。特别是
np.trapz
源代码()的第4059到4064行使用了
slice(None)
,这在官方slice文档中没有记录。
slice(None)
生成一个slice对象,
slice(None,None,None)
与索引符号中的
[:]
[:]
。这是一种从经验中比从文档中获得更多的细节。在交互环境中尝试代码位对于这种代码阅读和开发非常重要。我非常感谢您对
slice(None)
的解释,但我不能说这是对我问题的回答。我非常感谢您对
slice(None)
的解释,但我不能说这是对我问题的回答。