它应该是安全的。一方面,文档tf.GradientTape.watch
说:
确保tensor
该磁带正在跟踪它。
“确保”似乎暗示它将确保它被追踪以防万一。事实上,文档没有给出任何迹象表明在同一个对象上使用它两次应该是一个问题(尽管如果他们明确表示它不会受到伤害)。
但无论如何,我们都可以深入源码进行检查。最后,调用watch
一个变量(如果它不是变量但路径略有不同,则答案最终相同)归结为C++WatchVariable
中类的方法:GradientTape
void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
}
tensorflow::int64 id = FastTensorId(handle.get());
if (!PyErr_Occurred()) {
this->Watch(id);
}
tensorflow::mutex_lock l(watched_variables_mu_);
auto insert_result = watched_variables_.emplace(id, v);
if (insert_result.second) {
// Only increment the reference count if we aren't already watching this
// variable.
Py_INCREF(v);
}
}
该方法的后半部分显示了被监视的变量被添加到watched_variables_
,这是 a std::set
,所以再次添加一些东西不会做任何事情。这实际上稍后会检查以确保 Python 引用计数是正确的。上半场基本上叫Watch
:
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
tensor_tape_
是一个地图(特别是 a tensorflow::gtl:FlatMap
,与标准 C++ 地图几乎相同),所以如果tensor_id
已经存在,这将无效。
因此,即使没有明确说明,一切都表明它应该没有问题。