Python TensorFlow检查点保存和读取

Python TensorFlow检查点保存和读取,python,io,tensorflow,Python,Io,Tensorflow,我有一个基于TensorFlow的神经网络和一组变量 培训功能如下: def train(load = True, step) """ Defining the neural network is skipped here """ train_step = tf.train.AdamOptimizer(1e-4).minimize(mse) # Saver saver = tf.train.Saver() if not load:

我有一个基于TensorFlow的神经网络和一组变量

培训功能如下:

def train(load = True, step)
    """
    Defining the neural network is skipped here
    """

    train_step = tf.train.AdamOptimizer(1e-4).minimize(mse)
    # Saver
    saver = tf.train.Saver()

    if not load:
        # Initalizing variables
        sess.run(tf.initialize_all_variables())
    else:
        saver.restore(sess, 'Variables/map.ckpt')
        print 'Model Restored!'

    # Perform stochastic gradient descent
    for i in xrange(step):
        train_step.run(feed_dict = {x: train, y_: label})

    # Save model
    save_path = saver.save(sess, 'Variables/map.ckpt')
    print 'Model saved in file: ', save_path
    print 'Training Done!'
# First train
train(False, 1)
# Following train
for i in xrange(10):
    train(True, 10)
我是这样调用培训功能的:

def train(load = True, step)
    """
    Defining the neural network is skipped here
    """

    train_step = tf.train.AdamOptimizer(1e-4).minimize(mse)
    # Saver
    saver = tf.train.Saver()

    if not load:
        # Initalizing variables
        sess.run(tf.initialize_all_variables())
    else:
        saver.restore(sess, 'Variables/map.ckpt')
        print 'Model Restored!'

    # Perform stochastic gradient descent
    for i in xrange(step):
        train_step.run(feed_dict = {x: train, y_: label})

    # Save model
    save_path = saver.save(sess, 'Variables/map.ckpt')
    print 'Model saved in file: ', save_path
    print 'Training Done!'
# First train
train(False, 1)
# Following train
for i in xrange(10):
    train(True, 10)
我进行这种培训是因为我需要向我的模型提供不同的数据集。但是,如果我以这种方式调用train函数,TensorFlow将生成错误消息,指示它无法从文件中读取保存的模型

经过一些实验后,我发现这是因为检查点保存很慢。在将文件写入磁盘之前,下一列函数将开始读取,从而生成错误

我尝试使用time.sleep()函数在每次调用之间进行一些延迟,但没有成功


有人知道如何解决这种写/读错误吗?多谢各位

代码中存在一个微妙的问题:每次调用
train()
函数时,会向同一张TensorFlow图中添加更多节点,用于所有模型变量和神经网络的其余部分。这意味着每次构造
tf.train.Saver()
,它都包含以前调用
train()
的所有变量。每次重新创建模型时,都会使用额外的
\N
后缀创建变量,以给它们一个唯一的名称:

  • 用变量
    var\u a
    var\u b
    构造的储蓄器
  • 用变量
    var\u a
    var\u b
    var\u a\u 1
    var\u b\u 1
    构造的储蓄器
  • 由变量
    var\u a
    var\u b
    var\u a\u 1
    var\u b\u 1
    var\u a\u 2
    构成的储蓄器
  • 等等
  • tf.train.Saver
    的默认行为是将每个变量与相应op的名称相关联。这意味着
    var\u a\u 1
    不会从
    var\u a
    初始化,因为它们最终的名称不同

    解决方案是每次调用
    train()
    时创建一个新的图形。修复此问题的最简单方法是更改主程序,以便为对
    train()
    的每次调用创建一个新的图形,如下所示:

    # First train
    with tf.Graph().as_default():
        train(False, 1)
    
    # Following train
    for i in xrange(10):
        with tf.Graph().as_default():
            train(True, 10)
    

    ……或等价地,可以在()/代码>函数>

    内移动< <代码> 块。因此,在图中添加节点的行为类似于C++中的类/对象的行为?每次train()函数完成时,图形对象都不会被破坏。如果我继续添加与W1,b1同名的变量,它将切换到W1_1和b1_1,从而使加载失败。我的理解正确吗?这个问题是由于训练过程结束时没有调用析构函数造成的吗?非常感谢。本质上,除非您显式构造
    tf.Graph
    并使用
    with
    构造将其设置为默认值,否则所有节点都将添加到仅在流程结束时销毁的全局图中。(这并不理想,但它使其他一些用例变得更容易。)使用
    with
    块可确保在块的末尾取消图形的注册,这将为您提供所需的行为并避免内存泄漏!