gpt4 book ai didi

memory-management - 为什么这个 tensorflow 循环需要这么多内存?

转载 作者:行者123 更新时间:2023-12-04 07:44:12 24 4
gpt4 key购买 nike

我有一个复杂网络的人为版本:

import tensorflow as tf

a = tf.ones([1000])
b = tf.ones([1000])

for i in range(int(1e6)):
a = a * b

我的直觉是这应该需要很少的内存。只是用于初始数组分配的空间和一串命令,这些命令利用节点并在每一步覆盖存储在张量“a”中的内存。但是内存使用量增长得非常快。

这里发生了什么,当我计算张量并多次覆盖它时,如何减少内存使用量?

编辑:

感谢 Yaroslav 的建议,解决方案原来是使用 while_loop 来最小化图上的节点数。这很好用,速度更快,需要的内存少得多,并且都包含在图中。
import tensorflow as tf

a = tf.ones([1000])
b = tf.ones([1000])

cond = lambda _i, _1, _2: tf.less(_i, int(1e6))
body = lambda _i, _a, _b: [tf.add(_i, 1), _a * _b, _b]

i = tf.constant(0)
output = tf.while_loop(cond, body, [i, a, b])

with tf.Session() as sess:
result = sess.run(output)
print(result)

最佳答案

您的 a*b命令转换为 tf.mul(a, b) ,相当于 tf.mul(a, b, g=tf.get_default_graph()) .此命令添加一个 Mul节点到当前 Graph对象,因此您正尝试添加 100 万 Mul节点到当前图。这也有问题,因为您无法序列化大于 2GB 的 Graph 对象,一旦您处理如此大的图,有些检查可能会失败。

我建议阅读 Programming Models for Deep Learning由 MXNet 人员提供。 TensorFlow 在他们的术语中是“符号”编程,您将其视为必要的。

要使用 Python 循环获得您想要的结果,您可以构造一次乘法运算,然后使用 feed_dict 重复运行它。提供更新

mul_op = a*b
result = sess.run(a)
for i in range(int(1e6)):
result = sess.run(mul_op, feed_dict={a: result})

为了提高效率,您可以使用 tf.Variable对象和 var.assign避免 Python<->TensorFlow 数据传输

关于memory-management - 为什么这个 tensorflow 循环需要这么多内存?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38751120/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com