gpt4 book ai didi

tensorflow - 使用 TensorFlow v2.2 将 Keras .h5 模型转换为 TFLite .tflite

转载 作者:行者123 更新时间:2023-12-05 07:08:01 37 4
gpt4 key购买 nike

我正在尝试将我使用 Keras 定义的网络转换为 tflite。网络如下:

model = tf.keras.Sequential([
# Embedding
tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]),
# GRU unit
tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
# Fully connected layer
tf.keras.layers.Dense(vocab_size)
])

但是,当我尝试导出到 .tflite 时,似乎由于存在 GRU 层而出现问题。

# Save trained model in .h5 format
keras_file = 'inference.h5'
tf.keras.models.save_model(model, keras_file)

# Load .h5 model with custom loss function
model = load_model('inference.h5', custom_objects={'loss': loss})

# Converting a tf.Keras model to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

有错误:

ValueError: Input 0 of node sequential_8/gru_2/AssignVariableOp was passed float from sequential_8/gru_2/68029:0 incompatible with expected resource.

有解决这个问题的方法吗?

最佳答案

现在不支持stateful,你可以试试set stateful=False吗?

关于tensorflow - 使用 TensorFlow v2.2 将 Keras .h5 模型转换为 TFLite .tflite,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61975032/

37 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com