将预训练保存的模型从NCHW转换为NHWC,使其与Tensorflow Lite兼容

将预训练保存的模型从NCHW转换为NHWC,使其与Tensorflow Lite兼容,tensorflow,keras,Tensorflow,Keras,我已经将一个模型从PyTorch转换为Keras,并使用后端提取tensorflow图。由于PyTorch的数据格式是NCHW,因此提取和保存的模型也是NCHW。将模型转换为TFLite时,由于格式为NCHW,因此无法进行转换。有没有办法将整个图形转换为NHCW?最好有一个数据格式与TFLite匹配的图形,以便更快地进行推理。一种方法是手动将转置操作插入到图形中,如以下示例所示: 不幸的是,目前没有办法将NCHW图形转换为NHWC;如果以后您想使用TF lite运行,您必须从NHWC图形开始训

我已经将一个模型从PyTorch转换为Keras,并使用后端提取tensorflow图。由于PyTorch的数据格式是NCHW,因此提取和保存的模型也是NCHW。将模型转换为TFLite时,由于格式为NCHW,因此无法进行转换。有没有办法将整个图形转换为NHCW?

最好有一个数据格式与TFLite匹配的图形,以便更快地进行推理。一种方法是手动将转置操作插入到图形中,如以下示例所示:


不幸的是,目前没有办法将NCHW图形转换为NHWC;如果以后您想使用TF lite运行,您必须从NHWC图形开始训练

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with tf.Session(config=config) as session:

    kernel = tf.ones(shape=[5, 5, 3, 64])
    images = tf.ones(shape=[64,24,24,3])

    imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
    conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
    conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC

    print("conv=",conv.eval())