Python 2.7 如何在Tensorflow中使用扫描来传递之前的几个状态。

Python 2.7 如何在Tensorflow中使用扫描来传递之前的几个状态。,python-2.7,tensorflow,Python 2.7,Tensorflow,我将使用tf.scan函数修改其他人为可变长度序列共享的DRAW(Deep Recurrence Attentive Writer)代码。因此,我需要将原始代码中的for循环更改为适合scan函数的结构。下面是代码的原始部分 ... for t in range(T): c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1] x_hat=x-tf.sigmoid(c_prev) # error image

我将使用
tf.scan
函数修改其他人为可变长度序列共享的DRAW(Deep Recurrence Attentive Writer)代码。因此,我需要将原始代码中的for循环更改为适合scan函数的结构。下面是代码的原始部分

...
for t in range(T):
    c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1]
    x_hat=x-tf.sigmoid(c_prev) # error image
    r=read(x,x_hat,h_dec_prev)
    h_enc,enc_state=encode(enc_state,tf.concat(1,[r,h_dec_prev]))
    z,mus[t],logsigmas[t],sigmas[t]=sampleQ(h_enc)
    h_dec,dec_state=decode(dec_state,z)
    cs[t]=c_prev+write(h_dec) # store results
    h_dec_prev=h_dec
    DO_SHARE=True # from now on, share variables
...
为了使用
tf.scan
,我需要通过几个先前的状态(
c_prev
h_dec_prev
)。但是,正如我所知,作为中的一个例子,scan只能为循环获取一个张量(是否正确?)

似乎只有一个
a
,它应该是张量。在这种情况下,我能想象的唯一可能的方法是将几个不同的状态张量展平并连接起来。但我担心它会弄乱代码,使速度大大降低,尤其是当状态大小都不同时。有没有有效(快速)的方法来处理这类问题

elems = np.array([1, 2, 3, 4, 5, 6])
sum = scan(lambda a, x: a + x, elems)