1 回答
TA贡献2037条经验 获得超6个赞
这是对代码的一种重写,使其更适合numba.jit. 这并不完全是矢量化解决方案,但我发现该基准测试速度提高了 230 倍。
from numba import jit
from scipy import spatial
@jit
def D_from_cost(cost, D):
# operates on D inplace
ns, nt = cost.shape
for i in range(ns):
for j in range(nt):
D[i+1, j+1] = cost[i,j]+min(D[i, j+1], D[i+1, j], D[i, j])
# avoiding the list creation inside mean enables better jit performance
# D[i+1, j+1] = cost[i,j]+min([D[i, j+1], D[i+1, j], D[i, j]])
@jit
def get_d(D, matchidx):
ns = D.shape[0] - 1
nt = D.shape[1] - 1
d = D[ns,nt]
matchidx[0,0] = ns - 1
matchidx[0,1] = nt - 1
i = ns
j = nt
for k in range(1, ns+nt+3):
idx = 0
if not (D[i-1,j] <= D[i,j-1] and D[i-1,j] <= D[i-1,j-1]):
if D[i,j-1] <= D[i-1,j-1]:
idx = 1
else:
idx = 2
if idx == 0 and i > 1 and j > 0:
# matchidx.append([i-2, j-1])
matchidx[k,0] = i - 2
matchidx[k,1] = j - 1
i -= 1
elif idx == 1 and i > 0 and j > 1:
# matchidx.append([i-1, j-2])
matchidx[k,0] = i-1
matchidx[k,1] = j-2
j -= 1
elif idx == 2 and i > 1 and j > 1:
# matchidx.append([i-2, j-2])
matchidx[k,0] = i-2
matchidx[k,1] = j-2
i -= 1
j -= 1
else:
break
return d, matchidx[:k]
def seqdist2(seq1, seq2):
ns = len(seq1)
nt = len(seq2)
cost = spatial.distance_matrix(seq1, seq2)
# initialize and update D
D = np.full((ns+1, nt+1), np.inf)
D[0, 0] = 0
D_from_cost(cost, D)
matchidx = np.zeros((ns+nt+2,2), dtype=np.int)
d, matchidx = get_d(D, matchidx)
return d, matchidx[::-1].tolist()
assert seqdist2(seq1, seq2) == seqdist(seq1, seq2)
%timeit seqdist2(seq1, seq2) # 1000 loops, best of 3: 365 µs per loop
%timeit seqdist(seq1, seq2) # 10 loops, best of 3: 86.1 ms per loop
以下是一些变化:
cost
是使用 计算的spatial.distance_matrix
。的定义
idx
被一堆丑陋的 if 语句取代,这使得编译的代码更快。min([D[i, j+1], D[i+1, j], D[i, j]])
替换为min(D[i, j+1], D[i+1, j], D[i, j])
,即我们不取列表的最小值,而是取三个值的最小值。这导致了令人惊讶的加速jit
。matchidx
被预先分配为 numpy 数组,并在输出之前截断为正确的大小。
添加回答
举报