2 回答
TA贡献1780条经验 获得超5个赞
因此,在摆弄更多之后,我能够在 NumPy 的向量化和 Numba 的 JIT 编译器的帮助下大大减少运行时间。回到原来的代码:
arr = mp.RawArray(ctypes.c_uint, n*m)
def fun(i):
for j in range(i-1,0,-1):
count = 0
for k in range(0,m):
count += (arr[i*m+k] == arr[j*m+k])
if count/m > 0.7:
return (i,j)
return ()
我们可以省略最底层的return陈述,也可以count完全摒弃使用的想法,留给我们:
def fun(i):
for j in range(i-1,0,-1):
if sum(arr[i*m+k] == arr[j*m+k] for k in range(m)) > 0.7*m:
return (i,j)
然后,我们将数组更改arr为 NumPy 格式:
np_arr = np.frombuffer(arr,dtype='int32').reshape(m,n)
这里需要注意的重要一点是,我们不使用 NumPy 数组作为从多个进程写入的共享内存数组,从而避免了开销陷阱。
最后,我们应用 Numba 的装饰器并sum以向量形式重写该函数,使其与新数组一起工作:
import numba as nb
@nb.njit(fastmath=True,parallel=True)
def fun(i):
for j in range(i-1, 0, -1):
if np.sum(np_arr[i] == np_arr[j]) > 0.7*m:
return (i,j)
这将运行时间减少到7.9s,这对我来说绝对是一个胜利。
TA贡献1809条经验 获得超8个赞
由于您根本无法更改数组特征,因此我认为您坚持使用O(n^2)。 numpy将获得一些矢量化,但会更改共享数组的其他人的访问权限。从最里面的操作开始:
for k in range(0,m):
count += (arr[i][k] == arr[j][k])
将此更改为单行分配:
count = sum(arr[i][k] == arr[j][k] for k in range(m))
现在,如果这确实是一个数组,而不是一个列表列表,请使用数组包的向量化来简化循环,一次一个:
count = sum(arr[i] == arr[j]) # results in a vector of counts
您现在可以返回j索引 where count[j] / m > 0.7。请注意,没有必要i为每一个都返回:它在函数内是常量,并且调用程序已经具有该值。您的数组包可能有一对可以返回这些索引的矢量化索引操作。如果您正在使用numpy,这些很容易在本网站上查找。
添加回答
举报