3

首先,为这个模糊的标题道歉,我想不出一个合适的名字来解决这个问题。

我有以下格式的 3 个 numpy 数组:

N = ([[13, 14, 15], [2, 5, 7], [4, 6, 8] ... 几十万个元素长

e1 = [1, 0, 0]

e2 = [0, 1, 0]

这个想法是创建第四个数组“v”,它应该具有与“N”相同的维度,但将根据 if 语句给出值。这是我目前拥有的应该更好地解释该问题的内容:

v = np.zeros([len(N), 3])    

for i in range(0, len(N)):
    if((N*e1)[i,0] != 0):
        v[i] = np.cross(N[i],e1)
    else:
        v[i] = np.cross(N[i],e2)

这段代码完成了我的要求,但比预期的时间长(> 5 分钟)。我可以使用任何形式的列表理解或类似概念来提高代码效率吗?

4

2 回答 2

2

您可以使用numpy.where广播替换 if-else 并矢量化该过程,这是一个选项numpy.where

import numpy as np
np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))

这里的一些基准

1)数据设置

N = np.array([np.random.randint(0,10,3) for i in range(1000)])
N

#array([[3, 5, 0],
#       [5, 0, 8],
#       [4, 6, 0],
#       ..., 
#       [9, 4, 2],
#       [6, 9, 3],
#       [2, 9, 2]])

e1 = np.array([1, 0, 0])
e2 = np.array([0, 1, 0])

2)时间

def forloop():
    v = np.zeros([len(N), 3]);    
​
    for i in range(0, len(N)):
        if((N*e1)[i,0] != 0):
            v[i] = np.cross(N[i],e1)
        else:
            v[i] = np.cross(N[i],e2)
    return v

def forloop2():
    v = np.zeros([len(N), 3])    
​
    # Only calculate this one time.
    my_product = N*e1
​
    for i in range(0, len(N)):
        if my_product[i,0] != 0:
            v[i] = np.cross(N[i],e1)
        else:
            v[i] = np.cross(N[i],e2)               
    return v

%timeit forloop()
10 loops, best of 3: 25.5 ms per loop

%timeit forloop2()
100 loops, best of 3: 12.7 ms per loop    

%timeit np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
10000 loops, best of 3: 71.9 µs per loop

3)所有方法的结果检查

v1 = forloop()   

v2 = np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))

v3 = forloop2()

(v3 == v1).all()
# True

(v1 == v2).all()
# True
于 2017-02-03T01:40:52.933 回答
1

我不确定你想要做什么,但我知道为什么这个特定的代码对你来说这么慢。最严重的罪犯是(N*e1)。这是一个简单的计算,它使用 numpy 运行得非常快,但你是在循环内执行它,len(N)时间!

N == 1000000通过将代码拉到循环之外,我可以在不到 15 秒的时间内在我的机器上执行您的代码。下面的例子。

v = np.zeros([len(N), 3])    

# Only calculate this one time.
my_product = N*e1

for i in range(0, len(N)):
    if my_product[i,0] != 0):
        v[i] = np.cross(N[i],e1)
    else:
        v[i] = np.cross(N[i],e2)

另一个答案演示了如何避免 for 循环和 if 语句,以提高代码的可读性,从而提高速度。

于 2017-02-03T01:42:29.963 回答