Python 如何通过不规则索引获取子传感器?

Python 如何通过不规则索引获取子传感器?,python,tensorflow,Python,Tensorflow,我想通过不规则索引得到次张量。 这是我的问题 Input tensor = 2x8x10x1(Batch x Height x Width x Channel) index_Height = [0,1,4,5] index_Width = [0,1,4,5,8,9] Output_tensor = 2x4x6x1 我怎样才能得到这个结果 以下是我的python版本: input_np = np.zeros((2,8,10,1)) nx, ny = (10, 8) x = np.linsp

我想通过不规则索引得到次张量。 这是我的问题

Input tensor = 2x8x10x1(Batch x Height x Width x Channel)

index_Height = [0,1,4,5]

index_Width = [0,1,4,5,8,9]

Output_tensor = 2x4x6x1
我怎样才能得到这个结果

以下是我的python版本:

input_np = np.zeros((2,8,10,1))
nx, ny = (10, 8)
x = np.linspace(0, 9, nx)
y = np.linspace(0, 7,  ny)
xv, yv = np.meshgrid(x , y)
input_np[0,:,:,0] = xv
input_np[1,:,:,0] = yv

index_Height = [0,1,4,5]
index_Width  = [0,1,4,5,8,9]

output_np = input_np[:,index_Height][:,:,index_Width]
如何在tensorflow上执行此操作?谢谢

只需执行两次
tf.gather()

import tensorflow as tf

inputtensor = tf.constant(input_np)
height_tensor = tf.gather(inputtensor,index_Height,axis=1)
output_tensor = tf.gather(height_tensor,index_Width,axis=2)
print(output_tensor.shape)

with tf.Session() as sess:
    output_tensor_val = sess.run(output_tensor)
    print((output_tensor_val==output_np).all())

(2, 4, 6, 1)
True