Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 在Keras中指定模型编译的多重损失函数_Python_Tensorflow_Keras_Deep Learning - Fatal编程技术网

Python 在Keras中指定模型编译的多重损失函数

Python 在Keras中指定模型编译的多重损失函数,python,tensorflow,keras,deep-learning,Python,Tensorflow,Keras,Deep Learning,我想指定两个损失函数,一个用于对象类,它是交叉熵,另一个用于边界框,它是均方误差。如何在model.compile中指定每个具有相应损失函数的输出 model = Sequential() model.add(Dense(128, activation='relu')) out_last_dense = model.add(Dense(128, activation='relu')) object_type = model.add(Dense(1, activation='softmax'))

我想指定两个损失函数,一个用于对象类,它是交叉熵,另一个用于边界框,它是均方误差。如何在model.compile中指定每个具有相应损失函数的输出

model = Sequential()

model.add(Dense(128, activation='relu'))
out_last_dense = model.add(Dense(128, activation='relu'))
object_type = model.add(Dense(1, activation='softmax'))(out_last_dense)
object_coordinates = model.add(Dense(4, activation='softmax'))(out_last_dense)

/// here is the problem i want to specify loss function for object type and coordinates
model.compile(loss= keras.losses.categorical_crossentropy,
   optimizer= 'sgd', metrics=['accuracy'])

首先,您不能在这里使用顺序API,因为您的模型有两个输出层(即,您编写的内容都是错误的,并且会引起错误)。相反,您必须使用:

现在,您可以根据上面给出的名称并使用字典为每个输出层指定损失函数(以及度量):

model.compile(loss={'type': 'binary_crossentropy', 'coord': 'mse'}, 
              optimizer='sgd', metrics={'type': 'accuracy', 'coord': 'mae'})
此外,请注意,您正在使用softmax作为激活函数,我已将其更改为上面的
sigomid
linear
。这是因为:1)在具有一个单元的层上使用softmax没有意义(如果有两个以上的类,则应使用softmax),2)另一层预测坐标,因此使用softmax根本不合适(除非问题公式允许您这样做)

model.compile(loss={'type': 'binary_crossentropy', 'coord': 'mse'}, 
              optimizer='sgd', metrics={'type': 'accuracy', 'coord': 'mae'})