6

考虑以下函数

def foo(x):
  with tf.GradientTape() as tape:
    tape.watch(x)

    y = x**2 + x + 4

  return tape.gradient(y, x)

如果函数被称为 as ,则调用 totape.watch(x)是必要的foo(tf.constant(3.14)),但当它直接传入变量时则不需要,例如foo(tf.Variable(3.14)).

现在我的问题是,tape.watch(x)即使在tf.Variable直接传入的情况下也调用安全?还是会因为变量已经被自动监视然后再次手动监视而发生一些奇怪的事情?编写可以同时接受tf.Tensor和的一般函数的正确方法是什么tf.Variable

4

2 回答 2

6

它应该是安全的。一方面,文档tf.GradientTape.watch说:

确保tensor该磁带正在跟踪它。

“确保”似乎暗示它将确保它被追踪以防万一。事实上,文档没有给出任何迹象表明在同一个对象上使用它两次应该是一个问题(尽管如果他们明确表示它不会受到伤害)。

但无论如何,我们都可以深入源码进行检查。最后,调用watch一个变量(如果它不是变量但路径略有不同,则答案最终相同)归结为C++WatchVariable中类的方法:GradientTape

void WatchVariable(PyObject* v) {
  tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
  if (handle == nullptr) {
    return;
  }
  tensorflow::int64 id = FastTensorId(handle.get());

  if (!PyErr_Occurred()) {
    this->Watch(id);
  }

  tensorflow::mutex_lock l(watched_variables_mu_);
  auto insert_result = watched_variables_.emplace(id, v);

  if (insert_result.second) {
    // Only increment the reference count if we aren't already watching this
    // variable.
    Py_INCREF(v);
  }
}

该方法的后半部分显示了被监视的变量被添加到watched_variables_,这是 a std::set,所以再次添加一些东西不会做任何事情。这实际上稍后会检查以确保 Python 引用计数是正确的。上半场基本上叫Watch

template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
    int64 tensor_id) {
  tensor_tape_.emplace(tensor_id, -1);
}

tensor_tape_是一个地图(特别是 a tensorflow::gtl:FlatMap,与标准 C++ 地图几乎相同),所以如果tensor_id已经存在,这将无效。

因此,即使没有明确说明,一切都表明它应该没有问题。

于 2019-02-01T14:11:52.090 回答
0

它旨在供变量使用。从文档

默认情况下,GradientTape 将自动监视在上下文中访问的任何可训练变量。如果您想对监视哪些变量进行细粒度控制,您可以通过将 watch_accessed_variables=False 传递给磁带构造函数来禁用自动跟踪:

with tf.GradientTape(watch_accessed_variables=False) as tape:
  tape.watch(variable_a)
  y = variable_a ** 2  # Gradients will be available for `variable_a`.
  z = variable_b ** 3  # No gradients will be available since `variable_b` is
                       # not being watched.
于 2020-04-07T15:19:52.657 回答