在tensorflow中,除了零之外,我如何获得最小值索引

在tensorflow中,除了零之外,我如何获得最小值索引,tensorflow,Tensorflow,我想得到张量的最小值索引,但值不是0 a = np.array([[0, 3, 9, 0], [0, 0, 5, 7]]) tensor_a = tf.constant(a, dtype=tf.int32) max_index = tf.argmax(tensor_a, axis=1) 上面的代码定义了一个常量张量,如果我使用tf.argmax,我将得到索引[2,3]。如何得到第一行中的3和第二行中的5的索引,最小值而不是零。我想要得到的真正索引是[1,2]。 如何在t

我想得到张量的最小值索引,但值不是0

a = np.array([[0, 3, 9, 0],
            [0, 0, 5, 7]])
tensor_a = tf.constant(a, dtype=tf.int32)
max_index = tf.argmax(tensor_a, axis=1)
上面的代码定义了一个常量张量,如果我使用tf.argmax,我将得到索引[2,3]。如何得到第一行中的3和第二行中的5的索引,最小值而不是零。我想要得到的真正索引是[1,2]。
如何在tensorflow中实现它,谢谢。

这很可怕,但它可以工作:

with tf.Session() as sess:
    a = np.array([[0, 3, 9, 0],
                [0, 0, 5, 7]])
    tensor_a = tf.constant(a, dtype=tf.int64)
    row_max = tf.reshape(tf.reduce_max(a, axis=-1), [-1, 1]) + 1
    max_index = tf.argmin(tf.where(tensor_a > 0, tensor_a, row_max * tf.ones_like(tensor_a)), axis=1)
    print(max_index.eval()) # -> [1 2]

你在找什么?指数2和3处的最大值分别为9和5。你在找什么?我想得到除零以外的最小值。如何获得张量_a中的最小值索引(应排除0)。在上面的例子中,因为我不需要0,所以第1行的最小值是3,第2行的最小值是5。代码应该返回[1,2]。类似于
tf.min(tf.boolean\u mask(x>0))
的东西可能是什么?