gpt4 book ai didi

python - 当 `tape.watch(x)` 在 TensorFlow 中已经是 `x` 时调用 `tf.Variable` 是否可以?

转载 作者:太空宇宙 更新时间:2023-11-03 14:39:27 24 4
gpt4 key购买 nike

考虑以下函数

def foo(x):
with tf.GradientTape() as tape:
tape.watch(x)

y = x**2 + x + 4

return tape.gradient(y, x)

如果函数被称为 foo(tf.constant(3.14)),则调用 tape.watch(x) 是必要的,但不是直接传入一个变量,比如foo(tf.Variable(3.14))

现在我的问题是,即使在直接传入 tf.Variable 的情况下,对 tape.watch(x) 的调用是否安全?还是会因为变量已经被自动监视然后再次手动监视而发生一些奇怪的事情?编写这样可以同时接受 tf.Tensortf.Variable 的通用函数的正确方法是什么?

最佳答案

应该是安全的。一方面,tf.GradientTape.watch 的文档说:

Ensures that tensor is being traced by this tape.

“确保”似乎暗示它将确保它被跟踪以防万一。事实上,文档没有给出任何迹象表明在同一个对象上使用它两次应该是一个问题(尽管如果他们明确说明也不会造成伤害)。

但无论如何,我们都可以深入源码进行检查。最后,在变量上调用 watch(如果它不是变量但路径略有不同,答案最终相同)归结为 WatchVariable C++ 中 GradientTape 类的方法:

void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
}
tensorflow::int64 id = FastTensorId(handle.get());

if (!PyErr_Occurred()) {
this->Watch(id);
}

tensorflow::mutex_lock l(watched_variables_mu_);
auto insert_result = watched_variables_.emplace(id, v);

if (insert_result.second) {
// Only increment the reference count if we aren't already watching this
// variable.
Py_INCREF(v);
}
}

该方法的后半部分显示被监视的变量被添加到 watched_variables_,这是一个 std::set,因此再次添加一些东西不会有任何作用。这实际上是稍后检查以确保 Python 引用计数是正确的。上半场基本叫Watch :

template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}

tensor_tape_ 是一个 map (特别是 tensorflow::gtl:FlatMap ,与标准 C++ map 几乎相同),因此如果 tensor_id 已经存在,这将无效.

因此,即使没有明确说明,一切都表明它应该没有问题。

关于python - 当 `tape.watch(x)` 在 TensorFlow 中已经是 `x` 时调用 `tf.Variable` 是否可以?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54479770/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com