Python tensorflow.nn.conv2d中带有NCHW格式的过滤器形状

Python tensorflow.nn.conv2d中带有NCHW格式的过滤器形状,python,tensorflow,Python,Tensorflow,接下来,我将使用NCHW数据格式,但我不确定要在中使用的过滤器形状 文件说要使用NHWC格式的[过滤高度、过滤宽度、输入通道、输出通道],但不清楚如何处理NCHW 应该使用相同的形状吗?使用相同的过滤器形状应该可以。对函数参数的唯一更改是步长。例如,假设您希望您的体系结构同时使用这两种格式,这也是推荐的: # input -> Tensor in NCHW format if use_nchw: result = tf.nn.conv2d( input=input,

接下来,我将使用NCHW数据格式,但我不确定要在中使用的过滤器形状

文件说要使用NHWC格式的
[过滤高度、过滤宽度、输入通道、输出通道]
,但不清楚如何处理NCHW


应该使用相同的形状吗?

使用相同的过滤器形状应该可以。对函数参数的唯一更改是步长。例如,假设您希望您的体系结构同时使用这两种格式,这也是推荐的:

# input -> Tensor in NCHW format
if use_nchw:
    result = tf.nn.conv2d(
        input=input,
        filter=filter,
        strides=[1, 1, stride, stride],
        data_format='NCHW')
else:
    input_t = tf.transpose(input, [0, 2, 3, 1]) # NCHW to NHWC

    result = tf.nn.conv2d(
        input=input_t,
        filter=filter,
        strides=[1, stride, stride, 1])

    result = tf.transpose(result, [0, 3, 1, 2]) # NHWC to NCHW