2 回答
TA贡献1859条经验 获得超6个赞
方法#1
这是一个基于views. 利用np.argwhere( docs ) 返回满足条件的元素的索引,在本例中为成员资格。——
def view1D(a, b): # a, b are arrays
a = np.ascontiguousarray(a)
b = np.ascontiguousarray(b)
void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
return a.view(void_dt).ravel(), b.view(void_dt).ravel()
def argwhere_nd(a,b):
A,B = view1D(a,b)
return np.argwhere(A[:,None] == B)
方法#2
这是另一个O(n)更好的性能,尤其是在大型阵列上 -
def argwhere_nd_searchsorted(a,b):
A,B = view1D(a,b)
sidxB = B.argsort()
mask = np.isin(A,B)
cm = A[mask]
idx0 = np.flatnonzero(mask)
idx1 = sidxB[np.searchsorted(B,cm, sorter=sidxB)]
return idx0, idx1 # idx0 : indices in A, idx1 : indices in B
方法#3
另一个O(n)使用argsort()-
def argwhere_nd_argsort(a,b):
A,B = view1D(a,b)
c = np.r_[A,B]
idx = np.argsort(c,kind='mergesort')
cs = c[idx]
m0 = cs[:-1] == cs[1:]
return idx[:-1][m0],idx[1:][m0]-len(A)
示例使用与之前相同的输入运行 -
In [650]: argwhere_nd_searchsorted(a,b)
Out[650]: (array([0, 1]), array([2, 0]))
In [651]: argwhere_nd_argsort(a,b)
Out[651]: (array([0, 1]), array([2, 0]))
TA贡献1842条经验 获得超12个赞
您可以利用自动广播:
np.argwhere(np.all(a.reshape(3,1,-1) == b,2))
这导致
array([[0, 2],
[1, 0]])
注为花车你可能要更换==用np.islclose()
添加回答
举报