Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/362.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 在任意层上拆分Keras模型_Python_Python 3.x_Keras_Keras Layer - Fatal编程技术网

Python 在任意层上拆分Keras模型

Python 在任意层上拆分Keras模型,python,python-3.x,keras,keras-layer,Python,Python 3.x,Keras,Keras Layer,我试图创建一个函数,在用户指定的层上拆分Keras模型。我有以下代码: def return_split_models(model, layer): model_f, model_h = Sequential(), Sequential() for current_layer in range(0, layer+1): model_f.add(model.layers[current_layer]) for current_layer in range(l

我试图创建一个函数,在用户指定的层上拆分Keras模型。我有以下代码:

def return_split_models(model, layer):
    model_f, model_h = Sequential(), Sequential()
    for current_layer in range(0, layer+1):
        model_f.add(model.layers[current_layer])
    for current_layer in range(layer+1, len(model.layers)):
        model_h.add(model.layers[current_layer])
    return model_f, model_h

但是,当我们返回
model\u h
并调用摘要时,我们将看到一个
ValueError
,该模型从未被调用过。从其他帖子来看,这似乎与
model_h
的输入有关,但我找不到推广到任何指定层的示例。有人有任何指导吗?

您需要将
InputLayer
添加到
model\u h

from keras.layers import InputLayer

def return_split_models(model, layer):
    model_f, model_h = Sequential(), Sequential()
    for current_layer in range(0, layer+1):
        model_f.add(model.layers[current_layer])
    # add input layer
    model_h.add(InputLayer(input_shape=model.layers[layer+1].input_shape[1:]))
    for current_layer in range(layer+1, len(model.layers)):
        model_h.add(model.layers[current_layer])
    return model_f, model_h
例如:

model = Sequential()
model.add(Dense(50,input_shape=(100,)))
model.add(Dense(40))
model.add(Dense(30))
model.add(Dense(20))
model.add(Dense(10))

model_f, model_h = return_split_models(model, 2)
print(model_f.summary())
print(model_h.summary())

# print
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 50)                5050      
_________________________________________________________________
dense_2 (Dense)              (None, 40)                2040      
_________________________________________________________________
dense_3 (Dense)              (None, 30)                1230      
=================================================================
Total params: 8,320
Trainable params: 8,320
Non-trainable params: 0
_________________________________________________________________
None
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 20)                620       
_________________________________________________________________
dense_5 (Dense)              (None, 10)                210       
=================================================================
Total params: 830
Trainable params: 830
Non-trainable params: 0
_________________________________________________________________
None