Keras-如何在自动编码器上仅提取部分的组合模型
我已经使用KERAS编译了2个模型(分类和自动编码器),我能够评估模型,并且没有问题按照以下方式运行Keras-如何在自动编码器上仅提取部分的组合模型,keras,autoencoder,Keras,Autoencoder,我已经使用KERAS编译了2个模型(分类和自动编码器),我能够评估模型,并且没有问题按照以下方式运行 model.compile(loss={'classification': 'categorical_crossentropy', 'autoencoder': 'mean_squared_error'}, optimizer='adam', metrics={'cla
model.compile(loss={'classification': 'categorical_crossentropy',
'autoencoder': 'mean_squared_error'},
optimizer='adam',
metrics={'classification': 'accuracy'})
history = model.fit(x_train,
{'classification': y_train, 'autoencoder': x_train},
batch_size=300,
epochs=1,
validation_data= (x_test, {'classification': y_test}),
verbose=1)
第二部分要求我只利用autoencoder上的模型部分,并可视化8个图像样本。请参考下面的代码,它无法运行,因为代码是针对整个模型的,如何在autoencoder上仅提取模型的部分以绘制图像
# Generate reconstructions
num_reconstructions = 8
samples = x_test[:num_reconstructions]
targets = y_test[:num_reconstructions]
reconstructions = model.autoencoder.predict(samples)
import numpy as np
# Plot reconstructions
for i in np.arange(0, num_reconstructions):
# Get the sample and the recoax = pp.subplot(111)nstruction
sample = samples[i][:, :, 0]
reconstruction = reconstructions[i][:, :, 0]
input_class = targets[i]
# Matplotlib preparations
fig, axes = plt.subplots(1, 2)
# Plot sample and reconstruciton
axes[0].imshow(sample)
axes[0].set_title('Original image')
axes[1].imshow(reconstruction)
axes[1].set_title('Reconstruction with Conv2DTranspose')
fig.suptitle(f'MNIST target = {input_class}')
plt.show()
我的网络架构师参考如下:
- 我知道这样做的一种方法是在网络架构之后重新训练一个只有自动编码器的模型,但这将是一个不同的模型,它与之前评估的模型不同,损失/准确度对应于问题开始时一起评估的自动编码器/分类
- 这可以毫无问题地完成
我重新提出了您模型的一个版本:
inp = Input((28,28,1))
enc = Conv2D(63, 3, padding='same')(inp)
enc = MaxPool2D()(enc)
clas = Flatten()(enc)
clas = Dense(1000)(clas)
clas = Dropout(0.3)(clas)
clas = Dense(10, activation='softmax', name='classification')(clas)
dec = Dense(1000)(enc)
dec = Conv2DTranspose(63, 3, padding='same')(dec)
dec = Conv2D(1, 3, padding='same')(dec)
dec = UpSampling2D(name='autoencoder')(dec)
model = Model(inp, [clas,dec])
model.compile(loss={'classification': 'sparse_categorical_crossentropy', 'autoencoder': 'mean_squared_error'},
optimizer='adam',
metrics={'classification': 'accuracy'})
我创建虚拟数据并拟合整个结构(分类+自动编码器)
安装后,我只提取自动编码器部分
autoenc = Model(inp, dec)
rec2 = autoenc.predict(X)
检查结果(rec1必须等于rec2)
下面是完整的运行示例:不要忘记向上投票并接受它作为答案;-)
autoenc = Model(inp, dec)
rec2 = autoenc.predict(X)
(rec1 == rec2).all() # True ===> correct