2 回答

TA贡献1804条经验 获得超8个赞
请始终发布您迄今为止尝试过的内容。关于您的 Numba 版本,我认为您做了一些导致性能不佳的事情。
例子
import numpy as np
import numba as nb
import time
@nb.njit(fastmath=True)
def your_function(fx,fy,aPx,aPy,img):
pi=np.pi
sum=0.
for i in range(aPx.shape[0]):
for j in range(aPx.shape[1]):
sum+=img[i,j]*np.exp(-1j*2*pi*(fx*aPx[i,j]+fy*aPy[i,j]))
return np.abs(sum)
@nb.njit(fastmath=True,parallel=True)
def your_function_p(fx,fy,aPx,aPy,img):
pi=np.pi
sum=0.
for i in nb.prange(aPx.shape[0]):
for j in range(aPx.shape[1]):
sum+=img[i,j]*np.exp(-1j*2*pi*(fx*aPx[i,j]+fy*aPy[i,j]))
return np.abs(sum)
#The function gets compiled at the first call
#you may also use cache=True, which only works in single threaded code
F=your_function(fx,fy,aPx,aPy,img)
start = time.time()
for i in range(20):
F=your_function(fx,fy,aPx,aPy,img)
end = time.time()
print("Elapsed (after compilation) = %s" % (end - start))
print(F)
F=your_function_p(fx,fy,aPx,aPy,img)
start = time.time()
for i in range(20):
F=your_function_p(fx,fy,aPx,aPy,img)
end = time.time()
print("Elapsed (after compilation) = %s" % (end - start))
print(F)
计时(4C/8T)
your_version: 2.45s
Numba single threaded: 0.17s
Numba parallel: 0.07s
添加回答
举报