Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/282.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python Tensorflow:SavedModelBuilder,如何以最佳验证精度保存模型_Python_Tensorflow_Tensorflow Serving_Tflearn - Fatal编程技术网

Python Tensorflow:SavedModelBuilder,如何以最佳验证精度保存模型

Python Tensorflow:SavedModelBuilder,如何以最佳验证精度保存模型,python,tensorflow,tensorflow-serving,tflearn,Python,Tensorflow,Tensorflow Serving,Tflearn,我已经阅读了tensorflow文档,但找不到使用SavedModelBuilder类以最佳验证精度保存模型的方法。 我正在使用tflearn构建模型,下面是我尝试过的工作,但这需要花费很多时间,我分别为每个时代运行fit方法并保存模型 for i in range(epoch): model.fit(trainX, trainY, n_epoch=1, validation_set=(testX, testY), show_metric=True, batch_size=8)

我已经阅读了tensorflow文档,但找不到使用SavedModelBuilder类以最佳验证精度保存模型的方法。 我正在使用tflearn构建模型,下面是我尝试过的工作,但这需要花费很多时间,我分别为每个时代运行fit方法并保存模型

for i in range(epoch):
    model.fit(trainX, trainY, n_epoch=1, validation_set=(testX, testY), show_metric=True, batch_size=8)
    builder = tf.saved_model.builder.SavedModelBuilder('/tmp/serving/model/' + str(i))
    builder.add_meta_graph_and_variables(model.session,
                                     ['TRAINING'],
                                     signature_def_map={
                                         'predict': prediction_sig
                                     })
    builder.save()

请建议是否有更好的方法。

找到了。它可以通过tflearn回调来实现。 谢谢


你知道我是怎么做到的吗?有相关代码吗?我需要在我的代码中输入这个,而不需要学习。@WeiLiu,我已经编辑了答案,请告诉我它是否适合您
class SaveModelCallback(tflearn.callbacks.Callback):
def __init__(self, accuracy_threshold):
    self.accuracy_threshold = accuracy_threshold
    self.accuracy = []
    self.max_accuracy = -1

def on_epoch_end(self, training_state):
    self.accuracy.append(training_state.global_acc)
    if training_state.val_acc > self.accuracy_threshold and training_state.val_acc > self.max_accuracy:
        self.max_accuracy = training_state.val_acc
        epoch = training_state.epoch
        self.save_model(epoch)

def save_model(self, epoch):
    print('saved epoch ' + str(epoch))
    builder = tf.saved_model.builder.SavedModelBuilder('/tmp/serving/model/' + str(epoch))
    builder.add_meta_graph_and_variables(model.session,
                                         [tf.saved_model.tag_constants.SERVING],
                                         signature_def_map={
                                             'predict': prediction_sig,
                                             tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                                                 classification_signature,
                                         })
    builder.save()

callback = SaveModelCallback(accuracy_threshold=0.8)
model.fit(trainX, trainY, n_epoch=200, validation_set=(testX, testY), show_metric=True, batch_size=8,
          callbacks=callback)