4

我正在尝试索引张量以从一维张量中获取切片或单个元素。我发现使用numpy索引[:]slice vs tf.gather (几乎 30-40% )的方式时存在显着的性能差异。

我还观察到,tf.gather与 tensor 相比,在标量(循环未堆叠的张量)上使用时会产生很大的开销。这是一个已知的问题 ?

示例代码(低效):

for node_idxs in graph.nodes():
    node_indice_list = tf.unstack(node_idxs)
    result = []
    for nodeid in node_indices_list:
        x = tf.gather(..., nodeid)
        y = tf.gather(..., nodeid)
        result.append(tf.mul(x,y))
return tf.stack(result)

与示例代码(高效)相反:

for node_idxs in graph.nodes():
    x = tf.gather(..., node_idxs)
    y = tf.gather(..., node_idxs)
return tf.mul(x, y)

我知道第一个低效的实现是做更多的拆栈、堆叠然后循环和更多的收集操作,但是当我正在操作的节点顺序是几百个节点时,我没想到会减速 100 倍(拆栈和收集的开销在这么慢的单个标量上,在第一种情况下,我有更多的收集操作,每个操作都在单个元素上操作,而不是偏移张量)。是否有更快的索引方式,我尝试了 numpy 和 slice,结果比收集慢。

4

1 回答 1

1

首先,代码并没有真正比较收集与 Numpy 索引 - 它比较矢量化索引(tf.gather)与循环索引(Python“for”循环)。循环很慢也就不足为奇了。

请注意,tensor[idxs]在 Tensorflow 中,类 Numpy 的索引无论如何都受到限制:

只有整数、切片 ( :)、省略号 ( ...)、tf.newaxis ( None) 和标量 tf.int32/tf.int64 张量是有效的索引

所以tf.gather用于一般应用。

于 2020-12-19T15:06:34.340 回答