Python 如何在tensorflow中复制numpy.choose()?

Python 如何在tensorflow中复制numpy.choose()?,python,numpy,tensorflow,Python,Numpy,Tensorflow,我正在尝试有效地复制numpy的ndarray.choose()方法 下面是一个我正在寻找的numpy示例: b = np.arange(15).reshape(3, 5) c = np.array([1,0,4]) c.choose(b.T) # trying to replicate in tensorflow -> array([ 1, 5, 14]) 我所能做的最好的事情就是生成一个批量大小的平方矩阵(如果批量大小很大,则该矩阵很大),并取其对角线: tf_b = tf

我正在尝试有效地复制numpy的
ndarray.choose()
方法

下面是一个我正在寻找的numpy示例:

b = np.arange(15).reshape(3, 5)
c = np.array([1,0,4])
c.choose(b.T)  # trying to replicate in tensorflow
  -> array([ 1,  5, 14]) 
我所能做的最好的事情就是生成一个批量大小的平方矩阵(如果批量大小很大,则该矩阵很大),并取其对角线:

tf_b = tf.constant(b)
tf_c = tf.constant(c)
sess.run(tf.diag_part(tf.gather(tf.transpose(tf_b), tf_c)))
-> array([ 1,  5, 14])

有没有一种方法在第一维度上是线性的(而不是平方的)?

是的,有一种更简单的方法。将
b
数组展平为一维,使其成为
[0,1,2,…,13,14]
。获取一个索引数组,该数组在您正在获取的“选择”数量范围内(在您的案例中为3)。这将是
[0,1,2]
。将该范围乘以原始形状的第二个维度,即每个选项的选项数(本例中为5)。这将为您提供
[0,5,10]
。然后将索引添加到此中以获得
[1,5,14]
。现在可以调用tf.gather()

下面是我从中获取的一些代码,它们对RNN输出执行类似的操作。你的想法会略有不同,但想法是一样的

index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, out_size])
relevant = tf.gather(flat, index)
return relevant

从大局来看,操作相当简单。使用range操作获取每行开头的索引,然后添加每行所在位置的索引。我认为在1D中做这件事是最容易的,所以我们把它展平。

你的
numpy
代码相当于
b[np.arange(3),c]
choose
有一条注释,禁止在
选项中使用单个数组(如
b.T
。在
numpy
中,此索引的一维版本是
b.flat[np.arange(b.shape[0])*b.shape[1]+c]