为了账号安全,请及时绑定邮箱和手机立即绑定

如何使用 scipy.spatial.KDTree.query_ball_point

如何使用 scipy.spatial.KDTree.query_ball_point

慕盖茨4494581 2022-12-20 11:26:45
我正在尝试使用 Kdtree 数据结构从数组中移除最近的点,最好不要 for 循环。import sysimport timeimport scipy.spatialclass KDTree:    """    Nearest neighbor search class with KDTree    """    def __init__(self, data):        # store kd-tree        self.tree = scipy.spatial.cKDTree(data)    def search(self, inp, k=1):        """        Search NN        inp: input data, single frame or multi frame        """        if len(inp.shape) >= 2:  # multi input            index = []            dist = []            for i in inp.T:                idist, iindex = self.tree.query(i, k=k)                index.append(iindex)                dist.append(idist)            return index, dist        dist, index = self.tree.query(inp, k=k)        return index, dist    def search_in_distance(self, inp, r):        """        find points with in a distance r        """        index = self.tree.query_ball_point(inp, r)        return np.asarray(index)import numpy as npimport matplotlib.pyplot as pltimport matplotlib.animation as animationstart = time.time()fig, ar = plt.subplots()t = 0R = 50.0u = R *np.cos(t)v = R *np.sin(t)x = np.linspace(-100,100,51)y = np.linspace(-100,100,51)xx, yy = np.meshgrid(x,y)points =np.vstack((xx.ravel(),yy.ravel())).TTree = KDTree(points)ind = Tree.search_in_distance([u, v],10.0)ar.scatter(points[:,0],points[:,1],c='k',s=1)infected = points[ind]ar.scatter(infected[:,0],infected[:,1],c='r',s=5)def animate(i):    global R,t,start,points    ar.clear()    u = R *np.cos(t)    v = R *np.sin(t)    ind = Tree.search_in_distance([u, v],10.0)    ar.scatter(points[:,0],points[:,1],c='k',s=1)    infected = points[ind]    ar.scatter(infected[:,0],infected[:,1],c='r',s=5)    #points = np.delete(points,ind)    t+=0.01    end = time.time()    if end - start != 0:        print((end - start), end="\r")        start = endani = animation.FuncAnimation(fig, animate, interval=20)plt.show()  但无论我做什么,我都无法让 np.delete 处理 ball_query 方法返回的索引。我错过了什么?我想让红色点在点数组的每次迭代中消失。
查看完整描述

1 回答

?
红糖糍粑

TA贡献1815条经验 获得超6个赞

您的points数组是一个 Nx2 矩阵。您的ind索引是行索引列表。你需要的是指定你需要删除的轴,最终是这样的:

points = np.delete(points,ind,axis=0)

此外,一旦删除索引,请注意下一次迭代/计算中丢失的索引。也许您想要一个副本来删除点和绘图,另一个副本用于您不从中删除的计算。


查看完整回答
反对 回复 2022-12-20
  • 1 回答
  • 0 关注
  • 196 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信