Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/324.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/7/kubernetes/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 从上次训练中恢复训练过的变量_Python_Tensorflow - Fatal编程技术网

Python 从上次训练中恢复训练过的变量

Python 从上次训练中恢复训练过的变量,python,tensorflow,Python,Tensorflow,我试图从上次培训中恢复,但我能够保存模型,但无法恢复它。我有下面的代码,它运行时没有错误。我知道这不是恢复它,因为当我重新开始训练时,损失值会回到大值 有什么帮助吗 ckpt_path = os.path.abspath(os.path.dirname(__file__)) + '/weights/' labels_net, loss = vgg16(crop_size) optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).mini

我试图从上次培训中恢复,但我能够保存模型,但无法恢复它。我有下面的代码,它运行时没有错误。我知道这不是恢复它,因为当我重新开始训练时,损失值会回到大值

有什么帮助吗

ckpt_path = os.path.abspath(os.path.dirname(__file__)) + '/weights/'

labels_net, loss = vgg16(crop_size)
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
saver = tf.train.Saver(max_to_keep=3)

# Train
with tf.Session() as sess:

    # Load previous weights
    if os.listdir(ckpt_path) ==[]:
        sess.run(tf.global_variables_initializer())
    else:
        for file in os.listdir(ckpt_path):
            if 'vgg16' in file:
                try:
                    saver = tf.train.import_meta_graph(os.path.join(ckpt_path+file))
                    saver.restore(sess, ckpt_path+'vgg16-2')
                    print('Resuming training....')
                except:
                    sess.run(tf.global_variables_initializer())
            else:
                sess.run(tf.global_variables_initializer())

    print('Epoch', 'Training loss')
    for epoch_i in range(epochs):
        for batch_i in range(batches):

            batch_crops = getBatch(crops_train, batch_i, batch_size)
            batch_labels = getBatch(labels_train, batch_i, batch_size)
            x = sess.graph.get_tensor_by_name('x:0')
            y = sess.graph.get_tensor_by_name('y:0')
            sess.run(optimizer, feed_dict={x: batch_crops, y: batch_labels})#, options=run_options, run_metadata=run_metadata)

        train_loss = sess.run(loss, feed_dict={x: batch_crops, y: batch_labels})   
        print(epoch_i+1, train_loss)
        saver.save(sess, ckpt_path+'vgg16', global_step=2)

我对张量流知之甚少,但是。我认为你加载的文件与保存的文件不一样

您的加载线是saver.restore(sess,ckpt_路径+'vgg16-2')


因此,您正在保存到
vgg16
并从
vgg16-2

加载全局步骤=2将'-2'添加到名称中