为了账号安全,请及时绑定邮箱和手机立即绑定

如何编译具有可变输入类型的 numba jit'ed 函数?

如何编译具有可变输入类型的 numba jit'ed 函数?

幕布斯6054654 2022-01-18 21:33:06
假设我有一个函数可以同时接受一个int或一个None类型作为输入参数import numba as nbimport numpy as npjitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}@nb.jit("f8(i8)", **jitkw)def get_random(seed=None):    np.random.seed(None)    out = np.random.normal()    return out我希望函数简单地返回一个正态分布的随机数。如果我想要可重现的结果,种子应该是int.get_random(42)>>> 0.4967141530112327get_random(42)>>> 0.4967141530112327get_random(42)>>> 0.4967141530112327如果我想要随机数,seed应保留为None. 但是,如果我不传递参数(因此种子默认为None)或显式传递seed=None,那么 numba 会引发TypeErrorget_random()>>> TypeError: No matching definition for argument type(s) omitted(default=None)get_random(None)>>> TypeError: No matching definition for argument type(s) omitted(default=None)在这种情况下,我该如何编写函数,仍然声明签名和使用nopython模式?我的 numba 版本是 0.43.1
查看完整描述

1 回答

?
慕侠2389804

TA贡献1719条经验 获得超6个赞

第一个问题是 nopython 模式下的 numba 仅接受(从版本 0.43.1 开始)np.random.seed:仅使用整数参数。


因此,很遗憾,您无法通过None.


第二个问题是(据我所知)没有告诉 numba 如何处理缺失值的“单一”签名,但是您可以使用两个签名(是的,它非常冗长):


import numba as nb

import numpy as np


jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}


@nb.jit(

    [nb.types.float64(nb.types.misc.Omitted(None)), 

     nb.types.float64(nb.types.int64)], 

    **jitkw)

def get_random(seed=None):

    return np.random.normal()

只是关于签名的两个部分的简短说明:


如果省略参数,则告诉 numba 用作默认nb.types.float64(nb.types.misc.Omitted(None))类型None

是 nb.types.float64(nb.types.int64)需要整数的签名。

就我个人而言,我不会指定签名,只是让 numba 弄清楚。显式签名在 numba 中很少值得,而且更常见的是,它们会导致代码变慢且不灵活。


查看完整回答
反对 回复 2022-01-18
  • 1 回答
  • 0 关注
  • 166 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信