Python 如何根据tensorflow中的列条件获取张量值的索引

Python 如何根据tensorflow中的列条件获取张量值的索引,python,tensorflow,slice,indices,Python,Tensorflow,Slice,Indices,我有一个像这样的张量: sim_topics = [[0.65 0. 0. 0. 0.42 0. 0. 0.51 0. 0.34 0.] [0. 0.51 0. 0. 0.52 0. 0. 0. 0.53 0.42 0.] [0. 0.32 0. 0.50 0.34 0. 0. 0.39 0.32 0.52 0.] [0. 0.23 0.3

我有一个像这样的张量:

sim_topics = [[0.65 0.   0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
              [0.   0.51 0.   0.   0.52  0.   0.   0.   0.53 0.42 0.]
              [0.   0.32 0.   0.50 0.34  0.   0.   0.39 0.32 0.52 0.]
              [0.   0.23 0.37 0.   0.    0.37 0.37 0.   0.47 0.39 0.3 ]]
[[0.65 0. 0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
 [0.   0. 0.   0.   0.52  0.   0.   0.   0.   0.42 0.]
 [0.   0. 0.   0.   0.34  0.   0.   0.39 0.   0.52 0.]
 [0.   0. 0.37 0.   0.    0.37 0.   0.   0.   0.39 0.]]
result = tf.multiply(sim_topics, tf.cast(masked_t, dtype=tf.float64))
我想根据张量条件得到这个张量中的指数:

masked_t = [True  False  True  False True True False True False True False]
所以输出应该是这样的:

sim_topics = [[0.65 0.   0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
              [0.   0.51 0.   0.   0.52  0.   0.   0.   0.53 0.42 0.]
              [0.   0.32 0.   0.50 0.34  0.   0.   0.39 0.32 0.52 0.]
              [0.   0.23 0.37 0.   0.    0.37 0.37 0.   0.47 0.39 0.3 ]]
[[0.65 0. 0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
 [0.   0. 0.   0.   0.52  0.   0.   0.   0.   0.42 0.]
 [0.   0. 0.   0.   0.34  0.   0.   0.39 0.   0.52 0.]
 [0.   0. 0.37 0.   0.    0.37 0.   0.   0.   0.39 0.]]
result = tf.multiply(sim_topics, tf.cast(masked_t, dtype=tf.float64))
所以条件是作用在初始张量的柱上。实际上,我需要元素的索引,它们在
maske\u t
中为真

因此,指数应为:

[[0, 0],
 [1,0],
 [2, 0],
 [3,0],
 [0,2],
 [1,2],
 [2,2],
 [3,2],
 ....]]
实际上,这种方法在我按行操作时有效,但在这里,我想根据条件选择特定列,这样会引发s不兼容错误:

out = tf.cast(tf.zeros(shape=tf.shape(sim_topics), dtype=tf.float64), tf.float64)
indices = tf.where(tf.where(masked_t, out, sim_topics))

您可以直接获得所需的张量,如下所示:

sim_topics = [[0.65 0.   0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
              [0.   0.51 0.   0.   0.52  0.   0.   0.   0.53 0.42 0.]
              [0.   0.32 0.   0.50 0.34  0.   0.   0.39 0.32 0.52 0.]
              [0.   0.23 0.37 0.   0.    0.37 0.37 0.   0.47 0.39 0.3 ]]
[[0.65 0. 0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
 [0.   0. 0.   0.   0.52  0.   0.   0.   0.   0.42 0.]
 [0.   0. 0.   0.   0.34  0.   0.   0.39 0.   0.52 0.]
 [0.   0. 0.37 0.   0.    0.37 0.   0.   0.   0.39 0.]]
result = tf.multiply(sim_topics, tf.cast(masked_t, dtype=tf.float64))
让广播完成蒙面主题与模拟主题大小相同的工作