Python 如何将tf.gather与tf.where结合使用
由于Python 如何将tf.gather与tf.where结合使用,python,tensorflow,Python,Tensorflow,由于tf.where()的形状,以下操作不起作用。有没有好办法解决这个问题 我想要tensor_y的值,其中tensor_x满足一个条件(例如==值)。重要信息,张量具有批处理dims=1 tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32) tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.
tf.where()
的形状,以下操作不起作用。有没有好办法解决这个问题
我想要tensor_y
的值,其中tensor_x
满足一个条件(例如==值)。重要信息,张量具有批处理dims=1
tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
trues = tf.where(tensor_y ==1)
new_tensor = tf.gather(tensor_y, axis=-1, indices = trues,batch_dims=1)
我现在所做的工作很有效*,但效率不高,我认为:
new_tensor = tf.stack([tf.gather(tensor_y[i,:], tf.where(tensor_x[i,:] == 1)) for i in range(tensor_x.shape[0])])
*有时(我不知道在什么情况下)我会得到他的错误:
所有输入的形状必须匹配:值[0]。形状=[3,1]!=值[1]。形状=[6,1][Op:Pack]名称:stack
这是您需要的吗
tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
new_tensor = tensor_y[(tensor_x==1)]
这是在tensorflow环境中工作的。但在此环境之外,我得到了以下错误:
传递给参数“begin”的值的数据类型bool不在允许值列表中:int32、int64