3 回答
![?](http://img1.sycdn.imooc.com/533e4c640001354402000200-100-100.jpg)
TA贡献1864条经验 获得超2个赞
大卫确实想出了一个很好的解决方案,可以在以后找到最接近的更高价格。然而,我确实想在稍后的时间找到下一个更高的价格。我们与我的同事一起找到了这个解决方案。
包含元组的堆栈(索引、价格)
迭代所有行(索引 i)
当堆栈非空并且堆栈顶部的价格较低时,则弹出并用 times[index] 填充弹出的索引
将 (i,prices[i]) 压入堆栈
import numpy as np
import pandas as pd
df = pd.DataFrame({'time': [15, 30, 45, 60, 75, 90], 'price': [10.00, 10.01, 10.00, 10.01, 10.02, 9.99]})
print(df)
time price
0 15 10.00
1 30 10.01
2 45 10.00
3 60 10.01
4 75 10.02
5 90 9.99
times = df['time'].to_numpy()
prices = df['price'].to_numpy()
stack = []
next_times = np.full(len(df), np.nan)
for i in range(len(df)):
while stack and prices[i] > stack[-1][1]:
stack_time_index, stack_price = stack.pop()
next_times[stack_time_index] = times[i]
stack.append((i, prices[i]))
df['next_time'] = next_times
print(df)
time price next_time
0 15 10.00 30.0
1 30 10.01 75.0
2 45 10.00 60.0
3 60 10.01 75.0
4 75 10.02 NaN
5 90 9.99 NaN
该解决方案实际上执行速度非常快。我不完全确定,但我相信复杂性将接近O(n),因为它是对整个数据帧的一次完整传递。其表现如此良好的原因是堆栈本质上是排序的,其中最大的价格位于底部,最小的价格位于堆栈的顶部。
这是我对实际数据框的测试
print(f'{len(df):,.0f} rows with {len(df["price"].unique()):,.0f} unique prices ranging from ${df["price"].min():,.2f} to ${df["price"].max():,.2f}')
667,037 rows with 11,786 unique prices ranging from $1,857.52 to $2,022.00
def find_next_time_with_greater_price(df):
times = df['time'].to_numpy()
prices = df['price'].to_numpy()
stack = []
next_times = np.full(len(df), np.nan)
for i in range(len(df)):
while stack and prices[i] > stack[-1][1]:
stack_time_index, stack_price = stack.pop()
next_times[stack_time_index] = times[i]
stack.append((i, prices[i]))
return next_times
%timeit -n10 -r10 df['next_time'] = find_next_time_with_greater_price(df)
434 ms ± 11.8 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
![?](http://img1.sycdn.imooc.com/5458477300014deb02200220-100-100.jpg)
TA贡献1779条经验 获得超6个赞
这个在不到 7 秒的时间内为我返回了包含 1,000,000 行和 162,000 个唯一价格的数据框变体。因此,我认为既然你在 660,000 行和 12,000 个唯一价格上运行它,速度的提高将是 100x-1000x。
您的问题更加复杂,因为最接近的较高价格必须在稍后的时间出现。我必须从几个不同的角度来解决这个问题(正如您在关于我的评论中提到的那样,np.where()
将其分解为几种不同的方法)。
import pandas as pd
df = pd.DataFrame({'time': [15, 30, 45, 60, 75, 90], 'price': [10.00, 10.01, 10.00, 10.01, 10.02, 9.99]})
def bisect_right(a, x, lo=0, hi=None):
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if x < a[mid]: hi = mid
else: lo = mid+1
return lo
def get_closest_higher(df, col, val):
higher_idx = bisect_right(df[col].values, val)
return higher_idx
df = df.sort_values(['price', 'time']).reset_index(drop=True)
df['next_time'] = df['price'].apply(lambda x: get_closest_higher(df, 'price', x))
df['next_time'] = df['next_time'].map(df['time'])
df['next_time'] = np.where(df['next_time'] <= df['time'], np.nan, df['next_time'] )
df = df.sort_values('time').reset_index(drop=True)
df['next_time'] = np.where((df['price'].shift(-1) > df['price'])
,df['time'].shift(-1),
df['next_time'])
df['next_time'] = df['next_time'].ffill()
df['next_time'] = np.where(df['next_time'] <= df['time'], np.nan, df['next_time'])
df
Out[1]:
time price next_time
0 15 10.00 30.0
1 30 10.01 75.0
2 45 10.00 60.0
3 60 10.01 75.0
4 75 10.02 NaN
5 90 9.99 NaN
![?](http://img1.sycdn.imooc.com/533e4c1500010baf02200220-100-100.jpg)
TA贡献1735条经验 获得超5个赞
%timeit
当我在此示例上进行测试时,这些解决方案速度更快,但我在更大的数据帧上进行了测试,它们比您的解决方案慢得多。看看这 3 个解决方案中的任何一个在较大的数据框中是否更快,这将是很有趣的。
我希望其他人能够发布更有效的解决方案。以下是一些不同的答案:
您可以使用单行代码来实现这一点,该单行代码同时
next
循环遍历time
和列。该函数的工作方式与列表理解完全相同,但您需要使用圆括号而不是方括号,并且它仅返回第一个值。您还需要将处理错误作为函数中的参数传递。price
zip
next
True
None
next
您需要通过
axis=1
,因为您正在按列进行比较。
这应该会提高性能,因为当迭代在返回第一个值并移动到下一行后停止时,您不会循环遍历整个列。
import pandas as pd
df = pd.DataFrame({'time': [15, 30, 45, 60, 75, 90], 'price': [10.00, 10.01, 10.00, 10.01, 10.02, 9.99]})
print(df)
time price
0 15 10.00
1 30 10.01
2 45 10.00
3 60 10.01
4 75 10.02
5 90 9.99
df['next_time'] = (df.apply(lambda x: next((z for (y, z) in zip(df['price'], df['time'])
if y > x['price'] if z > x['time']), None), axis=1))
df
Out[1]:
time price next_time
0 15 10.00 30.0
1 30 10.01 75.0
2 45 10.00 60.0
3 60 10.01 75.0
4 75 10.02 NaN
5 90 9.99 NaN
正如您所看到的,列表理解会返回相同的结果,但理论上会慢很多......因为迭代总数会显着增加,尤其是对于大型数据帧。
df['next_time'] = (df.apply(lambda x: [z for (y, z) in zip(df['price'], df['time'])
if y > x['price'] if z > x['time']], axis=1)).str[0]
df
Out[2]:
time price next_time
0 15 10.00 30.0
1 30 10.01 75.0
2 45 10.00 60.0
3 60 10.01 75.0
4 75 10.02 NaN
5 90 9.99 NaN
使用 some 和 np.where() 创建函数的另一个选项numpy:
def closest(x):
try:
lst = df.groupby(df['price'].cummax())['time'].transform('first')
lst = np.asarray(lst)
lst = lst[lst>x]
idx = (np.abs(lst - x)).argmin()
return lst[idx]
except ValueError:
pass
df['next_time'] = np.where((df['price'].shift(-1) > df['price']),
df['time'].shift(-1),
df['time'].apply(lambda x: closest(x)))
添加回答
举报