我正在尝试索引张量以从一维张量中获取切片或单个元素。我发现使用numpy
索引[:]
和slice vs tf.gather
(几乎 30-40% )的方式时存在显着的性能差异。
我还观察到,tf.gather
与 tensor 相比,在标量(循环未堆叠的张量)上使用时会产生很大的开销。这是一个已知的问题 ?
示例代码(低效):
for node_idxs in graph.nodes():
node_indice_list = tf.unstack(node_idxs)
result = []
for nodeid in node_indices_list:
x = tf.gather(..., nodeid)
y = tf.gather(..., nodeid)
result.append(tf.mul(x,y))
return tf.stack(result)
与示例代码(高效)相反:
for node_idxs in graph.nodes():
x = tf.gather(..., node_idxs)
y = tf.gather(..., node_idxs)
return tf.mul(x, y)
我知道第一个低效的实现是做更多的拆栈、堆叠然后循环和更多的收集操作,但是当我正在操作的节点顺序是几百个节点时,我没想到会减速 100 倍(拆栈和收集的开销在这么慢的单个标量上,在第一种情况下,我有更多的收集操作,每个操作都在单个元素上操作,而不是偏移张量)。是否有更快的索引方式,我尝试了 numpy 和 slice,结果比收集慢。