gpt4 book ai didi

python - 删除 TensorFlow 图中除少数节点以外的所有节点

转载 作者:太空狗 更新时间:2023-10-29 20:16:41 25 4
gpt4 key购买 nike

我的 TensorFlow 用例要求我为每个需要处理的实例构建一个新的计算图。这最终会增加内存需求。

除了作为模型参数的一些 tf.Variables 之外,我想删除所有其他节点。其他有类似问题的人发现 tf.reset_default_graph() 很有用,但这会去掉我需要保留的模型参数。

我可以使用什么来删除除这些节点之外的所有节点?

编辑:特定于实例的计算实际上只是意味着我要添加很多新操作。我相信这些操作是内存问题背后的原因。

更新:请参阅最近发布的 tensorflow fold ( https://github.com/tensorflow/fold ),它允许动态构建计算图。

最佳答案

tf.graph 数据结构被设计为仅附加 数据结构。因此无法删除或修改现有节点。通常这不是问题,因为在运行 session 时只处理必要的子图。

您可以尝试将图表的变量复制到新图表中并删除旧图表。要存档,只需运行:

old_graph = tf.get_default_graph() # Save the old graph for later iteration
new_graph = tf.graph() # Create an empty graph
new_graph.set_default() # Makes the new graph default

如果您想遍历旧图中的所有节点,请使用:

for node in old_graph.get_operations():
if node.type == 'Variable':
# read value of variable and copy it into new Graph

或者你可以使用:

for node in old_graph.get_collection('trainable_variables'):
# iterates over all trainable Variabels
# read and create new variable

另请查看 python/framework/ops.py : 1759 以了解更多操作图中节点的方法。

但是,在您乱用 tf.Graph 之前,我强烈建议您考虑一下这是否真的需要。通常可以尝试概括计算并使用共享变量构建一个图,以便您要处理的每个实例都是该图的子图。

关于python - 删除 TensorFlow 图中除少数节点以外的所有节点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37103002/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com