Python 3.x 冻结后批量标准化的怪异行为与优化

Python 3.x 冻结后批量标准化的怪异行为与优化,python-3.x,tensorflow,batch-normalization,Python 3.x,Tensorflow,Batch Normalization,我在TF1.14中遇到了异常的批处理规范化行为(我将进一步更新TF2)。图形冻结后,精度大大降低。我发现这种差异的根源在于批量标准化。本文末尾附有一个较小的模型代码来重现这种行为。 训练结束后,我使用tf.graph\u util.convert\u variables\u to\u constants()冻结图形,然后使用optimize\u进行优化,以进行推理() 让我们直接进入结果,您可以看到,从训练到测试,它的变化很小,但在冻结后,移动平均值、β和γ发生了巨大的变化,但移动标准没有变化:

我在TF1.14中遇到了异常的批处理规范化行为(我将进一步更新TF2)。图形冻结后,精度大大降低。我发现这种差异的根源在于批量标准化。本文末尾附有一个较小的模型代码来重现这种行为。 训练结束后,我使用
tf.graph\u util.convert\u variables\u to\u constants()
冻结图形,然后使用
optimize\u进行优化,以进行推理()

让我们直接进入结果,您可以看到,从训练到测试,它的变化很小,但在冻结后,移动平均值、β和γ发生了巨大的变化,但移动标准没有变化:

  • [x] 培训:
0%| | 0/100[00:00
import tensorflow as tf
from tqdm import tqdm
import numpy as np
from util import print_nodes_name_shape, check_N_mkdir
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from tensorflow.python.framework import dtypes
import os

inputs = np.ones((8, 2, 2, 1))
outputs = np.arange(8 * 2 * 2 * 3).reshape((8, 2, 2, 3))
input_ph = tf.placeholder(tf.float32, shape=(None, 2, 2, 1), name='input_ph')
output_ph = tf.placeholder(tf.float32, shape=(None, 2, 2, 3), name='output_ph')
is_training = tf.placeholder(tf.bool, shape=[], name='is_training')

# build a one layer Full layer with BN and save 2 ckpts
with tf.name_scope('model'):
    out = tf.reshape(input_ph, shape=(-1, 2 * 2 * 1), name='flatten')
    with tf.variable_scope('dnn1', reuse=False):
        w1 = tf.get_variable('w1', dtype=tf.float32, shape=[4 * 1, 4 * 3], initializer=tf.initializers.glorot_normal())
        b1 = tf.get_variable('b1', dtype=tf.float32, shape=[4 * 3], initializer=tf.initializers.glorot_normal())
    out = tf.matmul(out, w1) + b1
    out = tf.layers.batch_normalization(out, training=is_training, name='BN')
    logits = tf.nn.relu(out)
    logits = tf.reshape(logits, shape=(-1, 2, 2, 3))

with tf.name_scope('loss'):
    MSE = tf.losses.mean_squared_error(labels=output_ph, predictions=logits)

with tf.name_scope('operations'):
    opt = tf.train.AdamOptimizer(learning_rate=0.0001, name='Adam')
    grads = opt.compute_gradients(MSE)
    train_op = opt.apply_gradients(grads, name='apply_grad')

# train
with tf.Session() as sess:
    # prepare
    graph = tf.get_default_graph()
    print_nodes_name_shape(graph)
    saver = tf.train.Saver()

    # init variables
    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    saver.save(sess, './dummy/ckpt/step0')
    # train
    for i in tqdm(range(100)):
        with tf.variable_scope('', reuse=True):
            mov_avg, mov_std, beta, gamma = sess.run([tf.get_variable('BN/moving_mean'),
                                         tf.get_variable('BN/moving_variance'),
                                                      tf.get_variable('BN/beta'),
                                                      tf.get_variable('BN/gamma')])
            print('\nmov_avg: {}, \nmov_std: {}, \nbeta: {}, \ngamma: {}'.format(mov_avg, mov_std, beta, gamma))
        _, _ = sess.run([train_op, tf.get_collection(tf.GraphKeys.UPDATE_OPS)], feed_dict={
            input_ph: inputs,
            output_ph: outputs,
            is_training: True,
        })

    for i in tqdm(range(100)):
        with tf.variable_scope('', reuse=True):
            mov_avg, mov_std, beta, gamma = sess.run([tf.get_variable('BN/moving_mean'),
                                         tf.get_variable('BN/moving_variance'),
                                                      tf.get_variable('BN/beta'),
                                                      tf.get_variable('BN/gamma')])
            print('\nmov_avg: {}, \nmov_std: {}, \nbeta: {}, \ngamma: {}'.format(mov_avg, mov_std, beta, gamma))
        _ = sess.run([graph.get_tensor_by_name('model/Reshape:0')], feed_dict={
            input_ph: inputs,
            is_training: False,
        })

        # print moving avg/std
    saver.save(sess, './dummy/ckpt/step100')


def freeze_ckpt_for_inference(ckpt_path=None, conserve_nodes=None):
    # clean graph first
    tf.reset_default_graph()
    # freeze ckpt then convert to pb
    new_input = tf.placeholder(tf.float32, shape=[None, 10, 10, 1], name='new_input')
    new_is_training = tf.placeholder(tf.bool, name='new_is_training')

    restorer = tf.train.import_meta_graph(
        ckpt_path + '.meta',
        input_map={
            'input_ph': new_input,
            'is_training': new_is_training,
        },
        clear_devices=True,
    )

    input_graph_def = tf.get_default_graph().as_graph_def()
    check_N_mkdir('./dummy/pb/')
    check_N_mkdir('./dummy/tb/')

    # freeze to pb
    with tf.Session() as sess:
        # restore variables
        restorer.restore(sess, './dummy/ckpt/step100')
        # convert variable to constant
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=input_graph_def,
            output_node_names=conserve_nodes,
        )

        # save to pb
        with tf.gfile.GFile('./dummy/pb/freeze.pb', 'wb') as f:  # 'wb' stands for write binary
            f.write(output_graph_def.SerializeToString())


def optimize_graph_for_inference(pb_dir=None, conserve_nodes=None):
    tf.reset_default_graph()
    check_N_mkdir(pb_dir)

    # import pb file
    with tf.gfile.FastGFile(pb_dir + 'freeze.pb', "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # optimize graph
    optimize_graph_def = optimize_for_inference(input_graph_def=graph_def,
                                                input_node_names=['new_input', 'new_is_training'],
                                                output_node_names=conserve_nodes,
                                                placeholder_type_enum=[dtypes.float32.as_datatype_enum,
                                                                       dtypes.bool.as_datatype_enum,
                                                                       dtypes.float32.as_datatype_enum,
                                                                       ]
                           )
    with tf.gfile.GFile(pb_dir + 'optimize.pb', 'wb') as f:
        f.write(optimize_graph_def.SerializeToString())

conserve_nodes = ['model/Reshape']
freeze_ckpt_for_inference(ckpt_path='./dummy/ckpt/step100', conserve_nodes=conserve_nodes)
optimize_graph_for_inference(pb_dir='./dummy/pb/', conserve_nodes=conserve_nodes)

# cleaning
tf.reset_default_graph()

# load pb
with tf.gfile.FastGFile('./dummy/pb/optimize.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

print('\n Now inference*******************************')
# inference
with tf.Session() as sess:
    tf.graph_util.import_graph_def(
        graph_def,
    )
    # prepare
    G = tf.get_default_graph()
    print_nodes_name_shape(G)
    new_input = G.get_tensor_by_name('import/new_input:0')
    new_is_training = G.get_tensor_by_name('import/new_is_training:0')
    new_output = G.get_tensor_by_name('import/' + conserve_nodes[-1] + ':0')

    # train
    for i in tqdm(range(100)):
        # print moving avg/std
        with tf.variable_scope('', reuse=True):
            mov_avg, mov_std, beta, gamma = sess.run([G.get_tensor_by_name('import/BN/moving_mean:0'),
                                                      G.get_tensor_by_name('import/BN/moving_variance:0'),
                                                      G.get_tensor_by_name('import/BN/beta:0'),
                                                      G.get_tensor_by_name('import/BN/gamma:0')])
            print('\nmov_avg: {}, \nmov_std: {}, \nbeta: {}, \ngamma: {}'.format(mov_avg, mov_std, beta, gamma))
        new_out = sess.run([new_output], feed_dict={
            new_input: inputs,
            new_output: outputs,
            new_is_training: False,
        })
        # print('out: {}'.format(new_out))