gpt4 book ai didi

java - 如何从 Java Map 创建 TensorProto

转载 作者:行者123 更新时间:2023-12-02 00:36:02 27 4
gpt4 key购买 nike

我可以使用 float 等单个值创建 Tensorproto。

// create TensorProto with 3 floats
org.tensorflow.framework.TensorProto.Builder tensorProtoBuilder = org.tensorflow.framework.TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
tensorProtoBuilder.addFloatVal(1.0f);
tensorProtoBuilder.addFloatVal(2.0f);
tensorProtoBuilder.addFloatVal(5.0f);

// create TensorShapeProto
org.tensorflow.framework.TensorShapeProto.Builder tensorShapeBuilder = org.tensorflow.framework.TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(org.tensorflow.framework.TensorShapeProto.Dim.newBuilder().setSize(3));

// set shape for proto
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());

// build proto
org.tensorflow.framework.TensorProto proto = tensorProtoBuilder.build();

System.out.println(proto.toString());

但是如何创建带有键值对的 Tensorproto 呢?键值对代表我想要输入分类模型的特征。

"country": "ireland", "currency": "euro"

这是 2 个数组的情况吗,一个包含键(功能名称),另一个包含值?

编辑:有没有办法将 org.tensorflow.example.Feature 的集合转换为 tensorproto ?

How to provide input for a TensorFlow DNNRegressor in Java?

最佳答案

我没有收到任何对此的反馈。

但设法解决了这个问题,所以希望我的答案和代码将来能对某人有所帮助。

注意:没有必要为此问题创建 TensorProto。

简而言之,我必须:- 创建特征对象- 将它们添加到功能列表对象中- 将功能列表 obj 添加到示例 obj- 将示例 obj 添加到示例列表 obj- 创建输入对象- 将示例列表 obj 添加到输入 obj

这是工作代码。

import com.google.protobuf.ByteString;
import com.google.protobuf.Int64Value;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import tensorflow.serving.Classification;
import tensorflow.serving.Classification.ClassificationRequest;
import tensorflow.serving.PredictionServiceGrpc;
import tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub;
import tensorflow.serving.Model.ModelSpec;
import org.tensorflow.example.Example;
import tensorflow.serving.InputOuterClass.Input;
import tensorflow.serving.InputOuterClass.ExampleList;
import org.tensorflow.example.Features;

import java.util.HashMap;
import java.util.Map;


public class GrpcClient {

public static void main(String[] args) throws Exception {

String host = "localhost";
int port = 8500;
String modelName = "Model_1";
long modelVersion = 1568645807;

// create a channel
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);

// create ClassificationRequest
ClassificationRequest.Builder classificationRequestBuilder = ClassificationRequest.newBuilder();

// create ModelSpec
ModelSpec.Builder modelSpecBuilder = ModelSpec.newBuilder();
modelSpecBuilder.setName(modelName);
modelSpecBuilder.setVersion(Int64Value.of(modelVersion));
modelSpecBuilder.setSignatureName("serving_default");

// set model for request
classificationRequestBuilder.setModelSpec(modelSpecBuilder);

// map of input features
Map<String, Float> myMap = new HashMap<>();
myMap.put("feature1", 1.0f);
myMap.put("feature2", 0.0f);

// convert map to example list
ExampleList.Builder exampleListBuilder = ExampleList.newBuilder();
exampleListBuilder.addExamples(buildExample(myMap));
ExampleList exampleList = exampleListBuilder.build();

// create input
Input.Builder inputBuilder = Input.newBuilder();
inputBuilder.setExampleList(exampleList);

// set input for request
Input input = inputBuilder.build();
classificationRequestBuilder.setInput(input);

// build request
Classification.ClassificationRequest request = classificationRequestBuilder.build();

// run classification
Classification.ClassificationResponse response = stub.classify(request);

System.out.println(response.toString());
}

public static Example buildExample(Map<String, ?> featureMap) {
Features.Builder featuresBuilder = Features.newBuilder();
for (String attr : featureMap.keySet()) {
Object value = featureMap.get(attr);
if (value instanceof Float) {
featuresBuilder.putFeature(attr, feature((Float) value));
} else if (value instanceof float[]) {
featuresBuilder.putFeature(attr, feature((float[]) value));
} else if (value instanceof String) {
featuresBuilder.putFeature(attr, feature((String) value));
} else if (value instanceof String[]) {
featuresBuilder.putFeature(attr, feature((String[]) value));
} else if (value instanceof Long) {
featuresBuilder.putFeature(attr, feature((Long) value));
} else if (value instanceof long[]) {
featuresBuilder.putFeature(attr, feature((long[]) value));
} else {
throw new UnsupportedOperationException("Not supported attribute value data type!");
}
}
Features features = featuresBuilder.build();
return Example.newBuilder().setFeatures(features).build();
}

private static org.tensorflow.example.Feature feature(String... strings) {
org.tensorflow.example.BytesList.Builder b = org.tensorflow.example.BytesList.newBuilder();
for (String s : strings) {
b.addValue(ByteString.copyFromUtf8(s));
}
return org.tensorflow.example.Feature.newBuilder().setBytesList(b).build();
}

private static org.tensorflow.example.Feature feature(float... values) {
org.tensorflow.example.FloatList.Builder b = org.tensorflow.example.FloatList.newBuilder();
for (float v : values) {
b.addValue(v);
}
return org.tensorflow.example.Feature.newBuilder().setFloatList(b).build();
}

private static org.tensorflow.example.Feature feature(long... values) {
org.tensorflow.example.Int64List.Builder b = org.tensorflow.example.Int64List.newBuilder();
for (long v : values) {
b.addValue(v);
}
return org.tensorflow.example.Feature.newBuilder().setInt64List(b).build();
}
}

关于java - 如何从 Java Map 创建 TensorProto,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57973223/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com