gpt4 book ai didi

python - 复制 tensorflow 图

转载 作者:太空狗 更新时间:2023-10-29 22:12:45 31 4
gpt4 key购买 nike

复制 TensorFlow 图表并使其保持最新的最佳方式是什么?

理想情况下,我想将复制的图形放在另一个设备上(例如从 GPU 到 CPU),然后不时更新副本。

最佳答案

简答:您可能需要 checkpoint files (permalink)。


长答案:

让我们弄清楚这里的设置。我假设您有两个设备 A 和 B,并且您在 A 上进行训练并在 B 上运行推理。您希望定期使用在另一台设备上训练期间发现的新参数来更新运行推理的设备上的参数。上面链接的教程是一个很好的起点。它向您展示了 tf.train.Saver 对象的工作原理,您在这里不需要任何更复杂的东西。

这是一个例子:

import tensorflow as tf

def build_net(graph, device):
with graph.as_default():
with graph.device(device):
# Input placeholders
inputs = tf.placeholder(tf.float32, [None, 784])
labels = tf.placeholder(tf.float32, [None, 10])
# Initialization
w0 = tf.get_variable('w0', shape=[784,256], initializer=tf.contrib.layers.xavier_initializer())
w1 = tf.get_variable('w1', shape=[256,256], initializer=tf.contrib.layers.xavier_initializer())
w2 = tf.get_variable('w2', shape=[256,10], initializer=tf.contrib.layers.xavier_initializer())
b0 = tf.Variable(tf.zeros([256]))
b1 = tf.Variable(tf.zeros([256]))
b2 = tf.Variable(tf.zeros([10]))
# Inference network
h1 = tf.nn.relu(tf.matmul(inputs, w0)+b0)
h2 = tf.nn.relu(tf.matmul(h1,w1)+b1)
output = tf.nn.softmax(tf.matmul(h2,w2)+b2)
# Training network
cross_entropy = tf.reduce_mean(-tf.reduce_sum(labels * tf.log(output), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# Your checkpoint function
saver = tf.train.Saver()
return tf.initialize_all_variables(), inputs, labels, output, optimizer, saver

训练程序的代码:

def programA_main():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Build training network on device A
graphA = tf.Graph()
init, inputs, labels, _, training_net, saver = build_net(graphA, '/cpu:0')
with tf.Session(graph=graphA) as sess:
sess.run(init)
for step in xrange(1,10000):
batch = mnist.train.next_batch(50)
sess.run(training_net, feed_dict={inputs: batch[0], labels: batch[1]})
if step%100==0:
saver.save(sess, '/tmp/graph.checkpoint')
print 'saved checkpoint'

...和推理程序的代码:

def programB_main():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Build inference network on device B
graphB = tf.Graph()
init, inputs, _, inference_net, _, saver = build_net(graphB, '/cpu:0')
with tf.Session(graph=graphB) as sess:
batch = mnist.test.next_batch(50)

saver.restore(sess, '/tmp/graph.checkpoint')
print 'loaded checkpoint'
out = sess.run(inference_net, feed_dict={inputs: batch[0]})
print out[0]

import time; time.sleep(2)

saver.restore(sess, '/tmp/graph.checkpoint')
print 'loaded checkpoint'
out = sess.run(inference_net, feed_dict={inputs: batch[0]})
print out[1]

如果您启动训练程序,然后启动推理程序,您会看到推理程序产生两个不同的输出(来自同一个输入批处理)。这是它选取训练程序已检查点的参数的结果。

现在,这个程序显然不是您的终点。我们不做任何真正的同步,你必须决定关于检查点的“定期”意味着什么。但这应该让您了解如何将参数从一个网络同步到另一个网络。

最后一个警告:这意味着这两个网络必然是确定性的。 TensorFlow 中存在已知的不确定性元素(例如 this ),因此如果您需要完全相同的答案,请小心。但这是在多个设备上运行的残酷事实。

祝你好运!

关于python - 复制 tensorflow 图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37801137/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com