如何使用tensorflow slim的批量标准化?

如何使用tensorflow slim的批量标准化?,tensorflow,neural-network,Tensorflow,Neural Network,我发现了一些关于在TensorFlow中使用批处理规范化的问题,但没有一个是关于它在slim中的包装器的 我正在尝试使用批处理规范化来训练MNIST数字分类器。虽然训练性能足够高,但验证或测试性能较差 我只构建了一个图,并通过了作为tf.placeholder的is_training,如下所示(每个conv和fc层都使用BN): 我还添加了控件依赖项,如下所示: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_o

我发现了一些关于在TensorFlow中使用批处理规范化的问题,但没有一个是关于它在slim中的包装器的

我正在尝试使用批处理规范化来训练MNIST数字分类器。虽然训练性能足够高,但验证或测试性能较差

我只构建了一个图,并通过了作为tf.placeholder的
is_training
,如下所示(每个conv和fc层都使用BN):

我还添加了控件依赖项,如下所示:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
    updates = tf.group(*update_ops)
    cross_entropy = control_flow_ops.with_dependencies([updates], cross_entropy)
对于培训阶段,我使用:

sess.run([net['cross_entropy'], net['accuracy']],
                                feed_dict={net['x']: batch_xs,
                                           net['y_']: batch_ys,
                                           net['keep_prob']: 1.0,
                                           net['is_training']: True})
sess.run(net['accuracy'], feed_dict={net['x']: batch_xs,
                                                net['y_']: batch_ys,
                                                net['keep_prob']: 1.0,
                                                net['is_training']: False})
对于验证阶段,我使用:

sess.run([net['cross_entropy'], net['accuracy']],
                                feed_dict={net['x']: batch_xs,
                                           net['y_']: batch_ys,
                                           net['keep_prob']: 1.0,
                                           net['is_training']: True})
sess.run(net['accuracy'], feed_dict={net['x']: batch_xs,
                                                net['y_']: batch_ys,
                                                net['keep_prob']: 1.0,
                                                net['is_training']: False})
出于测试目的,我将经过训练的模型转储到检查点,然后将
is\u training
作为False通过。同样,它的性能很差

那有什么问题吗?是关于
重用
参数吗?或者我需要自己维护BN层中的
gamma
beta
变量

为了便于复制,这是我的代码(将
阶段
设置为
训练
以训练模型并进行验证,
测试
以从检查点和测试恢复):

我终于解决了问题,请参见 详情请参阅。粗略地说,需要使用
slim.learning.create_train_op
来创建train op,并且必须耐心等待移动均值/方差参数来预热