gpt4 book ai didi

tensorflow - 找到一个 tensorflow op 依赖的所有变量

转载 作者:行者123 更新时间:2023-12-04 12:47:51 25 4
gpt4 key购买 nike

有没有办法找到给定操作(通常是损失)所依赖的所有变量?我想用它来将这个集合传递给 optimizer.minimize()tf.gradients()使用各种 set().intersection()组合。

到目前为止我找到了op.op.inputs并尝试了一个简单的 BFS,但我从来没有机会 Variable tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 返回的对象或 slim.get_variables()

相应的'Tensor.op._id and之间似乎确实存在对应关系Variables.op._id` 字段,但我不确定这是我应该依赖的东西。

或者也许我一开始就不应该这样做?我可以当然可以在构建我的图表时精心构建我的不相交变量集,但是如果我更改模型,很容易遗漏一些东西。

最佳答案

documentation for tf.Variable.op不是特别清楚,但它确实引用了 the implementation of a tf.Variable 中使用的关键 tf.Operation :任何依赖于 tf.Variable 的操作都将位于该操作的路径上。由于 tf.Operation 对象是可散列的,您可以将它用作 dict 的键,将 tf.Operation 对象映射到相应的 tf.Variable 对象,然后像以前一样执行 BFS:

op_to_var = {var.op: var for var in tf.trainable_variables()}

starting_op = ...
dependent_vars = []

queue = collections.deque()
queue.append(starting_op)

visited = set([starting_op])

while queue:
op = queue.popleft()
try:
dependent_vars.append(op_to_var[op])
except KeyError:
# `op` is not a variable, so search its inputs (if any).
for op_input in op.inputs:
if op_input.op not in visited:
queue.append(op_input.op)
visited.add(op_input.op)

关于tensorflow - 找到一个 tensorflow op 依赖的所有变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42855556/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com