为了账号安全,请及时绑定邮箱和手机立即绑定

在 NumPy ndArray 中基于布尔值查找最长序列的更有效解决方案

在 NumPy ndArray 中基于布尔值查找最长序列的更有效解决方案

有只小跳蛙 2021-09-28 20:59:22
我搜索我的 ndArray 以查找基于 True 值的最长系列。是否可以选择在不遍历数组的情况下查找最长系列?我已经用 numpy.nonzero 编写了自己的解决方案,但可能有更好的解决方案。import numpy as nparr = np.array([[[1,2,3,4,5],                [6,7,8,9,10],                [11,12,13,14,15],                [16,17,18,19,20],                [21,22,23,24,25]],                [[True,True,True,False,True],                [True,True,True,True,False],                [True,True,False,True,True],                [True,True,True,False,True],                [True,True,True,False,True]]])def getIndices(arr):    arr_to_search = np.nonzero(arr)    arrs = []    prev_el0 = 0    prev_el1 = -1    activ_long = []    for i in range(len(arr_to_search[0])):        if arr_to_search[0][i] == prev_el0:            if arr_to_search[1][i] != prev_el1 + 1:                arrs.append(activ_long)                activ_long = []        else:            arrs.append(activ_long)            activ_long = []        activ_long.append((arr_to_search[0][i],arr_to_search[1][i]))        prev_el0 = arr_to_search[0][i]        prev_el1 = arr_to_search[1][i]    max_len = len(max(arrs,key=len))    longest_arr_list = [a for a in arrs if len(a) == max_len]    return longest_arr_listprint(getIndices(arr[1,:,:]))print(getIndices(arr[1,:,:].T))[[(1, 0), (1, 1), (1, 2), (1, 3)]][[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4)], [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4)]]
查看完整描述

1 回答

?
叮当猫咪

TA贡献1776条经验 获得超12个赞

这是一个 numpy 解决方案,它避免了基于上一个问题的显式循环。


我假设布尔数组名为a. 本质上,我们找到行从 0 到 1 或从 1 到 0 变化的索引,并查看它们之间的差异。通过在前后填充 0,我们确保对于从 0 到 1 的每个转换,还有另一个从 1 到 0 的转换。


为了方便我处理a,并a.T在同一时间,但你可以分开,如果你想要做他们。


m,n = a.shape

A = np.zeros((2*m,n+2))

A[:m,1:-1] = a

A[m:,1:-1] = a.T


dA = np.diff(A)


start = np.where(dA>0)

end = np.where(dA<0)


argmax_run = np.argmax(end[1]-start[1])


row = start[0][argmax_run]

col_start = start[1][argmax_run]

col_end= end[1][argmax_run]-1


max_len = col_end - col_start + 1


print('max run of length {}'.format(max_len))

print('in '+('row' if row<m else'col')+' {}'.format(row%m)+' from '+('col' if row<m else'row')+' {} to {}'.format(col_start,col_end))

为了提高性能和存储,我们可以更改A为布尔数组。由于-1和1在dA上面总是成对出现,我们可以找到start和end如下。


nz = np.nonzero(dA)

start = (nz[0][::2], nz[1][::2])

end = (nz[0][1::2], nz[1][1::2])

请注意,您可以然后完全删除变量start,end因为它们并不是真正需要的。


查看完整回答
反对 回复 2021-09-28
  • 1 回答
  • 0 关注
  • 211 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信