为了账号安全,请及时绑定邮箱和手机立即绑定

熊猫:提高滚动窗口的速度(应用自定义功能)

熊猫:提高滚动窗口的速度(应用自定义功能)

红颜莎娜 2022-01-18 17:26:27
我正在使用此代码funcX使用滚动窗口在我的数据框上应用函数 ()。主要问题是这个数据框 ( data) 的大小非常大,我正在寻找一种更快的方法来完成这项任务。import numpy as npdef funcX(x):    x = np.sort(x)    xd = np.delete(x, 25)    med = np.median(xd)    return (np.abs(x - med)).mean() + medmed_out = data.var1.rolling(window = 51, center = True).apply(funcX, raw = True)使用这个函数的唯一原因是计算出的中位数是去掉中间值后的中位数。所以.median()在滚动窗口的末尾添加是不同的。
查看完整描述

1 回答

?
慕村9548890

TA贡献1884条经验 获得超4个赞

为了有效,窗口算法必须链接两个重叠窗口的结果。


在这里,与 :med0中位数,排序后的元素med中的中位数 x \ med0,xl之前的元素med和xg之后的元素,可以看作: medfuncX(x)


<|x-med|> + med = [sum(xg) - sum(xl) - |med0-med|] / windowsize + med  

因此,一个想法是维护一个表示已排序当前窗口的缓冲区,sum(xg)并且sum(xl). 使用 Numba 即时编译,这里会出现非常好的性能。


首先是缓冲区管理:


init对第一个窗口进行排序并计算 left( xls) 和 right( xgs) 总和。


import numpy as np

import numba

windowsize = 51 #odd, >1

halfsize = windowsize//2


@numba.njit

def init(firstwindow):

    buffer = np.sort(firstwindow)

    xls = buffer[:halfsize].sum()

    xgs = buffer[-halfsize:].sum()   

    return buffer,xls,xgs

shift是线性部分。它更新缓冲区,保持它的排序。np.searchsorted计算 中的插入和删除位置O(log(windowsize))。这是技术性的xin<xout,因为xout<xin不是对称的情况。


@numba.njit

def shift(buffer,xin,xout):

    i_in = np.searchsorted(buffer,xin) 

    i_out = np.searchsorted(buffer,xout)

    if xin <= xout :

        buffer[i_in+1:i_out+1] = buffer[i_in:i_out] 

        buffer[i_in] = xin                        

    else:

        buffer[i_out:i_in-1] = buffer[i_out+1:i_in]                      

        buffer[i_in-1] = xin

    return i_in, i_out

update更新缓冲区和左右部分的总和。这是技术性的xin<xout,因为xout<xin不是对称的情况。


@numba.njit

def update(buffer,xls,xgs,xin,xout):

    xl,x0,xg = buffer[halfsize-1:halfsize+2]

    i_in,i_out = shift(buffer,xin,xout)


    if i_out < halfsize:

        xls -= xout

        if i_in <= halfsize:

            xls += xin

        else:    

            xls += x0

    elif i_in < halfsize:

        xls += xin - xl


    if i_out > halfsize:

        xgs -= xout

        if i_in > halfsize:

            xgs += xin

        else:    

            xgs += x0

    elif i_in > halfsize+1:

        xgs += xin - xg


    return buffer, xls, xgs

func相当于原来funcX的on buffer。O(1).


@numba.njit       

def func(buffer,xls,xgs):

    med0 = buffer[halfsize]

    med  = (buffer[halfsize-1] + buffer[halfsize+1])/2

    if med0 > med:

        return (xgs-xls+med0-med) / windowsize + med

    else:               

        return (xgs-xls+med-med0) / windowsize + med    

med是全局函数。O(data.size * windowsize).


@numba.njit

def med(data):

    res = np.full_like(data, np.nan)

    state = init(data[:windowsize])

    res[halfsize] = func(*state)

    for i in range(windowsize, data.size):

        xin,xout = data[i], data[i - windowsize]

        state = update(*state, xin, xout)

        res[i-halfsize] = func(*state)

    return res 

表现 :


import pandas

data=pandas.DataFrame(np.random.rand(10**5))


%time res1=data[0].rolling(window = windowsize, center = True).apply(funcX, raw = True)

Wall time: 10.8 s


res2=med(data[0].values)


np.allclose((res1-res2)[halfsize:-halfsize],0)

Out[112]: True


%timeit res2=med(data[0].values)

40.4 ms ± 462 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

它快了 250 倍,窗口大小 = 51。一小时变成了 15 秒。


查看完整回答
反对 回复 2022-01-18
  • 1 回答
  • 0 关注
  • 195 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号