Python tflite量化如何更改输入数据类型
请参阅文章末尾的可能解决方案Python tflite量化如何更改输入数据类型,python,tensorflow,keras,quantization,tensorflow-lite,Python,Tensorflow,Keras,Quantization,Tensorflow Lite,请参阅文章末尾的可能解决方案 我试图将keras VGFace模型从完全量化到在NPU上运行。该模型是一个Keras模型(不是tf.Keras) 使用TF 1.15进行量化时: print(tf.version.VERSION) num_calibration_steps=5 converter = tf.lite.TFLiteConverter.from_keras_model_file('path_to_model.h5') #converter.post_traini
我试图将keras VGFace模型从完全量化到在NPU上运行。该模型是一个Keras模型(不是tf.Keras) 使用TF 1.15进行量化时:
print(tf.version.VERSION)
num_calibration_steps=5
converter = tf.lite.TFLiteConverter.from_keras_model_file('path_to_model.h5')
#converter.post_training_quantize = True # This only makes the weight in8 but does not initialize model quantization
def representative_dataset_gen():
for _ in range(num_calibration_steps):
pfad='path_to_image(s)'
img=cv2.imread(pfad)
# Get sample input data as a numpy array in a method of your choosing.
yield [img]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
open("quantized_model", "wb").write(tflite_quant_model)
模型已转换,但由于我需要完整的int8量化,我添加:
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8 # or tf.uint8
converter.inference_output_type = tf.int8 # or tf.uint8
出现以下错误消息:
ValueError:无法设置tensor:获取的值类型为UINT8,但输入0的类型应为FLOAT32,名称:input_1
显然,模型的输入仍然需要浮动32
问题:
添加: 使用:
h5_path = 'my_model.h5'
model = keras.models.load_model(h5_path)
model.save(os.getcwd() +'/modelTF2')
使用TF 2.2将h5另存为pb,然后使用converter=TF.lite.TFLiteConverter.from_saved_model(saved_model_dir)
由于TF2.x tflite接受浮点数,并将其转换为uint8s。我认为这可能是一个解决办法。很遗憾,出现以下错误消息:
tf.lite.TFLiteConverter.from_keras_模型提供的“str”对象没有属性“call”
显然,TF2.x无法处理纯keras模型
使用tf.compat.v1.lite.TFLiteConverter.from_keras_model_file()
来解决此错误,只需重复上面的错误,因为我们又回到了“tf 1.15”级别
添加2 另一种解决方案是手动将keras模型传输到tf.keras。如果没有其他解决办法,我将对此进行调查
关于Meghna Natraj的评论 要重新创建模型(使用TF 1.13.x),只需执行以下操作:
pip安装git+https://github.com/rcmalli/keras-vggface.git
及
输入层已连接。太糟糕了,这看起来是个很好/很容易解决的问题。
可能的解决方案
使用TF1.15.3似乎是可行的,我之前使用了1.15.0。我将检查我是否意外地做了其他不同的事情。此操作失败的一个可能原因是模型具有未连接到输出张量的输入张量,即它们可能未使用 这是我复制这个错误的地方。将笔记本开头的
io_类型
修改为tf.uint8
,以查看与您得到的错误类似的错误
解决方案
您需要手动检查模型,查看是否有任何悬空/丢失/未连接到输出的输入,然后将其删除
发布模型的链接,我也可以尝试调试它 嗨,Meghna Natraj,谢谢你的想法。我检查了整个模型。似乎没有什么松散的结局。关于你的回答,请看我的补充。
from keras_vggface.vggface import VGGFace
pretrained_model = VGGFace(model='resnet50', include_top=False, input_shape=(224, 224, 3), pooling='avg') # pooling: None, avg or max
pretrained_model.summary()
pretrained_model.save("my_model.h5") #using h5 extension