为了账号安全,请及时绑定邮箱和手机立即绑定

有效地返回数组中第一个值满足条件的索引

有效地返回数组中第一个值满足条件的索引

紫衣仙女 2019-11-20 12:48:42
我需要在满足条件的1d NumPy数组或Pandas数值系列中找到第一个值的索引。数组很大,索引可能在数组的开始或结尾附近,或者可能根本不满足条件。我无法提前告诉您哪种可能性更大。如果不满足条件,则返回值为-1。我考虑了几种方法。尝试1# func(arr) returns a Boolean arrayidx = next(iter(np.where(func(arr))[0]), -1)但这通常太慢,因为func(arr)在整个数组上应用矢量化函数,而不是在满足条件时停止。具体来说,在数组开始附近满足条件时,这很昂贵。尝试2np.argmax是稍快,但无法确定何时条件永不满足:np.random.seed(0)arr = np.random.rand(10**7)assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)%timeit next(iter(np.where(arr > 0.999999)[0]), -1)  # 21.2 ms%timeit np.argmax(arr > 0.999999)                    # 17.7 msnp.argmax(arr > 1.0)返回0,当条件,即一个实例并不满足。尝试3# func(arr) returns a Boolean scalaridx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)但这在数组末尾附近满足条件时太慢了。大概是因为生成器表达式的大量__next__调用产生了昂贵的开销。这是否总是一种折衷方案,或者对于通用而言func,是否有办法有效地提取第一个索引?标杆管理对于基准测试,假定func值大于给定常数时查找索引:# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0import numpy as npnp.random.seed(0)arr = np.random.rand(10**7)m = 0.9n = 0.999999# Start of array benchmark%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs# End of array benchmark%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms
查看完整描述

3 回答

?
胡说叔叔

TA贡献1804条经验 获得超8个赞

numba

有了numba它可以优化这两个场景。从语法上讲,您只需要构造一个带有简单for循环的函数:


from numba import njit


@njit

def get_first_index_nb(A, k):

    for i in range(len(A)):

        if A[i] > k:

            return i

    return -1


idx = get_first_index_nb(A, 0.9)

Numba通过JIT(“及时”)编译代码并利用CPU级别的优化来提高性能。一个常规的 for无环路@njit装饰通常会慢比你已经尝试了在条件满足后期的情况下的方法。


对于Pandas数值系列df['data'],您可以简单地将NumPy表示提供给JIT编译的函数:


idx = get_first_index_nb(df['data'].values, 0.9)

概括

由于numba允许将函数用作参数,并且假设传递的函数也可以JIT编译,则可以找到一种方法来计算第n个索引,其中满足任意条件的条件func。


@njit

def get_nth_index_count(A, func, count):

    c = 0

    for i in range(len(A)):

        if func(A[i]):

            c += 1

            if c == count:

                return i

    return -1


@njit

def func(val):

    return val > 0.9


# get index of 3rd value where func evaluates to True

idx = get_nth_index_count(arr, func, 3)

对于第三个最后的值,可以喂相反,arr[::-1]和否定的结果len(arr) - 1,则- 1需要考虑0索引。


绩效基准

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0


np.random.seed(0)

arr = np.random.rand(10**7)

m = 0.9

n = 0.999999


@njit

def get_first_index_nb(A, k):

    for i in range(len(A)):

        if A[i] > k:

            return i

    return -1


def get_first_index_np(A, k):

    for i in range(len(A)):

        if A[i] > k:

            return i

    return -1


%timeit get_first_index_nb(arr, m)                                 # 375 ns

%timeit get_first_index_np(arr, m)                                 # 2.71 µs

%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms

%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs


%timeit get_first_index_nb(arr, n)                                 # 204 µs

%timeit get_first_index_np(arr, n)                                 # 44.8 ms

%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms

%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms


查看完整回答
反对 回复 2019-11-20
?
Smart猫小萌

TA贡献1911条经验 获得超7个赞

我也想做类似的事情,发现这个问题中提出的解决方案并没有真正帮助我。特别是,numba对我来说,解决方案比问题本身中介绍的更常规的方法慢得多。我有一个times_all列表,通常为数万个元素的数量级,并且想要找到第一个元素的索引times_all大于a 的索引time_event。而且我有数千个time_event。我的解决方案是将其times_all分成例如100个元素的块,首先确定time_event属于哪个时间段,保留该时间段的第一个元素的索引,然后找到该时间段中的哪个索引,然后将两个索引相加。这是最少的代码。对我来说,它的运行速度比本页中的其他解决方案快几个数量级。


def event_time_2_index(time_event, times_all, STEPS=100):

    import numpy as np

    time_indices_jumps = np.arange(0, len(times_all), STEPS)

    time_list_jumps = [times_all[idx] for idx in time_indices_jumps]


    time_list_jumps_idx = next((idx for idx, val in enumerate(time_list_jumps)\

                          if val > time_event), -1)

    index_in_jumps = time_indices_jumps[time_list_jumps_idx-1]

    times_cropped = times_all[index_in_jumps:]

    event_index_rel = next((idx for idx, val in enumerate(times_cropped) \

                      if val > time_event), -1)


    event_index = event_index_rel + index_in_jumps

    return event_index


查看完整回答
反对 回复 2019-11-20
  • 3 回答
  • 0 关注
  • 1341 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信