gpt4 book ai didi

python - TensorFlow 对于简单网络表现不佳吗?

转载 作者:行者123 更新时间:2023-11-30 09:11:17 24 4
gpt4 key购买 nike

我一直在各种框架中尝试简单的基本(入门教程级别)神经网络,但对我在 TensorFlow 中看到的性能感到困惑。

例如,来自 Michael Nielsen's tutorial 的简单网络(在具有 30 个隐藏节点的网络中使用 L2 随机梯度下降的 MNIST 数字识别)的性能比稍微适应的网络要差得多(在所有相同的参数下,每个周期花费大约 8 倍的时间) (按照 one of the tutorial exercises 中的建议使用小批量矢量化) Nielsen's basic NumPy code 的版本。

在单 CPU 上运行的 TensorFlow 是否总是表现如此糟糕?我应该调整哪些设置来提高性能?或者 TensorFlow 是否只在更复杂的网络或学习机制上表现出色,因此预计它不会在如此简单的玩具案例中表现出色?

<小时/>
from __future__ import (absolute_import, print_function, division, unicode_literals)

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time


def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev=0.1))


def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape=shape))


mnist = input_data.read_data_sets("./data/", one_hot=True)

sess = tf.Session()

# Inputs and outputs
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

# Model parameters
W1 = weight_variable([784, 30])
b1 = bias_variable([30])
o1 = tf.nn.sigmoid(tf.matmul(x, W1) + b1, name='o1')
W2 = weight_variable([30, 10])
b2 = bias_variable([10])
y = tf.nn.softmax(tf.matmul(o1, W2) + b2, name='y')

sess.run(tf.initialize_all_variables())

loss = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
loss += 0.1/1000 * (tf.nn.l2_loss(W1) + tf.nn.l2_loss(W2))

train_step = tf.train.GradientDescentOptimizer(0.15).minimize(loss)

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32))


for ep in range(30):
for mb in range(int(len(mnist.train.images)/40)):
batch_xs, batch_ys = mnist.train.next_batch(40)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

最佳答案

是的,我希望在 CPU 上运行的手工编码的专门简单网络会比 tensorflow 网络运行得更快。原因通常与tensorflow使用的图评估系统有关。

使用 tensorflow 的好处是当您有更复杂的算法并且您希望能够首先测试正确性,然后能够轻松地将其移植到使用更多机器和更多处理单元时。

例如,您可以尝试的一件事是在具有 GPU 的计算机上运行您的代码,然后发现在不更改代码中的任何内容的情况下,您将获得加速,可能比您链接的手动编码示例更快。可以看到,手写的代码要移植到GPU上需要付出相当大的努力。

关于python - TensorFlow 对于简单网络表现不佳吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36780106/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com