Python 如何在Tensorflow中用二维张量索引三维张量?

Python 如何在Tensorflow中用二维张量索引三维张量?,python,tensorflow,Python,Tensorflow,我试图用一个二维张量来索引Tensorflow中的一个三维张量。例如,我有形状为[2,3,4]的x: [[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] 我想用另一个y形状的张量[2,3]对它进行索引,其中y的每个元素对x的最后一个维度进行索引。例如,如果我们有y像: [[0, 2, 3], [1,

我试图用一个二维张量来索引Tensorflow中的一个三维张量。例如,我有形状为[2,3,4]的
x

[[[ 0,  1,  2,  3],
  [ 4,  5,  6,  7],
  [ 8,  9, 10, 11]],

 [[12, 13, 14, 15],
  [16, 17, 18, 19],
  [20, 21, 22, 23]]]
我想用另一个
y
形状的张量
[2,3]
对它进行索引,其中
y
的每个元素对
x
的最后一个维度进行索引。例如,如果我们有
y
像:

[[0, 2, 3],
 [1, 0, 2]]
输出的形状应为
[2,3]

[[0, 6, 11],
 [13, 16, 22]]

使用
tf.meshgrid
创建索引,然后使用
tf.gather\n
提取元素:

# create a list of indices for except the last axis
idx_except_last = tf.meshgrid(*[tf.range(s) for s in x.shape[:-1]], indexing='ij')

# concatenate with last axis indices
idx = tf.stack(idx_except_last + [y], axis=-1)

# gather elements based on the indices
tf.gather_nd(x, idx).eval()

# array([[ 0,  6, 11],
#        [13, 16, 22]])

如果x的形状未知怎么办?例如,x和y的第一个维度是在运行时计算的,我不知道图形构造期间的值<如果参数为
Dimension(None)
,则code>tf.range函数将失败。对于动态形状,请尝试
x.get\u shape()
以获得
x.shape