为了账号安全,请及时绑定邮箱和手机立即绑定

使用 numba 优化 while 循环以实现容错

使用 numba 优化 while 循环以实现容错

紫衣仙女 2022-10-18 19:46:39
我在使用 numba 进行优化时有疑问。我正在编写一个定点迭代来计算一个名为 gamma 的数组的值,它满足方程 f(gamma)=gamma。我正在尝试使用 python 包 Numba 优化此功能。看起来如下。@jitdef fixed_point(gamma_guess):    for i in range(17):        gamma_guess=f(gamma_guess)    return gamma_guessNumba 能够很好地优化这个功能,因为它知道它将执行多少次操作,17 次,并且运行速度很快。但是我需要控制我想要的伽玛的误差容限,我的意思是,一个伽玛和下一个通过定点迭代获得的伽玛之差应该小于某个数字epsilon = 0.01,然后我尝试了@jitdef fixed_point(gamma_guess):    err=1000    gamma_old=gamma_guess.copy()    while(error>0.01):        gamma_guess=f(gamma_guess)        err=np.max(abs(gamma_guess-gamma_old))        gamma_old=gamma_guess.copy()    return gamma_guess它也可以工作并计算所需的结果,但不如上次实现快,它要慢得多。我认为这是因为 Numba 无法很好地优化 while 循环,因为我们不知道它何时会停止。有没有办法可以优化它并像上次实现一样快地运行?编辑:这是我正在使用的 ffrom scipy import fftpack as spS=0.01Amu=0.7@jit def f(gammaa,z,zal,kappa):    ka=sp.diff(kappa)    gamma0=gammaa    for i in range(N):        suma=0        for j in range(N):            if (abs(j-i))%2 ==1:                if((z[i]-z[j])==0):                    suma+=(gamma0[j]/(z[i]-z[j]))           gamma0[i]=2.0*Amu*np.real(-(zal[i]/z[i])+zal[i]*(1.0/(2*np.pi*1j))*suma*2*h)+S*ka[i]    return  gamma0我总是用作初始猜测,np.ones(2048)*0.5传递给我的函数的其他参数是z=np.cos(alphas)+1j*(np.sin(alphas)+0.1)、zal=-np.sin(alphas)+1j*np.cos(alphas)和kappa=np.ones(2048)alphas=np.arange(0,2*np.pi,2*np.pi/2048)
查看完整描述

1 回答

?
函数式编程

TA贡献1807条经验 获得超9个赞

我做了一个小测试脚本,看看我是否可以重现你的错误:


import numba as nb


from IPython import get_ipython

ipython = get_ipython()


@nb.jit(nopython=True)

def f(x):

    return (x+1)/x



def fixed_point_for(x):

    for _ in range(17):

        x = f(x)

    return x


@nb.jit(nopython=True)

def fixed_point_for_nb(x):

    for _ in range(17):

        x = f(x)

    return x


def fixed_point_while(x):

    error=1

    x_old = x

    while error>0.01:

        x = f(x)

        error = abs(x_old-x)

        x_old = x

    return x


@nb.jit(nopython=True)

def fixed_point_while_nb(x):

    error=1

    x_old = x

    while error>0.01:

        x = f(x)

        error = abs(x_old-x)

        x_old = x

    return x


print("for loop without numba:")

ipython.magic("%timeit fixed_point_for(10)")


print("for loop with numba:")

ipython.magic("%timeit fixed_point_for_nb(10)")


print("while loop without numba:")

ipython.magic("%timeit fixed_point_while(10)")


print("for loop with numba:")

ipython.magic("%timeit fixed_point_while_nb(10)")

由于我不了解您的f情况,因此我只使用了我能想到的最简单的稳定功能。然后,我在使用和不numba使用循环的情况下运行测试。我机器上的结果是:forwhile


for loop without numba:

3.35 µs ± 8.72 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

for loop with numba:

282 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

while loop without numba:

1.86 µs ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

for loop with numba:

214 ns ± 1.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

产生以下想法:

  • 不可能,你的函数是不可优化的,因为你的for循环很快(至少你是这么说的;你没有测试过numba吗?)。

  • 如您所想,您的函数可能需要更多的循环才能收敛

  • 我们使用不同的软件版本。我的版本是:

    • numba  0.49.0

    • numpy  1.18.3

    • python 3.8.2


查看完整回答
反对 回复 2022-10-18
  • 1 回答
  • 0 关注
  • 127 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信