Python Keras Sequential API正在用抽象类取代每一层';模块包装器';建立模型后
我正在尝试使用Tensorflow的(2.5)Keras API创建一个序列模型。Python Keras Sequential API正在用抽象类取代每一层';模块包装器';建立模型后,python,tensorflow,keras,Python,Tensorflow,Keras,我正在尝试使用Tensorflow的(2.5)Keras API创建一个序列模型。 在训练我的模型之后,我发现我无法保存我的模型,因为没有实现层ModuleWrapper的配置,这给我带来了很多困惑,因为我没有使用任何称为“ModuleWrapper”的层。我也没有使用任何自制的图层 经过大量测试,我发现Keras Sequential API不知何故无法识别自己的层,并用抽象类(?)ModuleWrapper替换它们 任何关于为什么会发生这种情况的帮助都将不胜感激 进口 模型 使用model.
在训练我的模型之后,我发现我无法保存我的模型,因为没有实现层
ModuleWrapper
的配置,这给我带来了很多困惑,因为我没有使用任何称为“ModuleWrapper”的层。我也没有使用任何自制的图层
经过大量测试,我发现Keras Sequential API不知何故无法识别自己的层,并用抽象类(?)ModuleWrapper替换它们
任何关于为什么会发生这种情况的帮助都将不胜感激
进口
模型
使用model.summary()
使用打印(模型层)
[,,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
]
您应该按如下方式导入模块,也不要将tf2.x
与旧的独立keras
混合在同一导入中
import tensorflow as tf # version 2.5
from tensorflow import keras
from tensorflow.keras.layers import LeakyReLU, Softmax
from tensorflow.keras.layers import Conv2D, MaxPooling2D, SeparableConv2D
from tensorflow.keras.layers import Dense, Flatten, Dropout, Reshape, Activation
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LSTM
除此之外,模型定义中的所有图层名称都应包含唯一的名称。但是在你的模型<代码> DROPUT5出现两次,所以考虑这个。
def create_model():
input_shape = (180, 18, 1)
data_format = 'channels_last'
batch_norm_axis = -1 # must be 1 if data_format = 'channels_first'
conv_activation = 'relu'
padding = 'same'
model = keras.Sequential(name="CPDP_4h_1dim")
model.add(BatchNormalization(name="batch0"))
model.add(Conv2D(name="Conv1", filters=64, input_shape=input_shape, kernel_size=(6, 6), padding=padding, activation=conv_activation, data_format=data_format))
model.add(BatchNormalization(name="batch1", axis=batch_norm_axis))
model.add(MaxPooling2D(name="pool1", pool_size=(2, 2), strides=(1,1)))
model.add(Dropout(name="dropout1", rate=0.35))
model.add(Conv2D(name="Conv2", filters=128, kernel_size=(6, 6), padding=padding, activation=conv_activation, data_format=data_format))
model.add(BatchNormalization(name="batch2", axis=batch_norm_axis))
model.add(MaxPooling2D(name="pool2", pool_size=(2, 2), strides=(1,1)))
model.add(Dropout(name="dropout2", rate=0.35))
model.add(Conv2D(name="Conv3", filters=128, kernel_size=(3, 3), padding=padding, activation=conv_activation, data_format=data_format))
model.add(BatchNormalization(name="batch3", axis=batch_norm_axis))
model.add(MaxPooling2D(name="pool3", pool_size=(2, 2), strides=(1,1)))
model.add(Dropout(name="dropout3", rate=0.15))
model.add(Conv2D(name="Conv4", filters=256, kernel_size=(3, 3), padding=padding, activation=conv_activation, data_format=data_format))
model.add(BatchNormalization(name="batch4", axis=batch_norm_axis))
model.add(MaxPooling2D(name="pool4", pool_size=(2, 2), strides=(1,1)))
model.add(Dropout(name="dropout4", rate=0.25))
model.add(Conv2D(name="Conv5", filters=256, kernel_size=(3, 3), padding=padding, activation=conv_activation, data_format=data_format))
model.add(BatchNormalization(name="batch5", axis=batch_norm_axis))
model.add(MaxPooling2D(name="pool5", pool_size=(2, 2), strides=(1,1)))
model.add(Dropout(name="dropout5", rate=0.25))
# [batch, width, height, features]
# width are timesteps
# LSTM expectationms: [batch, timesteps, feature]
# --> transform to [batch, width, (height,features)]
model.add(Reshape((175, 13*256), input_shape=(None, 175, 13, 256), name="reshape_for_lstm"))
model.add(LSTM(name="lstm1", units=512, return_sequences=True, dropout=0.25))
model.add(LSTM(name="lstm2", units=256, return_sequences=False, dropout=0.15))
model.add(Flatten(name="flatten1"))
model.add(Dense(name="dense1", units=256))
model.add(Activation('relu'))
model.add(Dropout(name="dropout5", rate=0.15))
model.add(Dense(name="dense15", units=256))
model.add(Activation('relu'))
model.add(Dropout(name="dropout51", rate=0.15))
model.add(Dense(name="dense2", units=128))
model.add(Activation('relu'))
model.add(Dropout(name="dropout6", rate=0.15))
model.add(Dense(name="dense3", units=64))
model.add(Activation('relu'))
model.add(Dropout(name="dropout7", rate=0.15))
model.add(Dense(name="dense4", units=3))
model.add(Activation('softmax'))
return model
model = create_model()
model.build(input_shape=(None, 180, 18, 1))
Model: "CPDP_4h_1dim"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
module_wrapper_472 (ModuleWr (None, 180, 18, 1) 4
_________________________________________________________________
module_wrapper_473 (ModuleWr (None, 180, 18, 64) 2368
_________________________________________________________________
module_wrapper_474 (ModuleWr (None, 180, 18, 64) 256
_________________________________________________________________
module_wrapper_475 (ModuleWr (None, 179, 17, 64) 0
_________________________________________________________________
module_wrapper_476 (ModuleWr (None, 179, 17, 64) 0
_________________________________________________________________
module_wrapper_477 (ModuleWr (None, 179, 17, 128) 295040
_________________________________________________________________
module_wrapper_478 (ModuleWr (None, 179, 17, 128) 512
_________________________________________________________________
module_wrapper_479 (ModuleWr (None, 178, 16, 128) 0
_________________________________________________________________
module_wrapper_480 (ModuleWr (None, 178, 16, 128) 0
_________________________________________________________________
module_wrapper_481 (ModuleWr (None, 178, 16, 128) 147584
_________________________________________________________________
module_wrapper_482 (ModuleWr (None, 178, 16, 128) 512
_________________________________________________________________
module_wrapper_483 (ModuleWr (None, 177, 15, 128) 0
_________________________________________________________________
module_wrapper_484 (ModuleWr (None, 177, 15, 128) 0
_________________________________________________________________
module_wrapper_485 (ModuleWr (None, 177, 15, 256) 295168
_________________________________________________________________
module_wrapper_486 (ModuleWr (None, 177, 15, 256) 1024
_________________________________________________________________
module_wrapper_487 (ModuleWr (None, 176, 14, 256) 0
_________________________________________________________________
module_wrapper_488 (ModuleWr (None, 176, 14, 256) 0
_________________________________________________________________
module_wrapper_489 (ModuleWr (None, 176, 14, 256) 590080
_________________________________________________________________
module_wrapper_490 (ModuleWr (None, 176, 14, 256) 1024
_________________________________________________________________
module_wrapper_491 (ModuleWr (None, 175, 13, 256) 0
_________________________________________________________________
module_wrapper_492 (ModuleWr (None, 175, 13, 256) 0
_________________________________________________________________
module_wrapper_493 (ModuleWr (None, 175, 3328) 0
_________________________________________________________________
module_wrapper_494 (ModuleWr (None, 175, 512) 7866368
_________________________________________________________________
module_wrapper_495 (ModuleWr (None, 256) 787456
_________________________________________________________________
module_wrapper_496 (ModuleWr (None, 256) 0
_________________________________________________________________
module_wrapper_497 (ModuleWr (None, 256) 65792
_________________________________________________________________
module_wrapper_498 (ModuleWr (None, 256) 0
_________________________________________________________________
module_wrapper_499 (ModuleWr (None, 256) 0
_________________________________________________________________
module_wrapper_500 (ModuleWr (None, 256) 65792
_________________________________________________________________
module_wrapper_501 (ModuleWr (None, 256) 0
_________________________________________________________________
module_wrapper_502 (ModuleWr (None, 256) 0
_________________________________________________________________
module_wrapper_503 (ModuleWr (None, 128) 32896
_________________________________________________________________
module_wrapper_504 (ModuleWr (None, 128) 0
_________________________________________________________________
module_wrapper_505 (ModuleWr (None, 128) 0
_________________________________________________________________
module_wrapper_506 (ModuleWr (None, 64) 8256
_________________________________________________________________
module_wrapper_507 (ModuleWr (None, 64) 0
_________________________________________________________________
module_wrapper_508 (ModuleWr (None, 64) 0
_________________________________________________________________
module_wrapper_509 (ModuleWr (None, 3) 195
_________________________________________________________________
module_wrapper_510 (ModuleWr (None, 3) 0
=================================================================
Total params: 10,160,327
Trainable params: 10,158,661
Non-trainable params: 1,666
_________________________________________________________________
[<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42845faf90>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840f7f90>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840f2c90>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840f2b90>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840f2490>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42843426d0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840e3710>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840f9c90>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840fd590>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840fb310>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840f9a90>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840ed3d0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840edf90>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840ed290>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840e7a50>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840e73d0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840e4690>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840ddf10>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840c8b10>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f4284097290>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f4284097690>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f4284097950>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840a2050>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840a2ad0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840a2e50>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840aa350>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840aaad0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840aaf10>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840aad50>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840b6710>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840fb990>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840b63d0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840b6a10>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840b69d0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840c1110>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840c1e90>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f42840c12d0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f428404e1d0>,
<tensorflow.python.keras.engine.functional.ModuleWrapper at 0x7f428404eb10>]
import tensorflow as tf # version 2.5
from tensorflow import keras
from tensorflow.keras.layers import LeakyReLU, Softmax
from tensorflow.keras.layers import Conv2D, MaxPooling2D, SeparableConv2D
from tensorflow.keras.layers import Dense, Flatten, Dropout, Reshape, Activation
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LSTM