使用 int vs float 数据类型会导致执行不同的代码路径:
float 的堆栈跟踪如下所示:
(gdb) backtr
#0 0x007865a0 in dgemm_ () from /usr/lib/libblas.so.3gf
#1 0x007559d5 in cblas_dgemm () from /usr/lib/libblas.so.3gf
#2 0x00744108 in dotblas_matrixproduct (__NPY_UNUSED_TAGGEDdummy=0x0, args=(<numpy.ndarray at remote 0x85d9090>, <numpy.ndarray at remote 0x85d9090>),
kwargs=0x0) at numpy/core/blasdot/_dotblas.c:798
#3 0x08088ba1 in PyEval_EvalFrameEx ()
...
..而 int 的堆栈跟踪如下所示:
(gdb) backtr
#0 LONG_dot (ip1=0xb700a280 "\t", is1=4, ip2=0xb737dc64 "\a", is2=4000, op=0xb6496fc4 "", n=1000, __NPY_UNUSED_TAGGEDignore=0x85fa960)
at numpy/core/src/multiarray/arraytypes.c.src:3076
#1 0x00659d9d in PyArray_MatrixProduct2 (op1=<numpy.ndarray at remote 0x85dd628>, op2=<numpy.ndarray at remote 0x85dd628>, out=0x0)
at numpy/core/src/multiarray/multiarraymodule.c:847
#2 0x00742b93 in dotblas_matrixproduct (__NPY_UNUSED_TAGGEDdummy=0x0, args=(<numpy.ndarray at remote 0x85dd628>, <numpy.ndarray at remote 0x85dd628>),
kwargs=0x0) at numpy/core/blasdot/_dotblas.c:254
#3 0x08088ba1 in PyEval_EvalFrameEx ()
...
这两个调用都导致了 dotblas_matrixproduct,但似乎 float 调用保留在 BLAS 库中(可能访问一些优化良好的代码),而 int 调用被踢回 numpy 的 PyArray_MatrixProduct2。
所以这要么是一个错误,要么是 BLAS 只是不支持 matrixproduct 中的整数类型(这似乎不太可能)。
这是一个简单且便宜的解决方法:
af = a.astype(float)
np.dot(af, af).astype(int)