- c - 在位数组中找到第一个零
- linux - Unix 显示有关匹配两种模式之一的文件的信息
- 正则表达式替换多个文件
- linux - 隐藏来自 xtrace 的命令
我设法使用 DNNRegressor 编写了一个 TensorFlow python 程序。我已经训练了模型,并且能够通过手动创建的输入(常量张量)从 Python 模型中获得预测。我还能够以二进制格式导出模型。
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import graph_util
#######################
# Setup
#######################
# Converting Data into Tensors
def input_fn(df, training = True):
# Creates a dictionary mapping from each continuous feature column name (k) to
# the values of that column stored in a constant Tensor.
continuous_cols = {k: tf.constant(df[k].values)
for k in continuous_features}
feature_cols = dict(list(continuous_cols.items()))
if training:
# Converts the label column into a constant Tensor.
label = tf.constant(df[LABEL_COLUMN].values)
# Returns the feature columns and the label.
return feature_cols, label
# Returns the feature columns
return feature_cols
def train_input_fn():
return input_fn(train_df)
def eval_input_fn():
return input_fn(evaluate_df)
#######################
# Data Preparation
#######################
df_train_ori = pd.read_csv('training.csv')
df_test_ori = pd.read_csv('test.csv')
train_df = df_train_ori.head(10000)
evaluate_df = df_train_ori.tail(5)
test_df = df_test_ori.head(1)
MODEL_DIR = "/tmp/model"
BIN_MODEL_DIR = "/tmp/modelBinary"
features = train_df.columns
continuous_features = [feature for feature in features if 'label' not in feature]
LABEL_COLUMN = 'label'
engineered_features = []
for continuous_feature in continuous_features:
engineered_features.append(
tf.contrib.layers.real_valued_column(
column_name=continuous_feature,
dimension=1,
default_value=None,
dtype=tf.int64,
normalizer=None
))
#######################
# Define Our Model
#######################
regressor = tf.contrib.learn.DNNRegressor(
feature_columns=engineered_features,
label_dimension=1,
hidden_units=[128, 256, 512],
model_dir=MODEL_DIR
)
#######################
# Training Our Model
#######################
wrap = regressor.fit(input_fn=train_input_fn, steps=5)
#######################
# Evaluating Our Model
#######################
results = regressor.evaluate(input_fn=eval_input_fn, steps=1)
for key in sorted(results):
print("%s: %s" % (key, results[key]))
#######################
# Save binary model (to be used in Java)
#######################
tfrecord_serving_input_fn = tf.contrib.learn.build_parsing_serving_input_fn(tf.contrib.layers.create_feature_spec_for_parsing(engineered_features))
regressor.export_savedmodel(
export_dir_base=BIN_MODEL_DIR,
serving_input_fn = tfrecord_serving_input_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
strip_default_attrs=False)
我的下一步是将模型加载到 java 中并做出一些预测。但是,我在为 Java 模型指定输入时遇到了问题。
import org.tensorflow.*;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
import java.util.List;
import java.util.Map;
public class ModelEvaluator {
public static void main(String[] args) throws Exception {
System.out.println("Using TF version: " + TensorFlow.version());
SavedModelBundle model = SavedModelBundle.load("/tmp/modelBinary/1546510038", "serve");
Session session = model.session();
printSignature(model);
printAllNodes(model);
float[][] km1 = new float[1][1];
km1[0][0] = 10;
Tensor inKm1 = Tensor.create(km1);
float[][] km2 = new float[1][1];
km2[0][0] = 10000;
Tensor inKm2 = Tensor.create(km2);
List<Tensor<?>> outputs = session.runner()
.feed("dnn/input_from_feature_columns/input_from_feature_columns/km1/ToFloat", inKm1)
.feed("dnn/input_from_feature_columns/input_from_feature_columns/km2/ToFloat", inKm2)
.fetch("dnn/regression_head/predictions/Identity:0")
.run();
System.out.println("\n\nOutputs from evaluation:");
for (Tensor<?> output : outputs) {
if (output.dataType() == DataType.STRING) {
System.out.println(new String(output.bytesValue()));
} else {
float[] outArray = new float[1];
output.copyTo(outArray);
System.out.println(outArray[0]);
}
}
}
public static void printAllNodes(SavedModelBundle model) {
model.graph().operations().forEachRemaining(x -> {
System.out.println(x.name() + " " + x.numOutputs());
});
}
/**
* This info can also be obtained from a command prompt via the command:
* saved_model_cli show --dir <dir-to-the-model> --tag_set serve --signature_def serving_default
* <p>
* See this where they also try to input data to a DNN regressor:
* https://github.com/tensorflow/tensorflow/issues/12367
* <p>
* https://github.com/tensorflow/tensorflow/issues/14683
* <p>
* https://github.com/migueldeicaza/TensorFlowSharp/issues/293
*/
public static void printSignature(SavedModelBundle model) throws Exception {
MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("-----------------------------------------------");
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}
System.out.println("-----------------------------------------------");
}
}
从 java 代码可以看出,我为两个节点提供了输入(用“km1”和“km2”命名)。但我想这不是正确的做法。我猜我需要为节点“input_example_tensor:0”提供输入?
所以问题是:我实际上如何为加载到 java 中的模型创建输入?在 python 中,我必须创建一个包含键“km1”和“km2”的字典,并为两个常量张量赋值。
最佳答案
在 Python 上尝试
feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
请查看 build_parsing_serving_input_receiver_fn,以及一个名为 input_example_tensor 的输入,它需要一个序列化的 tf.Example。
在 Java 上,尝试创建一个 Example输入(打包在 org.tensorflow:proto artifact 中),以及一些像这样的代码:
public static void main(String[] args) {
Example example = buildExample(yourFeatureNameAndValueMap);
byte[][] exampleBytes = {example.toByteArray()};
try (Tensor<String> inputBatch = Tensors.create(exampleBytes);
Tensor<Float> output =
yourSession
.runner()
.feed(yourInputsName, inputBatch)
.fetch(yourOutputsName)
.run()
.get(0)
.expect(Float.class)) {
long[] shape = output.shape();
int batchSize = (int) shape[0];
int labelNum = (int) shape[1];
float[][] resultValues = output.copyTo(new float[batchSize][labelNum]);
System.out.println(resultValues);
}
}
public static Example buildExample(Map<String, ?> yourFeatureNameAndValueMap) {
Features.Builder builder = Features.newBuilder();
for (String attr : yourFeatureNameAndValueMap.keySet()) {
Object value = yourFeatureNameAndValueMap.get(attr);
if (value instanceof Float) {
builder.putFeature(attr, feature((Float) value));
} else if (value instanceof float[]) {
builder.putFeature(attr, feature((float[]) value));
} else if (value instanceof String) {
builder.putFeature(attr, feature((String) value));
} else if (value instanceof String[]) {
builder.putFeature(attr, feature((String[]) value));
} else if (value instanceof Long) {
builder.putFeature(attr, feature((Long) value));
} else if (value instanceof long[]) {
builder.putFeature(attr, feature((long[]) value));
} else {
throw new UnsupportedOperationException("Not supported attribute value data type!");
}
}
Features features = builder.build();
Example example = Example.newBuilder()
.setFeatures(features)
.build();
return example;
}
private static Feature feature(String... strings) {
BytesList.Builder b = BytesList.newBuilder();
for (String s : strings) {
b.addValue(ByteString.copyFromUtf8(s));
}
return Feature.newBuilder().setBytesList(b).build();
}
private static Feature feature(float... values) {
FloatList.Builder b = FloatList.newBuilder();
for (float v : values) {
b.addValue(v);
}
return Feature.newBuilder().setFloatList(b).build();
}
private static Feature feature(long... values) {
Int64List.Builder b = Int64List.newBuilder();
for (long v : values) {
b.addValue(v);
}
return Feature.newBuilder().setInt64List(b).build();
}
如果你想自动获取yourInputsName和yourOutputsName,你可以试试
SignatureDef signatureDef;
try {
signatureDef = MetaGraphDef.parseFrom(model.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF_KEY);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e.getMessage(), e);
}
String yourInputsName = signatureDef.getInputsOrThrow(SIGNATURE_DEF_INPUT_KEY).getName();
String yourOutputsName = signatureDef.getOutputsOrThrow(SIGNATURE_DEF_OUTPUT_KEY).getName();
关于java,请引用DetectObjects.java .关于Python,请引用wide_deep
关于java - 如何在 Java 中为 TensorFlow DNNRegressor 提供输入?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54091670/
我正在编写一个具有以下签名的 Java 方法。 void Logger(Method method, Object[] args); 如果一个方法(例如 ABC() )调用此方法 Logger,它应该
我是 Java 新手。 我的问题是我的 Java 程序找不到我试图用作的图像文件一个 JButton。 (目前这段代码什么也没做,因为我只是得到了想要的外观第一的)。这是我的主课 代码: packag
好的,今天我在接受采访,我已经编写 Java 代码多年了。采访中说“Java 垃圾收集是一个棘手的问题,我有几个 friend 一直在努力弄清楚。你在这方面做得怎么样?”。她是想骗我吗?还是我的一生都
我的 friend 给了我一个谜语让我解开。它是这样的: There are 100 people. Each one of them, in his turn, does the following
如果我将使用 Java 5 代码的应用程序编译成字节码,生成的 .class 文件是否能够在 Java 1.4 下运行? 如果后者可以工作并且我正在尝试在我的 Java 1.4 应用程序中使用 Jav
有关于why Java doesn't support unsigned types的问题以及一些关于处理无符号类型的问题。我做了一些搜索,似乎 Scala 也不支持无符号数据类型。限制是Java和S
我只是想知道在一个 java 版本中生成的字节码是否可以在其他 java 版本上运行 最佳答案 通常,字节码无需修改即可在 较新 版本的 Java 上运行。它不会在旧版本上运行,除非您使用特殊参数 (
我有一个关于在命令提示符下执行 java 程序的基本问题。 在某些机器上我们需要指定 -cp 。 (类路径)同时执行java程序 (test为java文件名与.class文件存在于同一目录下) jav
我已经阅读 StackOverflow 有一段时间了,现在我才鼓起勇气提出问题。我今年 20 岁,目前在我的家乡(罗马尼亚克卢日-纳波卡)就读 IT 大学。足以介绍:D。 基本上,我有一家提供簿记应用
我有 public JSONObject parseXML(String xml) { JSONObject jsonObject = XML.toJSONObject(xml); r
我已经在 Java 中实现了带有动态类型的简单解释语言。不幸的是我遇到了以下问题。测试时如下代码: def main() { def ks = Map[[1, 2]].keySet()
一直提示输入 1 到 10 的数字 - 结果应将 st、rd、th 和 nd 添加到数字中。编写一个程序,提示用户输入 1 到 10 之间的任意整数,然后以序数形式显示该整数并附加后缀。 public
我有这个 DownloadFile.java 并按预期下载该文件: import java.io.*; import java.net.URL; public class DownloadFile {
我想在 GUI 上添加延迟。我放置了 2 个 for 循环,然后重新绘制了一个标签,但这 2 个 for 循环一个接一个地执行,并且标签被重新绘制到最后一个。 我能做什么? for(int i=0;
我正在对对象 Student 的列表项进行一些测试,但是我更喜欢在 java 类对象中创建硬编码列表,然后从那里提取数据,而不是连接到数据库并在结果集中选择记录。然而,自从我这样做以来已经很长时间了,
我知道对象创建分为三个部分: 声明 实例化 初始化 classA{} classB extends classA{} classA obj = new classB(1,1); 实例化 它必须使用
我有兴趣使用 GPRS 构建车辆跟踪系统。但是,我有一些问题要问以前做过此操作的人: GPRS 是最好的技术吗?人们意识到任何问题吗? 我计划使用 Java/Java EE - 有更好的技术吗? 如果
我可以通过递归方法反转数组,例如:数组={1,2,3,4,5} 数组结果={5,4,3,2,1}但我的结果是相同的数组,我不知道为什么,请帮助我。 public class Recursion { p
有这样的标准方式吗? 包括 Java源代码-测试代码- Ant 或 Maven联合单元持续集成(可能是巡航控制)ClearCase 版本控制工具部署到应用服务器 最后我希望有一个自动构建和集成环境。
我什至不知道这是否可能,我非常怀疑它是否可能,但如果可以,您能告诉我怎么做吗?我只是想知道如何从打印机打印一些文本。 有什么想法吗? 最佳答案 这里有更简单的事情。 import javax.swin
我是一名优秀的程序员,十分优秀!