Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/8/linq/3.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow 如何从tflite模型中获取权重?_Tensorflow - Fatal编程技术网

Tensorflow 如何从tflite模型中获取权重?

Tensorflow 如何从tflite模型中获取权重?,tensorflow,Tensorflow,我有一个简单的网络,我用tensorflow做了修剪和量化。我特别遵循本教程在我的网络上应用: 最后,我得到了tflite文件。我想从这个文件中提取权重。如何从这个量化模型中获得权重?我知道从“h5”文件获取权重的方法,但不知道从“tflite”文件获取权重的方法。或者在模型上执行量化后,是否有其他方法保存“h5”文件?我已使用Netron解决了此问题。在Netron中,权重可以另存为numpy数组。 创建tflite解释器并(可选)执行推断。tflite_解释器。get_tensor_det

我有一个简单的网络,我用tensorflow做了修剪和量化。我特别遵循本教程在我的网络上应用:


最后,我得到了tflite文件。我想从这个文件中提取权重。如何从这个量化模型中获得权重?我知道从“h5”文件获取权重的方法,但不知道从“tflite”文件获取权重的方法。或者在模型上执行量化后,是否有其他方法保存“h5”文件?

我已使用Netron解决了此问题。在Netron中,权重可以另存为numpy数组。
创建tflite解释器并(可选)执行推断。tflite_解释器。get_tensor_details()将提供一个字典列表,其中包含权重、偏差、刻度、零点等等

'''
Create interpreter, allocate tensors
'''
tflite_interpreter = tf.lite.Interpreter(model_path='model_file.tflite')
tflite_interpreter.allocate_tensors()

'''
Check input/output details
'''
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()

print("== Input details ==")
print("name:", input_details[0]['name'])
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("name:", output_details[0]['name'])
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])

'''
Run prediction (optional), input_array has input's shape and dtype
'''
tflite_interpreter.set_tensor(input_details[0]['index'], input_array)
tflite_interpreter.invoke()
output_array = tflite_interpreter.get_tensor(output_details[0]['index'])

'''
This gives a list of dictionaries. 
'''
tensor_details = tflite_interpreter.get_tensor_details()

for dict in tensor_details:
    i = dict['index']
    tensor_name = dict['name']
    scales = dict['quantization_parameters']['scales']
    zero_points = dict['quantization_parameters']['zero_points']
    tensor = tflite_interpreter.tensor(i)()

    print(i, type, name, scales.shape, zero_points.shape, tensor.shape)

    '''
    See note below
    '''
  • Conv2D层将有三个DICT与之关联:内核、偏差、conv_输出,每个DICT都有其刻度、零点和张量
  • 张量-是具有核权重或偏差的np数组。对于conv_输出或激活,这并不意味着什么(不是中间输出)
  • 对于conv-kernel的字典,张量是成形的(cout,k,k,cin)