Python 如何命名Keras fit输出中的自定义指标

Python 如何命名Keras fit输出中的自定义指标,python,keras,Python,Keras,我在培训Keras模型时使用自定义指标。它工作正常,只是model.fit_generator(…)输出中的度量名称不可解释(注意:Tensorboard也使用了这些错误名称) 这是我正在做的一个可复制的例子:度量使用一个参数(除了预测和基本事实),因此我定义了一个工厂来生成无参数度量函数,类似于: def my_dummy_metric(y_true, y_pred, the_param=1.0): return the_param * keras.backend.ones((1))

我在培训Keras模型时使用自定义指标。它工作正常,只是
model.fit_generator(…)
输出中的度量名称不可解释(注意:Tensorboard也使用了这些错误名称)

这是我正在做的一个可复制的例子:度量使用一个参数(除了预测和基本事实),因此我定义了一个工厂来生成无参数度量函数,类似于:

def my_dummy_metric(y_true, y_pred, the_param=1.0):
    return the_param * keras.backend.ones((1))

def my_metric_factory(the_param=1.0):
    def fn(y_true, y_pred):
        return my_dummy_metric(y_true, y_pred, the_param=the_param)

    return fn

my_second_metric = my_metric_factory(2.0)
my_other_metric = my_metric_factory(3.14)
然后我编译并训练我的模型:

model.compile(my_optim, my_loss, [my_second_metric, my_other_metric])
history = model.fit_generator(...)
print(history.params['metrics'])
我的问题是
history
中的度量名称是
fn
fn\u 1
val\u fn
val\u fn\u 1
。Tensorboard也使用这些名称,您需要了解implem细节才能理解它们

相反,使用简单的自定义函数时,如果没有工厂,则不会出现此问题:

model.compile(my_optim, my_loss, [my_dummy_metric])
history = model.fit_generator(...)
print(history.params['metrics'])
在基于工厂的用例中,是否有可能获得
my_XXX_metric
作为输出名称


环境:使用keras2.2.4、tf1.14.0、python3.7是的,这是可能的。在公制工厂中,只需设置一个适当的公制函数的
\uuu name\uuu
。例如:

def my_metric_factory(the_param=1.0):
    def fn(y_true, y_pred):
        return my_dummy_metric(y_true, y_pred, the_param=the_param)

    fn.__name__ = 'metricname_{}'.format(the_param)
    return fn

如果有人将TF 2.0或更高版本与默认度量(非自定义度量)一起使用,您可以在度量中尝试名称参数,如下所示:

# Recall inherits Metric class and similarly name of most of the metrics could be changed
recall = tf.keras.metrics.Recall(name = 'custom_name')
model.compile(loss=..., optimizer=...,
                        metrics=[recall])
model.fit(...)

如果您想为度量提供自定义名称,例如precision@k, recall@k,TopKCategoricalAccuracy,您可以按照如下方式操作:

k = 5
precision = tf.keras.metrics.Precision(top_k=k, name = 'precision_' + str(k))
model.compile(loss=..., optimizer=...,
                        metrics=[precision])
model.fit(...)

资料来源: