Python 在n维数组上使用带网格网格的scipy interpn
我试图翻译一个大型4D数组的Matlab“interpn”插值,但Matlab和Python之间的公式存在显著差异。几年前有一个很好的问题/答案,我一直在尝试与之合作。我想我就快到了,但显然我的网格插值器还没有正确的公式化 在使用我实际使用的维度时,我尽可能地按照上面链接的答案中给出的示例对代码示例进行建模。唯一的变化是我将rollaxis切换为moveaxis,因为前者已被弃用 基本上,给定4D数组skyrad0(取决于第一个代码块中定义的四个元素)以及第三个代码块中定义的两个常量和两个1D数组,我想要插值的2D结果Python 在n维数组上使用带网格网格的scipy interpn,python,matlab,numpy,interpolation,n-dimensional,Python,Matlab,Numpy,Interpolation,N Dimensional,我试图翻译一个大型4D数组的Matlab“interpn”插值,但Matlab和Python之间的公式存在显著差异。几年前有一个很好的问题/答案,我一直在尝试与之合作。我想我就快到了,但显然我的网格插值器还没有正确的公式化 在使用我实际使用的维度时,我尽可能地按照上面链接的答案中给出的示例对代码示例进行建模。唯一的变化是我将rollaxis切换为moveaxis,因为前者已被弃用 基本上,给定4D数组skyrad0(取决于第一个代码块中定义的四个元素)以及第三个代码块中定义的两个常量和两个1D数
from scipy.interpolate import interpn
import numpy as np
# Define the data space in the 4D skyrad0 array
solzen = np.arange(0,70,10) # 7
aod = np.arange(0,0.25,0.05) # 5
index = np.arange(1,92477,1) # 92476
wave = np.arange(350,1050,5) # 140
# Simulated skyrad for the values above
skyrad0 = np.random.rand(
solzen.size,aod.size,index.size,wave.size) # 7, 5, 92476, 140
# Data space for desired output values of skyrad
# with interpolation between input data space
solzen0 = 30 # 1
aod0 = 0.1 # 1
index0 = index # 92476
wave0 = np.arange(350,1050,10) # 70
# Matlab
# result = squeeze(interpn(solzen, aod, index, wave,
# skyrad0,
# solzen0, aod0, index0, wave0))
# Scipy
points = (solzen, aod, index, wave) # 7, 5, 92476, 140
interp_mesh = np.array(
np.meshgrid(solzen0, aod0, index0, wave0)) # 4, 1, 1, 92476, 70
interp_points = np.moveaxis(interp_mesh, 0, -1) # 1, 1, 92476, 70, 4
interp_points = interp_points.reshape(
(interp_mesh.size // interp_mesh.shape[3],
interp_mesh.shape[3])) # 280, 92476
result = interpn(points, skyrad0, interp_points)
我期待一个4D数组“结果”,我可以将其压缩到我需要的2D答案中,但interpn会产生错误:
ValueError: The requested sample points xi have dimension 92476, but this RegularGridInterpolator has dimension 4
在本例中,我最不清楚的是查询点网格的结构,以及将第一个维度移动到末尾并对其进行重塑。这方面还有很多,但我仍然不清楚如何将其应用于这个问题
如果有人能在我的公式中发现明显的低效,那将是一个额外的收获。我需要在许多不同的结构上运行这种类型的插值数千次——甚至扩展到6D——所以效率很重要
更新下面的答案非常优雅地解决了问题。然而,随着计算和数组变得越来越复杂,另一个问题逐渐出现,即数组中的元素不是单调增加的。以下是在6D中重新定义的问题:
# Data space in the 6D rad_boa array
azimuth = np.arange(0, 185, 5) # 37
senzen = np.arange(0, 185, 5) # 37
wave = np.arange(350,1050,5) # 140
# wave = np.array([350, 360, 370, 380, 390, 410, 440, 470, 510, 550, 610, 670, 750, 865, 1040, 1240, 1640, 2250]) # 18
solzen = np.arange(0,65,5) # 13
aod = np.arange(0,0.55,0.05) # 11
wind = np.arange(0, 20, 5) # 4
# Simulated rad_boa
rad_boa = np.random.rand(
azimuth.size,senzen.size,wave.size,solzen.size,aod.size,wind.size,) # 37, 37, 140/18, 13, 11, 4
azimuth0 = 135 # 1
senzen0 = 140 # 1
wave0 = np.arange(350,1010,10) # 66
solzen0 = 30 # 1
aod0 = 0.1 # 1
wind0 = 10 # 1
da = xr.DataArray(name='Radiance_BOA',
data=rad_boa,
dims=['azimuth','senzen','wave','solzen','aod','wind'],
coords=[azimuth,senzen,wave,solzen,aod,wind])
rad_inc_scaXR = da.loc[azimuth0,senzen0,wave0,solzen0,aod0,wind0].squeeze()
目前,它运行,但如果将wave的定义更改为注释行,则会抛出错误:
KeyError: "not all values found in index 'wave'"
最后,为了回应下面的评论(并帮助提高效率),我加入了HDF5文件(在Matlab中创建)的结构,这个“rad_boa”6D阵列实际上就是从这个文件中构建的(上面的示例只使用了一个模拟的随机阵列)。将实际数据库读入Xarray,如下所示:
sdb = xr.open_dataset(db_path, group='sdb')
由此产生的Xarray如下所示:
为什么出现值错误? 首先,
scipy.interpolate.interpn
要求interp_points.shape[-1]
与问题中的维数相同。这就是为什么您从代码段中获得ValueError
的原因——您的interp\u points
的92476为n_dims
,这与实际DIM(4)数量冲突
快速修复
只需更改操作顺序,即可修复此代码段。您试图挤得太早--如果在interp之后挤:
points = (solzen, aod, index, wave) # 7, 5, 92476, 140
mg = np.meshgrid(solzen0, aod0, index0, wave0) # 4, 1, 1, 92476, 70
interp_points = np.moveaxis(mg, 0, -1) # 1, 1, 92476, 70, 4
result_presqueeze = interpn(points,
skyrad0, interp_points) # 1, 1, 92476, 70
result = np.squeeze(result_presqueeze,
axis=(0,1)) # 92476, 70
我在这里将interp_mesh
替换为mg
,并删除了np.array
(这不是必需的,因为np.meshgrid
返回一个ndarray
对象)
绩效评价
我认为您的代码片段很好,但是如果您正在处理带标签的数据,您可能希望使用xarray
,因为:
- 比未标记的
数组更可读numpy
- 还可以使用来处理一些后台工作(如果您在6D中检查大量数据,这很有用)
.interp
,而不是.loc
。下面的代码片段可以工作,因为数据点实际上是原始数据点。作为对他人的警告:
from scipy.interpolate import interpn
import numpy as np
from xarray import DataArray
# Define the data space in the 4D skyrad0 array
solzen = np.arange(0,70,10) # 7
aod = np.arange(0,0.25,0.05) # 5
index = np.arange(1,92477,1) # 92476
wave = np.arange(350,1050,5) # 140
# Simulated skyrad for the values above
skyrad0 = np.random.rand(
solzen.size,aod.size,index.size,wave.size) # 7, 5, 92476, 140
# Data space for desired output values of skyrad
# with interpolation between input data space
solzen0 = 30 # 1
aod0 = 0.1 # 1
index0 = index # 92476
wave0 = np.arange(350,1050,10) # 70
def slow():
points = (solzen, aod, index, wave) # 7, 5, 92476, 140
mg = np.meshgrid(solzen0, aod0, index0, wave0) # 4, 1, 1, 92476, 70
interp_points = np.moveaxis(mg, 0, -1) # 1, 1, 92476, 70, 4
result_presqueeze = interpn(points,
skyrad0, interp_points) # 1, 1, 92476, 70
result = np.squeeze(result_presqueeze,
axis=(0,1)) # 92476, 70
return result
# This function uses .loc instead of .interp!
"""
def fast():
da = DataArray(name='skyrad0',
data=skyrad0,
dims=['solzen','aod','index','wave'],
coords=[solzen, aod, index, wave])
result = da.loc[solzen0, aod0, index0, wave0].squeeze()
return result
"""
通过对OP给出的更新代码段进行两次修改:
import numpy as np
import xarray as xr
from scipy.interpolate import interpn
azimuth = np.arange(0, 185, 5) # 37
senzen = np.arange(0, 185, 5) # 37
#wave = np.arange(350,1050,5) # 140
wave = np.asarray([350, 360, 370, 380, 390, 410, 440, 470, 510,
550, 610, 670, 750, 865, 1040, 1240, 1640, 2250]) # 18
solzen = np.arange(0,65,5) # 13
aod = np.arange(0,0.55,0.05) # 11
wind = np.arange(0, 20, 5) # 4
coords = [azimuth, senzen, wave, solzen, aod, wind]
azimuth0 = 135 # 1
senzen0 = 140 # 1
wave0 = np.arange(350,1010,10) # 66
solzen0 = 30 # 1
aod0 = 0.1 # 1
wind0 = 10 # 1
interp_coords = [azimuth0, senzen0, wave0, solzen0, aod0, wind0]
# Simulated rad_boa
rad_boa = np.random.rand(
*map(lambda x: x.size, coords)) # 37, 37, 140/18, 13, 11, 4
def slow():
mg = np.meshgrid(*interp_coords)
interp_points = np.moveaxis(mg, 0, -1)
result_presqueeze = interpn(coords,
rad_boa, interp_points)
result = np.squeeze(result_presqueeze)
return result
def fast():
da = xr.DataArray(name='Radiance_BOA',
data=rad_boa,
dims=['azimuth','senzen','wave','solzen','aod','wind'],
coords=coords)
interp_dict = dict(zip(da.dims, interp_coords))
rad_inc_scaXR = da.interp(**interp_dict).squeeze()
return rad_inc_scaXR
这是相当迅速的:
>>> %timeit slow()
2.09 ms ± 85.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
>>> %timeit fast()
343 ms ± 6.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> np.array_equal(slow(),fast())
True
您可以找到有关xarray
插值的更多信息。数据集实例具有非常相似的语法
还可以根据需要更改插值方法(对于离散插值问题,可能希望将关键字参数method='nearest'
提供给.interp
)
更高级的东西
如果您希望实现更高级的功能,我建议您使用MARS(多元自适应回归样条曲线)的一种实现。它介于标准回归和插值之间,适用于多维情况。在Python 3中,您的最佳选择是。完美;更棒的是,真正的skyrad实际上已经存在,而且xarray已经存在(根据skyrad0=xr.open_数据集(db_path)['skyrad0']),尽管我目前无法直接使用loc引用它(正在使用它),并且必须以它的元素构建da作为示例来让它工作。很高兴提供帮助。如果您遇到问题,请随意编辑您的原始帖子,以包含通过
xarray导入的数据集的字符串表示形式。打开\u Dataset
,让所有人都可以查看。根据HDF5文件的形式,有时将其作为DataArray(xarray.open_DataArray
)导入,然后手动构建数据集会更容易。我编辑了我的问题,添加了另一个关于单调递增值的问题,以及用于将实际数据库从HDF5文件拉入Xarray的代码,以及由此产生的数据库结构的屏幕截图。我已经更新了我的回复。请注意,它以前错误地使用了.loc
,而不是.interp
——与线性间隔点无关,只是所有点实际上都在原始坐标集中。