3 回答
TA贡献1804条经验 获得超8个赞
numba
有了numba它可以优化这两个场景。从语法上讲,您只需要构造一个带有简单for循环的函数:
from numba import njit
@njit
def get_first_index_nb(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
idx = get_first_index_nb(A, 0.9)
Numba通过JIT(“及时”)编译代码并利用CPU级别的优化来提高性能。一个常规的 for无环路@njit装饰通常会慢比你已经尝试了在条件满足后期的情况下的方法。
对于Pandas数值系列df['data'],您可以简单地将NumPy表示提供给JIT编译的函数:
idx = get_first_index_nb(df['data'].values, 0.9)
概括
由于numba允许将函数用作参数,并且假设传递的函数也可以JIT编译,则可以找到一种方法来计算第n个索引,其中满足任意条件的条件func。
@njit
def get_nth_index_count(A, func, count):
c = 0
for i in range(len(A)):
if func(A[i]):
c += 1
if c == count:
return i
return -1
@njit
def func(val):
return val > 0.9
# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)
对于第三个最后的值,可以喂相反,arr[::-1]和否定的结果len(arr) - 1,则- 1需要考虑0索引。
绩效基准
# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999
@njit
def get_first_index_nb(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
def get_first_index_np(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
%timeit get_first_index_nb(arr, m) # 375 ns
%timeit get_first_index_np(arr, m) # 2.71 µs
%timeit next(iter(np.where(arr > m)[0]), -1) # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1) # 2.5 µs
%timeit get_first_index_nb(arr, n) # 204 µs
%timeit get_first_index_np(arr, n) # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1) # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1) # 39.2 ms
TA贡献1911条经验 获得超7个赞
我也想做类似的事情,发现这个问题中提出的解决方案并没有真正帮助我。特别是,numba对我来说,解决方案比问题本身中介绍的更常规的方法慢得多。我有一个times_all列表,通常为数万个元素的数量级,并且想要找到第一个元素的索引times_all大于a 的索引time_event。而且我有数千个time_event。我的解决方案是将其times_all分成例如100个元素的块,首先确定time_event属于哪个时间段,保留该时间段的第一个元素的索引,然后找到该时间段中的哪个索引,然后将两个索引相加。这是最少的代码。对我来说,它的运行速度比本页中的其他解决方案快几个数量级。
def event_time_2_index(time_event, times_all, STEPS=100):
import numpy as np
time_indices_jumps = np.arange(0, len(times_all), STEPS)
time_list_jumps = [times_all[idx] for idx in time_indices_jumps]
time_list_jumps_idx = next((idx for idx, val in enumerate(time_list_jumps)\
if val > time_event), -1)
index_in_jumps = time_indices_jumps[time_list_jumps_idx-1]
times_cropped = times_all[index_in_jumps:]
event_index_rel = next((idx for idx, val in enumerate(times_cropped) \
if val > time_event), -1)
event_index = event_index_rel + index_in_jumps
return event_index
添加回答
举报