gpt4 book ai didi

android - Android 上的 Tensorflow lite,如何定义 'runForMultipleInputsOutputs' 函数的输入和输出?

转载 作者:行者123 更新时间:2023-12-04 15:33:03 25 4
gpt4 key购买 nike

我在 android 上使用 tensorflow lite。但是,runForMulipleInputsOutputs 函数不起作用。

这是我做的。

1。制作一个“tfile”,这是 Colab 的模型来源

from numpy import mean
from numpy import std
from numpy import dstack
from pandas import read_csv
from matplotlib import pyplot
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import Dropout
from keras.layers.convolutional import Conv1D
from keras.layers.convolutional import MaxPooling1D
from keras.utils import to_categorical
from tensorflow import keras

#make the model
n_timesteps, n_features, n_outputs = 128, 9, 6
model = Sequential()
model.add(Conv1D(filters=64, kernel_size=3, activation='relu', input_shape=(n_timesteps,n_features)))
model.add(Conv1D(filters=64, kernel_size=3, activation='relu'))
model.add(Dropout(0.5))
model.add(MaxPooling1D(pool_size=2))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(n_outputs, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

#save the model
model.save("/content/gdrive/My Drive/Train_data/accel_trained_model.h5")
model2 = keras.models.load_model("/content/gdrive/My Drive/Train_data/accel_trained_model.h5")
model2.save('/content/gdrive/My Drive/Train_data/tf_accel_trained_model', save_format="tf")

#convert the model and save the tfile
converter = tf.lite.TFLiteConverter.from_saved_model('/content/gdrive/My Drive/Train_data/tf_accel_trained_model')
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('/content/gdrive/My Drive/Train_data/converted_model.tflite', 'wb').write(tflite_model)

2。在 Android

的“build.gradle(Module)”中添加 tensorflow lite 选项
aaptOptions {
noCompress "tflite"
noCompress "lite"
}

dependencies {
implementation 'org.tensorflow:tensorflow-lite:+'
}

3。在android上传模型

tflite = getTfliteInterpreter(modelFile);


private Interpreter getTfliteInterpreter(String modelPath) {
try {
return new Interpreter(loadModelFile(MainActivity.this, modelPath));
}
catch (Exception e) {
e.printStackTrace();
}
return null;
}


private MappedByteBuffer loadModelFile(Activity activity, String MODEL_FILE) throws IOException {
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_FILE);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

3。进行输入输出,model.runForMultipleInputsOutputs

float[][] inp=new float[128][9];
float[][] out=new float[][]{{0, 0, 0, 0, 0, 0}};

java.util.Map<Integer, Object> outputs = new java.util.HashMap();
outputs.put(0, out);

tflite.runForMultipleInputsOutputs(inp,outputs);

结果)错误,我不知道model.runForMultipleInputsOutputs 的正确输入和输出是什么

2020-03-19 22:00:45.219 14799-14799/com.example.tensorflowlite E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.example.tensorflowlite, PID: 14799
java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(java.lang.Object[], java.util.Map)' on a null object reference
at com.example.tensorflowlite.MainActivity$1.onClick(MainActivity.java:93)
at android.view.View.performClick(View.java:6597)
at android.view.View.performClickInternal(View.java:6574)
at android.view.View.access$3100(View.java:778)
at android.view.View$PerformClick.run(View.java:25885)
at android.os.Handler.handleCallback(Handler.java:873)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loop(Looper.java:193)
at android.app.ActivityThread.main(ActivityThread.java:6669)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:493)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)

最佳答案

我发现问题是什么。

第一个。Keras 模型不应更改为 tensorflow 模型。直接将 keras 模型转换为 tensorflow lite 模型(tfile)。这是代码(保存和转换模型)

import tensorflow as tf
from tensorflow import keras
model2 = keras.models.load_model("/content/gdrive/My Drive/Train_data/accel_trained_model.h5")
converter = tf.lite.TFLiteConverter.from_keras_model_file("/content/gdrive/My Drive/Train_data/accel_trained_model.h5")
tflite_model = converter.convert()
open('/content/gdrive/My Drive/Train_data/converted_model.tflite', 'wb').write(tflite_model)

其次。 我更改了 android 的输入。你可以检查android的输入和输出类型。通过这样做,

Log.d("Tag", Arrays.toString(input.shape()));
Log.d("Tag", Arrays.toString(outi.shape()));

在这种情况下,我的输入和输出类型是这样的。

#input shape Log
2020-03-20 21:33:59.608 20035-20035/com.example.tensorflowlite D/Tag: [1, 128, 9]
#output shape Log
2020-03-20 21:33:59.608 20035-20035/com.example.tensorflowlite D/Tag: [1, 6]

因此,我更改了输入和输出形状。像这样。

float[][][] inp=new float[1][128][9];
float[][] out=new float[][]{{0, 0, 0, 0, 0, 0}};

tflite = getTfliteInterpreter(modelFile);
tfile.run(inp, out);

private Interpreter getTfliteInterpreter(String modelPath) {
try {
return new Interpreter(loadModelFile(MainActivity.this, modelPath));
}
catch (Exception e) {
e.printStackTrace();
}
return null;
}


private MappedByteBuffer loadModelFile(Activity activity, String MODEL_FILE) throws IOException {
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_FILE);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

然后,效果很好。

关于android - Android 上的 Tensorflow lite,如何定义 'runForMultipleInputsOutputs' 函数的输入和输出?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60767578/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com