Python Tensorflow有效重叠加法

Python Tensorflow有效重叠加法,python,tensorflow,Python,Tensorflow,我试图在Tensorflow中实现,但我正在努力将numpyoutput\u seq[start:end]+=chunk转换为Tensorflow。现在我是output\u seq=output\u seq+tf.pad(chunk,[[start,length-end]]),但在长序列上这真的很慢 我也有一种预感,也许你可以用聚集/分散来做一些小把戏,但我不太明白。以下是我的暴力尝试: import tensorflow as tf input = [[1, 2, 3, 4], [5, 6,

我试图在Tensorflow中实现,但我正在努力将numpy
output\u seq[start:end]+=chunk
转换为Tensorflow。现在我是
output\u seq=output\u seq+tf.pad(chunk,[[start,length-end]])
,但在长序列上这真的很慢

我也有一种预感,也许你可以用聚集/分散来做一些小把戏,但我不太明白。以下是我的暴力尝试:

import tensorflow as tf

input = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]

def overlap_add(overlap):
    with tf.Graph().as_default(), tf.Session() as sess:

        x = tf.constant(input)

        num_chunks = tf.shape(x)[0]
        chunk_size = tf.shape(x)[1]
        hop_length = chunk_size - overlap
        out_len = chunk_size + hop_length * (num_chunks - 1)

        y = tf.zeros((out_len,), dtype=tf.int32)

        def body(i, y):
            j = i * hop_length
            padding = [[j, out_len - (j + chunk_size)]]
            chunk = x[i]
            y = y + tf.pad(chunk, padding)
            return (i + 1, y)

        i = tf.constant(0)
        i, y = tf.while_loop(
            cond=lambda i, _: tf.less(i, num_chunks),
            body=body,
            loop_vars=[i, y])

        return sess.run(y)


for i in range(4):
    print 'overlap_add(%d): %s' % (i, overlap_add(i))

# overlap_add(0): [ 1  2  3  4  5  6  7  8  9 10 11 12]
# overlap_add(1): [ 1  2  3  9  6  7 17 10 11 12]
# overlap_add(2): [ 1  2  8 10 16 18 11 12]
# overlap_add(3): [ 1  7 18 21 19 12]

也可以在Tensorflow中使用切片:

a[1:3].assign(a[1:3] + b[1:3]).eval()
由于某些原因,未实现assign_add。对我来说,这好像是一只虫子

a[1:3].assign_add(b[1:3]).eval() # Doesn't work

更新:现在Tensorflow本身有一个函数

旧答案: 搜索文档,发现
未排序的段\u和

import tensorflow as tf

input = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]

def tf_repeat(a, repeats):
    return tf.reshape(tf.tile(tf.reshape(a, [-1, 1]),
                              [1, repeats]), [-1])

def overlap_add(overlap):
    with tf.Graph().as_default(), tf.Session() as sess:

        x = tf.constant(input)
        x_flat = tf.reshape(x, [-1])

        num_chunks = tf.shape(x)[0]
        chunk_size = tf.shape(x)[1]
        hop_len = chunk_size - overlap
        flat_len = num_chunks * chunk_size
        out_len = chunk_size + hop_len * (num_chunks - 1)

        # e.g. [0,1,2,3, 2,3,4,5, 4,5,6,7] for overlap == 2
        indexes = tf.range(flat_len) - tf_repeat(tf.range(num_chunks), chunk_size) * overlap

        return sess.run(tf.unsorted_segment_sum(x_flat, indexes, out_len))


for i in range(4):
    print 'overlap_add(%d): %s' % (i, overlap_add(i))

# overlap_add(0): [ 1  2  3  4  5  6  7  8  9 10 11 12]
# overlap_add(1): [ 1  2  3  9  6  7 17 10 11 12]
# overlap_add(2): [ 1  2  8 10 16 18 11 12]
# overlap_add(3): [ 1  7 18 21 19 12]

非常感谢。只有变量
才支持引发
切片赋值。我将
y
更改为tf.Variable,但在while_循环体中
y
不再是变量,而是
,错误仍然存在。