Tensorflow 图形模式下的GRU/RNN状态与急切执行模式下的GRU/RNN状态

Tensorflow 图形模式下的GRU/RNN状态与急切执行模式下的GRU/RNN状态,tensorflow,keras,deep-learning,Tensorflow,Keras,Deep Learning,我有同样的一段代码,先是在急切执行模式下编写,然后是在图形模式下编写。现在,我不太明白为什么GRU状态在图形模式下不保留,而在渴望模式下工作良好 以下是急切模式代码: import tensorflow as tf import xxhash import numpy as np tf.enable_eager_execution() rnn_units = 1024 def hash_code(arr): return xxhash.xxh64(arr).hexdigest()

我有同样的一段代码,先是在急切执行模式下编写,然后是在图形模式下编写。现在,我不太明白为什么GRU状态在图形模式下不保留,而在渴望模式下工作良好

以下是急切模式代码:

import tensorflow as tf 
import xxhash
import numpy as np 
tf.enable_eager_execution()
rnn_units = 1024 
def hash_code(arr): 
    return xxhash.xxh64(arr).hexdigest()

model = tf.keras.Sequential([tf.keras.layers.GRU(rnn_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform', batch_input_shape=[1, None, 256])])

lstm_wt = np.load('lstm_wt.npy', allow_pickle=True) # fixed weights for comparison 
lstm_re_wt = np.load('lstm_re_wt.npy', allow_pickle=True)
lstm_bias = np.load('lstm_bias.npy', allow_pickle=True)
model.layers[0].set_weights([lstm_wt, lstm_re_wt, lstm_bias])

op_embed = np.load('op_embed.npy', allow_pickle=True) # fixed input 
op_lstm = model(op_embed)
print(hash_code(op_lstm.numpy()))

op_lstm = model(op_embed)
print(hash_code(op_lstm.numpy()))

model.layers[0].reset_states() # now reset the state, you'll get back the initial output. 
op_lstm = model(op_embed)
print(hash_code(op_lstm.numpy()))
此代码的输出:

d092fdb4739588a3
cdfdf8b8e292c6e8
d092fdb4739588a3
现在,图形模型代码:

import tensorflow as tf 
import xxhas
import numpy as np 

# checking lstm 
op_embed = np.load('op_embed.npy', allow_pickle=True)
# load op_embed, lstm weights 
lstm_wt = np.load('lstm_wt.npy', allow_pickle=True)
lstm_re_wt = np.load('lstm_re_wt.npy', allow_pickle=True)
lstm_bias = np.load('lstm_bias.npy', allow_pickle=True)

rnn_units = 1024 
layers = tf.keras.layers.GRU(rnn_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform')
x_placeholder = tf.placeholder(shape=op_embed.shape, dtype=tf.float32)
op_lstm = layers(x_placeholder)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
layers.set_weights([lstm_wt, lstm_re_wt, lstm_bias])
tf.assign(layers.weights[0],lstm_wt ).eval(sess)
tf.assign(layers.weights[1], lstm_re_wt).eval(sess)
tf.assign(layers.weights[2], lstm_bias).eval(sess)
print('keras op hash',xxhash.xxh64(sess.run(op_lstm, feed_dict={x_placeholder:op_embed})).hexdigest())
print('keras op hash',xxhash.xxh64(sess.run(op_lstm, feed_dict={x_placeholder:op_embed})).hexdigest())

输出:

keras op hash d092fdb4739588a3
keras op hash d092fdb4739588a3
对于如何在图形模式下修复这种模糊性并保留状态,有什么见解吗?
以前有人问过类似的问题,但没有回答

在此处指定解决方案(答案部分),即使该解决方案存在于问题中提供的解决方案中,也是为了社区的利益

递归神经网络
RNN
GRU
LSTM
)在默认情况下以
非急切模式
/
图形模式
执行时,会丢失其
状态

如果我们想保留
状态
,我们需要在
RNN
调用期间传递
初始状态
,如下所示:

current_state = np.zeros((1,1))
state_placeholder = tf.placeholder(tf.float32, shape=[1, 1])
output, state = rnn(x, initial_state=state_placeholder)
然后,在执行输出时,我们还需要传递
状态
,以及
输入
feed\u dict

那么,代码,

print('keras op hash',xxhash.xxh64(sess.run(op_lstm, feed_dict={x_placeholder:op_embed})).hexdigest())
可以替换为

for _ in range(No_Of_TimeSteps):
        op_val, state_val = sess.run([op_lstm, state], feed_dict={x_placeholder:op_embed})).hexdigest(),
                            state_placeholder: current_state.astype(np.float32)})
        current_state = state_val
        print('keras op hash',xxhash.xxh64(op_val))

希望这有帮助。学习愉快

在此处指定解决方案(答案部分),即使该解决方案存在于问题中提供的解决方案中,也是为了社区的利益

递归神经网络
RNN
GRU
LSTM
)在默认情况下以
非急切模式
/
图形模式
执行时,会丢失其
状态

如果我们想保留
状态
,我们需要在
RNN
调用期间传递
初始状态
,如下所示:

current_state = np.zeros((1,1))
state_placeholder = tf.placeholder(tf.float32, shape=[1, 1])
output, state = rnn(x, initial_state=state_placeholder)
然后,在执行输出时,我们还需要传递
状态
,以及
输入
feed\u dict

那么,代码,

print('keras op hash',xxhash.xxh64(sess.run(op_lstm, feed_dict={x_placeholder:op_embed})).hexdigest())
可以替换为

for _ in range(No_Of_TimeSteps):
        op_val, state_val = sess.run([op_lstm, state], feed_dict={x_placeholder:op_embed})).hexdigest(),
                            state_placeholder: current_state.astype(np.float32)})
        current_state = state_val
        print('keras op hash',xxhash.xxh64(op_val))

希望这有帮助。学习愉快

我真的解决了这个问题。我在上面的问题的链接上回答了这个问题。我实际上解决了这个问题。我在上面问题的链接上回答了这个问题。