Python 如何在TensorFlow中实现Numpy where索引?

Python 如何在TensorFlow中实现Numpy where索引?,python,numpy,tensorflow,Python,Numpy,Tensorflow,我有以下使用numpy.where的操作: mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) index = np.array([[1,0,0],[0,1,0],[0,0,1]]) mat[np.where(index>0)] = 100 print(mat) 如何在TensorFlow中实现等效 mat = np.array([[1, 2, 3], [4, 5, 6], [

我有以下使用
numpy.where
的操作:

    mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
    index = np.array([[1,0,0],[0,1,0],[0,0,1]])
    mat[np.where(index>0)] = 100
    print(mat)
如何在TensorFlow中实现等效

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
indi = tf.where(tf_index>0)
tf_mat[indi] = -1   <===== not allowed 
mat=np.array([[1,2,3],[4,5,6],[7,8,9]],dtype=np.int32)
index=np.数组([[1,0,0],[0,1,0],[0,0,1]]
tf_mat=tf.常数(mat)
tf_指数=tf常数(指数)
indi=tf.where(tf_指数>0)

tf_mat[indi]=-1您可以通过
tf获取索引。在
tf中,您可以运行索引,或者使用
tf.gather
从原始数组收集数据,或者使用
tf.scatter\u update
更新原始数据,
tf.scatter\u nd\u update
进行多维更新

mat = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=tf.int32)
index = tf.Variable([[1,0,0],[0,1,0],[0,0,1]])
idx = tf.where(index>0)
tf.scatter_nd_update(mat, idx, /*values you want*/)
请注意,更新值应与idx的第一个维度大小相同


请参见

假设您想要的是用一些替换的元素创建一个新的张量,而不是更新一个变量,您可以这样做:

import numpy as np
import tensorflow as tf

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
tf_mat = tf.where(tf_index > 0, -tf.ones_like(tf_mat), tf_mat)
with tf.Session() as sess:
    print(sess.run(tf_mat))
输出:

[[-1  2  3]
 [ 4 -1  6]
 [ 7  8 -1]]

np.where()
在这里不需要:
mat[index>0]=100
无论如何,如何在tensorflow中达到相同的目的?
idx
scatter\u update
中使用的是1D,
tf.where
返回此处不工作!如果您想使用多个1D更新,请使用
tf.scatter\u-update
而不是
tf.scatter\u-update
是,
tf.scatter\u-update
有效,这是@jdehesa建议的
tf.where()
的替代方法。谢谢你,这不是我想要的@罗比,你能澄清一下你到底想要什么吗?或者答案不符合你的需要呢?@Roby哎呀,对不起,我刚刚意识到我在输出端复制了错误的矩阵(代码仍然很好…),现在修复了。