Python 沿着张量的第二维度聚集元素

Python 沿着张量的第二维度聚集元素,python,tensorflow,Python,Tensorflow,假设值和张量T都具有形状(N,K)。现在,如果我们从矩阵的角度来考虑它们,我想为T的每一行获取对应于值最大的索引的行元素。我可以很容易地找到这些索引 max_indicies = tf.argmax(T, 1) 它返回一个形状为(N)的张量。现在,我如何从T中收集这些索引,从而得到形状N?我试过了 result = tf.gather(T,max_indices) 但是它没有做正确的事情-它返回一些形状(N,K),这意味着它没有收集任何东西。您可以使用 比如说, import tensorf

假设
值和张量
T
都具有形状
(N,K)
。现在,如果我们从矩阵的角度来考虑它们,我想为
T
的每一行获取对应于
最大的索引的行元素。我可以很容易地找到这些索引

max_indicies = tf.argmax(T, 1)
它返回一个形状为
(N)
的张量。现在,我如何从
T
中收集这些索引,从而得到形状
N
?我试过了

result = tf.gather(T,max_indices)
但是它没有做正确的事情-它返回一些形状
(N,K)
,这意味着它没有收集任何东西。

您可以使用

比如说,

import tensorflow as tf

sess = tf.InteractiveSession()

values = tf.constant([[0, 0, 0, 1],
                      [0, 1, 0, 0],
                      [0, 0, 1, 0]])

T = tf.constant([[0, 1, 2 ,  3],
                 [4, 5, 6 ,  7],
                 [8, 9, 10, 11]])

max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0], 
                                            dtype=max_indices.dtype),
                                   max_indices),
                                  axis=1))

print(result.eval())

但是,当
T
的等级较高时,使用
tf.collection\n和
会有点尴尬。我将我当前的解决方案发布到了。对于高维
T

的情况,可能有更好的解决方案。谢谢您,先生。