Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
如何在TensorFlow中从三维张量中选择行?_Tensorflow - Fatal编程技术网

如何在TensorFlow中从三维张量中选择行?

如何在TensorFlow中从三维张量中选择行?,tensorflow,Tensorflow,我有一个张量logits,其尺寸[批次大小、行数、坐标数](即批次中的每个logit是一个矩阵)。在我的例子中,批大小是2,有4行和4个坐标 logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], [11.0, 10.0, 10.0, 30.0], [12.0, 10.0, 10.0, 20.0], [13.0, 10.0,

我有一个张量
logits
,其尺寸
[批次大小、行数、坐标数]
(即批次中的每个logit是一个矩阵)。在我的例子中,批大小是2,有4行和4个坐标

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])
我想选择第一批的第一行和第二行,以及第二批的第二行和第四行

indices = tf.constant([[0, 1], [1, 3]])
所以期望的输出是

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0]],
                     [[15.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

如何使用TensorFlow进行此操作?我尝试使用
tf.gather(logits,index)
,但它没有返回我所期望的结果。谢谢

这在TensorFlow中是可能的,但有点不方便,因为目前仅适用于一维索引,并且仅从张量的第0维选择切片。但是,通过转换参数,使其可以传递到
tf.gather()
,仍然可以有效地解决问题:


请注意,由于这使用and not,它不需要修改
logits
张量中的(可能较大)数据,因此它应该相当有效。

mrry的答案很好,但我认为使用该函数,可以用更少的代码行解决问题(在mrry撰写本文时,此功能可能还不可用):

这会打印出来

[[[ 10.  10.  20.  20.]
  [ 11.  10.  10.  30.]]

 [[ 15.  11.  11.  21.]
  [ 17.  11.  11.  21.]]]

应该从v0.10开始提供。查看更多关于这方面的讨论。

虽然您的答案很好,但我认为今天可以用
tf.gather\u nd
替换,这在您撰写本文时可能还不可用(请参见我的答案)。您是如何将索引从2d更改为3d的(如问题所述)?@Tulsi我不明白你的问题。问题中没有提到3D索引,或者是吗?
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]])

result = tf.gather_nd(logits, indices)
with tf.Session() as sess:
    print(sess.run(result))
[[[ 10.  10.  20.  20.]
  [ 11.  10.  10.  30.]]

 [[ 15.  11.  11.  21.]
  [ 17.  11.  11.  21.]]]