Keras/Tensorflow中RNN单元的参数数量(SimpleRN、GRU和LSTM)
我对Keras/Tensorflow 2.1.0为不同RNN层报告的参数数量感到有点困惑。 对于SimpleRN元素,以下代码按预期工作:Keras/Tensorflow中RNN单元的参数数量(SimpleRN、GRU和LSTM),tensorflow,keras,recurrent-neural-network,Tensorflow,Keras,Recurrent Neural Network,我对Keras/Tensorflow 2.1.0为不同RNN层报告的参数数量感到有点困惑。 对于SimpleRN元素,以下代码按预期工作: import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import SimpleRNN, LSTM, GRU, Dense model = tf.keras.Sequential() model.add(Simpl
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, LSTM, GRU, Dense
model = tf.keras.Sequential()
model.add(SimpleRNN(1,input_shape=(None,1), use_bias=True))
model.summary()
并报告3个参数(输入、历史和偏差)-见下文。当我使用带有_bias=False的时,它会按预期报告两个参数
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn (SimpleRNN) (None, 1) 3
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
对于LSTM层,它也能按预期工作(至少对我来说)。守则:
model = tf.keras.Sequential()
model.add(LSTM(1,input_shape=(None,1), use_bias=True))
model.summary()
结果是12个参数,因为我们有4个门,每个门有三个参数。当我使用带有_bias=False的时,我有8个参数
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (None, 1) 12
=================================================================
Total params: 12
Trainable params: 12
Non-trainable params: 0
然而,当我检查GRU单元时,我得到了奇怪的结果
model = tf.keras.Sequential()
model.add(GRU(1,input_shape=(None,1), use_bias=True))
model.summary()
对于使用_bias=True
我有12个参数,对于使用_bias=False
6个参数。GRU有三个门,所以我希望它有9个参数,包括3个偏差(每个门一个)。因此,似乎还有三个额外的偏见
Model: "sequential_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
gru_1 (GRU) (None, 1) 12
=================================================================
Total params: 12
Trainable params: 12
Non-trainable params: 0
有人知道为什么这一层有这三个额外的偏差吗