gpt4 book ai didi

python - Tensorflow:是否可以打印 session 数?

转载 作者:行者123 更新时间:2023-12-01 07:20:28 28 4
gpt4 key购买 nike

tensorflow 中是否可以有多个 session ?是否可以在 tensorflow 中打印 session 数?

def test_print_number_of_sessions():
sess1 = tf.Session()
sess2 = tf.Session()

//print_number_of_sessions

最佳答案

每个图表可以有多个 session ,但没有直接的方法来获取图表中的所有打开的 session 。该图的内部 C 数据结构确实有 a collection with all the existing sessions ,但不幸的是 Python 的对应部分(tf.Graph 对象的 ._c_graph 属性)只是一个没有类型信息的不透明指针。

一种可能的解决方案是使用您自己的 session 包装器来跟踪每个图表的打开 session 。这是一种可能的方法。

import tensorflow as tf
import collections

class TrackedSession(tf.Session):
_sessions = collections.defaultdict(list)
def __init__(self, target='', graph=None, config=None):
super(tf.Session, self).__init__(target=target, graph=graph, config=config)
TrackedSession._sessions[self.graph].append(self)
def close(self):
super(tf.Session, self).close()
TrackedSession._sessions[self.graph].remove(self)
@classmethod
def get_open_sessions(cls, g=None):
g = g or tf.get_default_graph()
return list(cls._sessions[g])

print(TrackedSession.get_open_sessions())
# []
sess1 = TrackedSession()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C77F0>]
sess2 = TrackedSession()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C77F0>, <__main__.TrackedSession object at 0x000001D75B0C7A58>]
sess1.close()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C7A58>]
sess2.close()
print(TrackedSession.get_open_sessions())
# []

但这会限制您使用此自定义 session 类型,根据具体情况,这可能不够好(例如,如果 session 是由某些外部代码打开的,例如当您使用 Keras 时)。

关于python - Tensorflow:是否可以打印 session 数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57725588/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com