Python tensorflow中带条件的求和

Python tensorflow中带条件的求和,python,tensorflow,Python,Tensorflow,给我一个随机行的二维张量。在应用了tf.math.greater()和tf.cast(tf.int32)之后,剩下的张量是0和1。现在,我想将reduce sum应用到该矩阵上,但有一个条件:如果至少有一个1求和,后面跟着一个0,那么我也要删除后面的所有1,这意味着1 01应该导致1,而不是2 我试图用tf.scan()解决这个问题,但是我还没有找到一个能够处理0开始的函数,因为行可能看起来像:0 0 1 一个想法是将矩阵的下半部分设置为1(bc我知道对角线上剩下的都是0),然后使用类似tf.s

给我一个随机行的二维张量。在应用了
tf.math.greater()
tf.cast(tf.int32)
之后,剩下的张量是0和1。现在,我想将reduce sum应用到该矩阵上,但有一个条件:如果至少有一个1求和,后面跟着一个0,那么我也要删除后面的所有1,这意味着
1 01
应该导致
1
,而不是
2

我试图用
tf.scan()
解决这个问题,但是我还没有找到一个能够处理0开始的函数,因为行可能看起来像:
0 0 1
一个想法是将矩阵的下半部分设置为1(bc我知道对角线上剩下的都是0),然后使用类似
tf.scan()
run的函数来过滤斑点(请参见下面的代码和错误消息)

设z为tf.cast之后的矩阵。

helper = tf.matrix_band_part(tf.ones_like(z), -1, 0)
z = tf.math.logical_or(tf.cast(z, tf.bool), tf.cast(helper,tf.bool))
z = tf.cast(z, tf.int32)
z = tf.scan(lambda a, x: x if a == 1 else 0 ,z)
导致:


ValueError:value([])的形状不兼容,应为([5])
IIUC,这是一种无需扫描或循环即可完成所需操作的方法。它可能有点复杂,实际上是迭代列两次(一个cumsum和一个cumprod),但是作为向量化操作,我认为它可能更快。代码是TF2.x,但在TF1.x中运行相同的代码(除了最后一行)


因此,为了看看我是否理解正确,对于每一行,您想计算连续行的第一个“组”中的行数,对吗?不使用
cumsum
(以及以后的
reduce\u max
),而通过
cumprod
,只使用
mask
,可能更快,最后直接减少总和。@Albert谢谢你的反馈。但是我使用
cumsum
来定义掩码,找到
1
之后的第一个
0
(我简化了制作掩码的方法,因为我意识到它比应该的复杂得多)。@jdehesa啊,我快速浏览了那部分。我认为您只需执行
mask=cumprod(tf.cast(tf.not_equal(a,0),tf.uint8))
。这不是更容易吗?@Albert问题是OP希望计算一个
1
之后的第一个
0
的总和,因此以
0
开头的行的总和将失败。
import tensorflow as tf

# Example data
a = tf.constant([[0, 0, 0, 0],
                 [1, 0, 0, 0],
                 [0, 1, 1, 0],
                 [0, 1, 0, 1],
                 [1, 1, 1, 0],
                 [1, 1, 0, 1],
                 [0, 1, 1, 1],
                 [1, 1, 1, 1]])
# Cumsum columns
c = tf.math.cumsum(a, axis=1)
# Column-wise differences
diffs = tf.concat([tf.ones([tf.shape(c)[0], 1], c.dtype), c[:, 1:] - c[:, :-1]], axis=1)
# Find point where we should not sum anymore (cumsum is not zero and difference is zero)
cutoff = tf.equal(a, 0) & tf.not_equal(c, 0)
# Make mask
mask = tf.math.cumprod(tf.dtypes.cast(~cutoff, tf.uint8), axis=1)
# Compute result
result = tf.reduce_max(c * tf.dtypes.cast(mask, c.dtype), axis=1)
print(result.numpy())
# [0 1 2 1 3 2 3 4]