gpt4 book ai didi

tensorflow - 如何从 tflite 模型中获取权重?

转载 作者:行者123 更新时间:2023-12-04 17:34:28 32 4
gpt4 key购买 nike

我有一个简单的网络,我已经使用 tensorflow 进行了剪枝和量化。我专门按照本教程在我的网络上应用: https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras#convert_to_tensorflow_lite

最后,我得到了 tflite 文件。我想从这个文件中提取权重。我如何从这个量化模型中获得权重?我知道从“h5”文件而不是“tflite”文件获取权重的方法。或者在对模型执行量化后是否有任何其他方法可以保存“h5”文件?

最佳答案

创建一个 tflite 解释器并(可选)执行推理。 tflite_interpreter.get_tensor_details() 将给出具有权重、偏差、它们的比例、zero_points..等的字典列表。

'''
Create interpreter, allocate tensors
'''
tflite_interpreter = tf.lite.Interpreter(model_path='model_file.tflite')
tflite_interpreter.allocate_tensors()

'''
Check input/output details
'''
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()

print("== Input details ==")
print("name:", input_details[0]['name'])
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("name:", output_details[0]['name'])
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])

'''
Run prediction (optional), input_array has input's shape and dtype
'''
tflite_interpreter.set_tensor(input_details[0]['index'], input_array)
tflite_interpreter.invoke()
output_array = tflite_interpreter.get_tensor(output_details[0]['index'])

'''
This gives a list of dictionaries.
'''
tensor_details = tflite_interpreter.get_tensor_details()

for dict in tensor_details:
i = dict['index']
tensor_name = dict['name']
scales = dict['quantization_parameters']['scales']
zero_points = dict['quantization_parameters']['zero_points']
tensor = tflite_interpreter.tensor(i)()

print(i, type, name, scales.shape, zero_points.shape, tensor.shape)

'''
See note below
'''
  • Conv2D 层将具有三个与其关联的字典:kernel、bias、conv_output,每个字典都有其尺度、zero_points 和张量。
  • tensor - 是具有内核权重或偏差的 np 数组。对于 conv_output 或者 activation,这没有任何意义(不是中间输出)
  • 对于conv kernel的字典,tensor的形状是(cout, k , k,cin)

关于tensorflow - 如何从 tflite 模型中获取权重?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57197914/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com