Tensorflow中4-D张量的tf.nn.top_k指数的二进制掩码?

Tensorflow中4-D张量的tf.nn.top_k指数的二进制掩码?,tensorflow,Tensorflow,我有一个形状为(10,32,32,128)的四维张量。我想为所有前N个元素生成一个二进制掩码 arr = tf.random_normal(shape=(10, 32, 32, 128)) values, indices = tf.nn.top_k(arr, N=64) 我的问题是,如果有人在寻找答案,如何使用tf.nn.top_k返回的索引获得与arr形状相同的二进制掩码:开始了 K = 64 arr = tf.random_normal(shape=(10, 32, 32, 128)) v

我有一个形状为(10,32,32,128)的四维张量。我想为所有前N个元素生成一个二进制掩码

arr = tf.random_normal(shape=(10, 32, 32, 128))
values, indices = tf.nn.top_k(arr, N=64)

我的问题是,如果有人在寻找答案,如何使用
tf.nn.top_k
返回的
索引获得与
arr
形状相同的二进制掩码:开始了

K = 64
arr = tf.random_normal(shape=(10, 32, 32, 128))
values, indices = tf.nn.top_k(arr, k=K, sorted=False)

temp_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(
       tf.shape(arr)[:(arr.get_shape().ndims - 1)]) + [K])], indexing='ij')
temp_indices = tf.stack(temp_indices[:-1] + [indices], axis=-1)
full_indices = tf.reshape(temp_indices, [-1, arr.get_shape().ndims])
values = tf.reshape(values, [-1])

mask_st = tf.SparseTensor(indices=tf.cast(
      full_indices, dtype=tf.int64), values=tf.ones_like(values), dense_shape=arr.shape)
mask = tf.sparse_tensor_to_dense(tf.sparse_reorder(mask_st))

我看到这些是arr、值、索引的形状。但是,生成二进制掩码的条件是什么?除了位于
索引
位置的值之外,所有值都应该为零。您找到解决方案了吗?@Perm.Questiin请参见下面的答案。