Python 这两个Keras模型是等价的吗?如果我在输出层进行压缩,然后提取张量切片

Python 这两个Keras模型是等价的吗?如果我在输出层进行压缩,然后提取张量切片,python,tensorflow,keras,tf.keras,Python,Tensorflow,Keras,Tf.keras,我试图创建一个模型,其最后一层有多个输出。您可以看到3输出磁头的模型图 Model: "model_18" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to

我试图创建一个模型,其最后一层有多个输出。您可以看到
3
输出磁头的模型图

Model: "model_18"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
observations (InputLayer)       [(None, 84, 84, 4)]  0                                            
__________________________________________________________________________________________________
layer1 (Conv2D)                 (None, 20, 20, 32)   8224        observations[0][0]               
__________________________________________________________________________________________________
layer2 (Conv2D)                 (None, 9, 9, 64)     32832       layer1[0][0]                     
__________________________________________________________________________________________________
layer3 (Conv2D)                 (None, 7, 7, 64)     36928       layer2[0][0]                     
__________________________________________________________________________________________________
layer4 (Flatten)                (None, 3136)         0           layer3[0][0]                     
__________________________________________________________________________________________________
agent_indicator (InputLayer)    [(None, 2)]          0                                            
__________________________________________________________________________________________________
concatenate_6 (Concatenate)     (None, 3138)         0           layer4[0][0]                     
                                                                 agent_indicator[0][0]            
__________________________________________________________________________________________________
layer5 (Dense)                  (None, 512)          1607168     concatenate_6[0][0]              
__________________________________________________________________________________________________
gamma_q_layer0 (Dense)          (None, 6)            3078        layer5[0][0]                     
__________________________________________________________________________________________________
gamma_q_layer1 (Dense)          (None, 6)            3078        layer5[0][0]                     
__________________________________________________________________________________________________
gamma_q_layer2 (Dense)          (None, 6)            3078        layer5[0][0]                     
==================================================================================================
Total params: 1,694,386
Trainable params: 1,694,386
Non-trainable params: 0
__________________________________________________________________________________________________
None
由于这是一个我正在与库接口的自定义模型,因此它不起作用,因为库通常需要一个具有输入形状
(batch_size,6)
的单头

如果我将输出磁头连接到一个层中,如图所示,然后将其输入到库中,库中仍然需要一个张量,但我只需将输入形状修改为
(batch_size,6*3)
,然后将张量切成各个部分以提取各个磁头,分别计算每个磁头的损耗,这两种方法是否等效

Model: "model_33"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
observations (InputLayer)       [(None, 84, 84, 4)]  0                                            
__________________________________________________________________________________________________
layer1 (Conv2D)                 (None, 20, 20, 32)   8224        observations[0][0]               
__________________________________________________________________________________________________
layer2 (Conv2D)                 (None, 9, 9, 64)     32832       layer1[0][0]                     
__________________________________________________________________________________________________
layer3 (Conv2D)                 (None, 7, 7, 64)     36928       layer2[0][0]                     
__________________________________________________________________________________________________
layer4 (Flatten)                (None, 3136)         0           layer3[0][0]                     
__________________________________________________________________________________________________
agent_indicator (InputLayer)    [(None, 2)]          0                                            
__________________________________________________________________________________________________
concatenate_18 (Concatenate)    (None, 3138)         0           layer4[0][0]                     
                                                                 agent_indicator[0][0]            
__________________________________________________________________________________________________
layer5 (Dense)                  (None, 512)          1607168     concatenate_18[0][0]             
__________________________________________________________________________________________________
gamma_q_layer0 (Dense)          (None, 6)            3078        layer5[0][0]                     
__________________________________________________________________________________________________
gamma_q_layer1 (Dense)          (None, 6)            3078        layer5[0][0]                     
__________________________________________________________________________________________________
gamma_q_layer2 (Dense)          (None, 6)            3078        layer5[0][0]                     
__________________________________________________________________________________________________
concatenate_19 (Concatenate)    (None, 18)           0           gamma_q_layer0[0][0]             
                                                                 gamma_q_layer1[0][0]             
                                                                 gamma_q_layer2[0][0]             
==================================================================================================
Total params: 1,694,386
Trainable params: 1,694,386
Non-trainable params: 0
__________________________________________________________________________________________________
None
我不确定的是,在这两款车型的backprop期间,重量将如何调整。在第一个模型中,计算每个水头的损失,并将其平均化为最终损失。然后backprop调整权重。如果我像第二个模型一样,通过张量切片提取磁头,计算每个磁头的损耗,并将其平均化为最终损耗,这两个损耗值是否相同

多谢各位