Python Keras中的LSTM:顺序API和函数API的参数数量不同
使用顺序API 如果我使用Keras的顺序API和以下代码创建LSTM:Python Keras中的LSTM:顺序API和函数API的参数数量不同,python,machine-learning,keras,lstm,recurrent-neural-network,Python,Machine Learning,Keras,Lstm,Recurrent Neural Network,使用顺序API 如果我使用Keras的顺序API和以下代码创建LSTM: from keras.models import Sequential from keras.layers import LSTM model = Sequential() model.add(LSTM(2, input_dim=3)) from keras.models import Model from keras.layers import Input from keras.layers import LSTM
from keras.models import Sequential
from keras.layers import LSTM
model = Sequential()
model.add(LSTM(2, input_dim=3))
from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM
inputs = Input(shape=(3, 1))
lstm = LSTM(2)(inputs)
model = Model(input=inputs, output=lstm)
然后
返回48参数,如中所示,该参数正常
快速详细信息:
使用功能API
但是,如果我使用以下代码对函数API执行相同的操作:
from keras.models import Sequential
from keras.layers import LSTM
model = Sequential()
model.add(LSTM(2, input_dim=3))
from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM
inputs = Input(shape=(3, 1))
lstm = LSTM(2)(inputs)
model = Model(input=inputs, output=lstm)
然后
返回32个参数
为什么会有这样的差异?不同之处在于,当您将
input\u dim=x
传递到RNN层(包括LSTM层)时,这意味着输入形状是(无,x)
,即存在不同数量的时间步,其中每个时间步都是长度向量x
。但是,在函数API示例中,您将shape=(3,1)
指定为输入形状,这意味着有3个时间步,每个时间步都有一个特性。因此,参数的数量为:4*输出尺寸*(输出尺寸+输入尺寸+1)=4*2*(2+1+1)=32
,这是模型摘要中显示的数字
此外,如果您使用Keras 2.x.x,在使用RNN层的input\u dim
参数的情况下,您会得到警告:
用户警告:循环中的input\u dim
和input\u length
参数
不推荐使用图层。改用input\u shape
用户警告:更新对keras2api的LSTM调用:LSTM(2,input_-shape=(None,3))
我按如下方式解决了这个问题:
Case 1:
m (input) = 3
n (output) = 2
params = 4 * ( (input * output) + (output ^ 2) + output)
= 4 * (3*2 + 2^2 + 2)
= 4 * (6 + 4 + 2)
= 4 * 12
= 48
Case 2:
m (input) = 1
n (output) = 2
params = 4 * ( (input * output) + (output ^ 2) + output)
= 4 * (1*2 + 2^2 + 2)
= 4 * (2 + 4 + 2)
= 4 * 8
= 32
如果在函数API中,我将inputs=Input(shape=(3,1))替换为inputs=Input(shape=(1,3)),我将得到48个参数,正如预期的那样。谢谢
Case 1:
m (input) = 3
n (output) = 2
params = 4 * ( (input * output) + (output ^ 2) + output)
= 4 * (3*2 + 2^2 + 2)
= 4 * (6 + 4 + 2)
= 4 * 12
= 48
Case 2:
m (input) = 1
n (output) = 2
params = 4 * ( (input * output) + (output ^ 2) + output)
= 4 * (1*2 + 2^2 + 2)
= 4 * (2 + 4 + 2)
= 4 * 8
= 32