作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我使用 tensorflow 来运行随机森林模型。代码:
import tensorflow as tf
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.ops import resources
from tensorflow.examples.tutorials.mnist import input_data
num_steps = 50000 # Total steps to train
batch_size = 1024 # The number of samples per batch
num_classes = 10 # The 10 digits
num_features = 784 # Each image is 28x28 pixels
num_trees = 10
max_nodes = 1000
X = tf.placeholder(tf.float32, shape=[None, num_features])
Y = tf.placeholder(tf.int32, shape=[None])
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
num_features=num_features,
num_trees=num_trees,
max_nodes=max_nodes).fill()
forest_graph = tensor_forest.RandomForestGraphs(params=hparams)
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X,Y)
infer_op = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.arg_max(infer_op, 1), tf.cast(Y, tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
sess = tf.Session()
sess.run(init_vars)
test_x, test_y = mnist.test.images, mnist.test.labels
for i in range(1, num_steps + 1):
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
_, l = sess.run([train_op, loss_op], feed_dict={X:batch_x, Y: batch_y})
if i % 100 == 0 or i == 1:
acc = sess.run(accuracy_op, feed_dict={X:batch_x, Y: batch_y})
print('step %i, loss: %f, acc: %f' % (i, l, acc))
if i % 100 == 0:
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
问题:如何保存模型并恢复模型进行预测?这是tf随机森林的最新版本,我使用tf 1.2,它可以工作。我发现有人使用 TensorForestEstimator,但它不适用于 tf 1.2,tf更新这么频繁!
最佳答案
保存模型很容易,但恢复它会杀了我。无论我做什么,总是出现“FertileStatsResourceHandleOp”错误,最后,我在恢复之前添加两行代码,它起作用了。
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
num_features=num_features,
num_trees=num_trees,
max_nodes=max_nodes).fill()
forest_graph = tensor_forest.RandomForestGraphs(params=hparams)
完整代码如下:
X = tf.placeholder(tf.float32, shape=[None, num_features],name="input_x")
Y = tf.placeholder(tf.int32, shape=[None], name="input_y")
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
num_features=num_features,
num_trees=num_trees,
max_nodes=max_nodes).fill()
forest_graph = tensor_forest.RandomForestGraphs(params=hparams)
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X,Y)
correct_prediction = tf.argmax(infer_op, 1, name="predictions")
accuracy_op = tf.reduce_mean(tf.cast(tf.equal(correct_prediction,tf.cast(Y, tf.int64)), tf.float32),name="accuracy")
init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
sess = tf.Session()
sess.run(init_vars)
test_x, test_y = mnist.test.images, mnist.test.labels
saver = tf.train.Saver(save_relative_paths=True, max_to_keep=10)
checkpoint_prefix = 'checkpoints/model'
for i in range(1, num_steps + 1):
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
_, l = sess.run([train_op, loss_op], feed_dict={X:batch_x, Y: batch_y})
if i % 10 == 0 or i == 1:
acc = sess.run(accuracy_op, feed_dict={X:batch_x, Y: batch_y})
print('step %i, loss: %f, acc: %f' % (i, l, acc))
if i % 10 == 0:
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
path = saver.save(sess, checkpoint_prefix, global_step=i)
print("last Saved model checkpoint to {} at step {}".format(path, i))
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
num_features=num_features,
num_trees=num_trees,
max_nodes=max_nodes).fill()
forest_graph = tensor_forest.RandomForestGraphs(params=hparams)
checkpoint_file = tf.train.latest_checkpoint('checkpoints')
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
sess = tf.Session(config=session_conf)
with sess.as_default():
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file), clear_devices=True)
saver.restore(sess, checkpoint_file)
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]
predictions = graph.get_operation_by_name("predictions").outputs[0]
accuracy = graph.get_operation_by_name("accuracy").outputs[0]
acc = sess.run(accuracy, {input_x: test_x, input_y:test_y })
predictions = sess.run(predictions, {input_x: test_x })
print(predictions)
关于python - 如何保存/恢复tensorflow的tensor_forest模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48542446/
我在 tensorflow 中构建了一个简单的随机森林模型,并想为 android 卡住和优化它。我使用以下函数构建 tesnor_forest 估计器: def build_estimator(_m
我正在尝试使用 tensorflow 2.0 实现随机森林.我查看了一些示例( https://github.com/aymericdamien/TensorFlow-Examples/blob/ma
我是一名优秀的程序员,十分优秀!