1

对于一些自定义代码,我需要运行一个 for 循环来在 Tensorflow 2 中动态创建一个变量(启用急切执行模式)。(在我的自定义代码中,我写入变量的值需要渐变,因此我想跟踪 for 循环中的计算,以便从 autodiff 中获取渐变)。我的代码有效,但速度非常慢。事实上,它比在 numpy 中执行相同的操作要慢几个数量级。

我已经隔离了这个问题,并提供了一个突出问题的玩具代码片段。修复它将允许我修复我的自定义代码。

import numpy as np
import tensorflow as tf
import timeit

N = int(1e5)
data = np.random.randn(N)
def numpy_func(data):
    new_data = np.zeros_like(data)
    for i in range(len(data)):
        new_data[i] = data[i]
    return new_data

def tf_func(data):
    new_data = tf.Variable(tf.zeros_like(data))
    for i in range(len(data)):
        new_data[i].assign(data[i])
    return new_data    

%timeit numpy_func(data)
%timeit tf_func(data)

此代码片段的关键要点是,在 for 循环中,我只需要更新变量的一部分。每次迭代时要更新的切片都不同。用于更新的数据在每次迭代中都不同(在我的自定义代码中,它是依赖于变量切片的一些简单计算的结果,这里我只是使用固定数组来隔离问题。)

我正在使用 Tensorflow 2,并且 TensorFlow 代码理想情况下需要在启用急切执行的情况下运行,因为部分自定义操作依赖于急切执行。

我是 Tensorflow 的新手,我非常感谢解决这个问题的任何帮助。

非常感谢,马克斯

4

1 回答 1

0

当这样使用时,TensorFlow 永远不会很快。理想的解决方案是将您的计算矢量化,因此它不需要您显式循环,但这取决于您正在计算的确切内容(如果您愿意,您可以发布另一个问题)。但是,您可以使用tf.function. 我将您的函数更改new_data为作为输出参数,因为tf.function不允许您在第一次调用后创建变量(但实际上,如果您删除new_data参数,它也可以工作,因为tf.function将在全局范围内找到变量)。

import numpy as np
import tensorflow as tf
import timeit

# Input data
N = int(1e3)
data = np.random.randn(N)

# NumPy
def numpy_func(data, new_data):
    new_data[:] = 0
    for i in range(len(data)):
        new_data[i] = data[i]

new_data = np.zeros_like(data)
%timeit numpy_func(data, new_data)
# 143 µs ± 4.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# TensorFlow
def tf_func(data, new_data):
    new_data.assign(tf.zeros_like(data))
    for i in range(len(data)):
        new_data[i].assign(data[i])
new_data = tf.Variable(tf.zeros_like(data))
%timeit tf_func(data, new_data)
# 119 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# tf.function
# This is equivalent to using it as a decorator
tf_func2 = tf.function(tf_func)
new_data = tf.Variable(tf.zeros_like(data))
tf_func2(data, new_data)  # First call is slower
%timeit tf_func2(data, new_data)
# 3.55 ms ± 40.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这是在 CPU 上运行的,结果在 GPU 上可能会有很大差异。无论如何,如您所见,tf.function它仍然比 NumPy 慢 20 倍以上,但也比 Python 函数快 30 倍以上。

于 2020-04-14T13:57:51.777 回答