Python tensorflow中的批量规范化:变量和性能

Python tensorflow中的批量规范化:变量和性能,python,tensorflow,batch-normalization,Python,Tensorflow,Batch Normalization,我想在批处理规范化层的变量上添加条件操作。具体地说,在浮动中训练,然后在微调二次训练阶段量化。为此,我想在变量(均值和var的scale、shift和exp移动平均数)上添加tf.cond操作 我将tf.layers.batch\u normalization替换为我编写的batchnorm层(见下文) 这个函数工作得很好(即,我使用两个函数获得相同的度量),并且我可以将任何管道添加到变量中(在batchnorm操作之前)问题是性能(运行时)急剧下降(即,通过简单地用我自己的函数替换layers

我想在批处理规范化层的变量上添加条件操作。具体地说,在浮动中训练,然后在微调二次训练阶段量化。为此,我想在变量(均值和var的scale、shift和exp移动平均数)上添加tf.cond操作

我将
tf.layers.batch\u normalization
替换为我编写的batchnorm层(见下文)

这个函数工作得很好(即,我使用两个函数获得相同的度量),并且我可以将任何管道添加到变量中(在batchnorm操作之前)问题是性能(运行时)急剧下降(即,通过简单地用我自己的函数替换layers.batchnorm,存在一个x2因子,如下所述)

如果您能在以下问题中提供帮助,我将不胜感激:

  • 关于如何提高我的解决方案的性能(减少运行时间)有什么想法吗
  • 是否可以在batchnorm操作之前将我自己的运算符添加到layers.batchnorm的变量管道中
  • 对于同样的问题,还有其他解决方案吗

谢谢大家!

tf.nn.fused_batch_norm
经过优化,达到了目的

我必须创建两个子图,每个模式一个子图,因为
fused\u batch\u norm
的接口不采用条件训练/测试模式(is\u training是bool而不是张量,所以它的图形不是条件的)。我在后面添加了条件(见下文)。然而,即使有这两个子图,它的
tf.layers.batch\u normalization
运行时也大致相同

这里是最终的解决方案(我仍然非常感谢任何关于改进的评论或建议):

def batchnorm(self, x, name, epsilon=0.001, decay=0.99):
    epsilon = tf.to_float(epsilon)
    decay = tf.to_float(decay)
    with tf.variable_scope(name):
        shape = x.get_shape().as_list()
        channels_num = shape[3]
        # scale factor
        gamma = tf.get_variable("gamma", shape=[channels_num], initializer=tf.constant_initializer(1.0), trainable=True)
        # shift value
        beta = tf.get_variable("beta", shape=[channels_num], initializer=tf.constant_initializer(0.0), trainable=True)
        moving_mean = tf.get_variable("moving_mean", channels_num, initializer=tf.constant_initializer(0.0), trainable=False)
        moving_var = tf.get_variable("moving_var", channels_num, initializer=tf.constant_initializer(1.0), trainable=False)
        batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2]) # per channel

        update_mean = moving_mean.assign((decay * moving_mean) + ((1. - decay) * batch_mean))
        update_var = moving_var.assign((decay * moving_var) + ((1. - decay) * batch_var))

        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)

        bn_mean = tf.cond(self.is_training, lambda: tf.identity(batch_mean), lambda: tf.identity(moving_mean))
        bn_var = tf.cond(self.is_training, lambda: tf.identity(batch_var), lambda: tf.identity(moving_var))

        with tf.variable_scope(name + "_batchnorm_op"):
            inv = tf.math.rsqrt(bn_var + epsilon)
            inv *= gamma
            output = ((x*inv) - (bn_mean*inv)) + beta

    return output
def batchnorm(self, x, name, epsilon=0.001, decay=0.99):
    with tf.variable_scope(name):
        shape = x.get_shape().as_list()
        channels_num = shape[3]
        # scale factor
        gamma = tf.get_variable("gamma", shape=[channels_num], initializer=tf.constant_initializer(1.0), trainable=True)
        # shift value
        beta = tf.get_variable("beta", shape=[channels_num], initializer=tf.constant_initializer(0.0), trainable=True)
        moving_mean = tf.get_variable("moving_mean", channels_num, initializer=tf.constant_initializer(0.0), trainable=False)
        moving_var = tf.get_variable("moving_var", channels_num, initializer=tf.constant_initializer(1.0), trainable=False)

        (output_train, batch_mean, batch_var) = tf.nn.fused_batch_norm(x,
                                                                 gamma,
                                                                 beta,  # pylint: disable=invalid-name
                                                                 mean=None,
                                                                 variance=None,
                                                                 epsilon=epsilon,
                                                                 data_format="NHWC",
                                                                 is_training=True,
                                                                 name="_batchnorm_op")
        (output_test, _, _) = tf.nn.fused_batch_norm(x,
                                                     gamma,
                                                     beta,  # pylint: disable=invalid-name
                                                     mean=moving_mean,
                                                     variance=moving_var,
                                                     epsilon=epsilon,
                                                     data_format="NHWC",
                                                     is_training=False,
                                                     name="_batchnorm_op")

        output = tf.cond(self.is_training, lambda: tf.identity(output_train), lambda: tf.identity(output_test))

        update_mean = moving_mean.assign((decay * moving_mean) + ((1. - decay) * batch_mean))
        update_var = moving_var.assign((decay * moving_var) + ((1. - decay) * batch_var))
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)

    return output