Python Tensorflow中的布尔索引
Tensorflow似乎不支持布尔索引。如何在Tensorflow中执行此操作Python Tensorflow中的布尔索引,python,tensorflow,Python,Tensorflow,Tensorflow似乎不支持布尔索引。如何在Tensorflow中执行此操作 import numpy as np A = np.array([3, 4, 5, -1, 6, -1, 7, 8]) mask = (A == -1) print(A) A[mask] = [11, 12] print(A) 要使用布尔数组提取元素,可以使用: 然而,这似乎不支持分配。 如果需要将元素更新为相同的值,请使用tf.where,它可以用于常量和变
import numpy as np
A = np.array([3, 4, 5, -1, 6, -1, 7, 8])
mask = (A == -1)
print(A)
A[mask] = [11, 12]
print(A)
要使用布尔数组提取元素,可以使用: 然而,这似乎不支持分配。 如果需要将元素更新为相同的值,请使用tf.where,它可以用于常量和变量:
a = tf.constant([3, 4, 5, -1, 6, -1, 7, 8])
mask = tf.equal(a, -1)
tf.where(mask, [11] * a.shape[0], a).eval()
# array([ 3, 4, 5, 11, 6, 11, 7, 8], dtype=int32)
如果更新的值是具有不同值的自定义数组,我们可以通过首先将布尔掩码转换为索引来使用,在这种情况下,a需要是一个变量:
a = tf.constant([3, 4, 5, -1, 6, -1, 7, 8])
mask = tf.equal(a, -1)
tf.where(mask, [11] * a.shape[0], a).eval()
# array([ 3, 4, 5, 11, 6, 11, 7, 8], dtype=int32)
您是否也知道我所做的编辑是否可行,因此[11,12]不是[11]而是一个任意列表,其编号与maskThanks中的Trues相同,但如果a不是一个变量呢?
a = tf.Variable([3, 4, 5, -1, 6, -1, 7, 8])
tf.global_variables_initializer().run()
mask = tf.equal(a, -1)
indices = tf.reshape(tf.where(mask), (-1,))
tf.scatter_update(a, indices, [11, 12]).eval()
# array([ 3, 4, 5, 11, 6, 12, 7, 8], dtype=int32)