gpt4 book ai didi

python - 将keras模型导出到tflite

转载 作者:太空宇宙 更新时间:2023-11-03 11:59:46 25 4
gpt4 key购买 nike

我正在尝试结合这两个示例并为我的 Android 应用程序创建 tflite 文件。

https://medium.com/nybles/create-your-first-image-recognition-classifier-using-cnn-keras-and-tensorflow-backend-6eaab98d14dd

https://medium.com/@xianbao.qian/convert-keras-model-to-tflite-e2bdf28ee2d2

这是我的代码:

# Part 1 - Building the CNN

# Importing the Keras libraries and packages
from keras.models import Sequential
from keras.layers import Convolution2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
import tensorflow as tf
from keras.models import load_model


# Initialising the CNN
classifier = Sequential()

# Step 1 - Convolution
classifier.add(Convolution2D(32, 3, 3, input_shape = (64, 64, 3), activation = 'relu'))

# Step 2 - Pooling
classifier.add(MaxPooling2D(pool_size = (2, 2)))

# Adding a second convolutional layer
classifier.add(Convolution2D(32, 3, 3, activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))

# Step 3 - Flattening
classifier.add(Flatten())

# Step 4 - Full connection
classifier.add(Dense(output_dim = 128, activation = 'relu'))
classifier.add(Dense(output_dim = 1, activation = 'sigmoid'))

# Compiling the CNN
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])

# Part 2 - Fitting the CNN to the images

from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale = 1./255,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

training_set = train_datagen.flow_from_directory('dataset/training_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'binary')

test_set = test_datagen.flow_from_directory('dataset/test_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'binary')

classifier.fit_generator(training_set,
samples_per_epoch = 80,
nb_epoch = 1,
validation_data = test_set,
nb_val_samples = 20)



output_names = [node.op.name for node in classifier.outputs]
sess = tf.keras.backend.get_session()
frozen_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names)


tflite_model = tf.contrib.lite.toco_convert(frozen_def, [inputs], output_names)
with tf.gfile.GFile(tflite_graph, 'wb') as f:
f.write(tflite_model)

在这一行:

frozen_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names)

我有一个异常(exception):

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value conv2d_1/bias
[[Node: _retval_conv2d_1/bias_0_0 = _Retval[T=DT_FLOAT, index=0, _device="/job:localhost/replica:0/task:0/device:CPU:0"](conv2d_1/bias)]]

我是机器学习的初学者,完全不知道这个错误是什么:-(

谁能给我解释一下哪里出了问题?我所需要的只是处理几个包含许多图片的文件夹,并可以预测新出现的图片与特定文件夹的关系。谢谢。

最佳答案

可以使用 tf.lite.TFLiteConverter.from_session 函数将 keras 模型直接转换为 .tflite。在 fit_generator 之后放置以下代码以将其导出(使用 tensorflow 1.3.1 测试)

with tf.keras.backend.get_session() as sess:
sess.run(tf.global_variables_initializer())
converter = tf.lite.TFLiteConverter.from_session(sess, model.inputs, model.outputs)
tflite_model = converter.convert()
with open("model.tflite", "wb") as f:
f.write(tflite_model)

关于python - 将keras模型导出到tflite,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53060501/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com