Numpy Tensorflow:如何在没有np的情况下根据条件随机选择元素?

Numpy Tensorflow:如何在没有np的情况下根据条件随机选择元素?,numpy,tensorflow,Numpy,Tensorflow,我有3个tensorflow数组(a,b,valid_entries),它们共享前两个维度[T,N,?]。其中一个数组的“有效\u项”具有带布尔值的形状[T,N,1]。我想随机抽样T*M2个索引元组(M

我有3个tensorflow数组(
a
b
valid_entries
),它们共享前两个维度
[T,N,?]
。其中一个数组的“有效\u项”具有带布尔值的形状
[T,N,1]
。我想随机抽样
T*M
2个索引元组(
M
),这样所有这些索引的
有效项[T,M]==1

换句话说,对于每个时间步,我想从
a
b
中随机选择M个有效条目

我认为在numpy中,这个任务可以通过执行以下操作来解决(为了简单起见,让我们跳过第一维度T):

然而,所有这些都需要在Tensorflow中发生


非常感谢您的帮助

这里有一个函数可以:

import tensorflow as tf

def sample_indices(valid, m, seed=None):
    valid = tf.convert_to_tensor(valid)
    n = tf.size(valid)
    # Flatten boolean tensor
    valid_flat = tf.reshape(valid, [n])
    # Get flat indices where the tensor is true
    valid_idx = tf.boolean_mask(tf.range(n), valid_flat)
    # Shuffled valid indices
    valid_idx_shuffled = tf.random.shuffle(valid_idx, seed=seed)
    # Pick sample from shuffled indices
    valid_idx_sample = valid_idx_shuffled[:m]
    # Unravel indices
    return tf.transpose(tf.unravel_index(valid_idx_sample, tf.shape(valid)))

with tf.Graph().as_default(), tf.Session() as sess:
    valid = [[ True,  True, False,  True],
             [False,  True,  True, False],
             [False,  True, False, False]]
    m = 4
    print(sess.run(sample_indices(valid, m, seed=0)))
    # [[1 1]
    #  [1 2]
    #  [0 1]
    #  [2 1]]
这个
样本索引
对于任何形状的布尔张量都是通用的。如果在您的例子中,
valid\u entries
具有shape
(T,N,1)
,那么您将得到一个输出为shape
(M,3)
的张量,尽管您可以忽略最后一列,因为它总是为零(或者您可以通过
tf.squence(valid\u entries,axis=2)


注意:最后一个
tf.transpose
只是将一个带有形状
(样本大小,数值尺寸)
的张量作为输出,而不是相反。但是,如果
m
相当大,并且您不介意维度的顺序,您可以跳过它以节省一点时间和内存,因为(与它的NumPy对应物不同)
tf.transpose
产生一个全新的张量。

太棒了,谢谢!仅供参考,然后我使用
tf.gather
对索引列表进行索引/切片。
import tensorflow as tf

def sample_indices(valid, m, seed=None):
    valid = tf.convert_to_tensor(valid)
    n = tf.size(valid)
    # Flatten boolean tensor
    valid_flat = tf.reshape(valid, [n])
    # Get flat indices where the tensor is true
    valid_idx = tf.boolean_mask(tf.range(n), valid_flat)
    # Shuffled valid indices
    valid_idx_shuffled = tf.random.shuffle(valid_idx, seed=seed)
    # Pick sample from shuffled indices
    valid_idx_sample = valid_idx_shuffled[:m]
    # Unravel indices
    return tf.transpose(tf.unravel_index(valid_idx_sample, tf.shape(valid)))

with tf.Graph().as_default(), tf.Session() as sess:
    valid = [[ True,  True, False,  True],
             [False,  True,  True, False],
             [False,  True, False, False]]
    m = 4
    print(sess.run(sample_indices(valid, m, seed=0)))
    # [[1 1]
    #  [1 2]
    #  [0 1]
    #  [2 1]]