Tensorflow:是否可以打印会话数量?

wqsoz72f  于 2023-02-16  发布在  其他
关注(0)|答案(2)|浏览(81)

可以在tensorflow中有多个会话吗?可以在tensorflow中打印会话数吗?

def test_print_number_of_sessions():
    sess1 = tf.Session()
    sess2 = tf.Session()

    //print_number_of_sessions
ua4mk5z4

ua4mk5z41#

每个图可以有多个会话,但是没有直接的方法来获取图中所有打开的会话,图的内部C数据结构确实有一个包含所有现有会话的集合,但是不幸的是,Python对应的部分(tf.Graph对象的._c_graph属性)只是一个不透明的指针,没有类型信息。
一个可能的解决方案是使用您自己的会话 Package 器来跟踪每个图中打开的会话,这是一种可能的方法。

import tensorflow as tf
import collections

class TrackedSession(tf.Session):
    _sessions = collections.defaultdict(list)
    def __init__(self, target='', graph=None, config=None):
        super(tf.Session, self).__init__(target=target, graph=graph, config=config)
        TrackedSession._sessions[self.graph].append(self)
    def close(self):
        super(tf.Session, self).close()
        TrackedSession._sessions[self.graph].remove(self)
    @classmethod
    def get_open_sessions(cls, g=None):
        g = g or tf.get_default_graph()
        return list(cls._sessions[g])

print(TrackedSession.get_open_sessions())
# []
sess1 = TrackedSession()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C77F0>]
sess2 = TrackedSession()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C77F0>, <__main__.TrackedSession object at 0x000001D75B0C7A58>]
sess1.close()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C7A58>]
sess2.close()
print(TrackedSession.get_open_sessions())
# []

但是这限制了您使用这种定制会话类型,根据场景的不同,这种类型可能不够好(例如,如果会话是由一些外部代码打开的,例如当您使用Keras时)。

rsl1atfo

rsl1atfo2#

这可能会有所帮助:

tf.InteractiveSession._active_session_count

相关问题