numba
numba是一个python的JIT编译器,用numba加速的Python函数可以达到接近于机器码的运行速度。只用在python函数上加上 @jit的装饰器,就可以将该函数编译到本地机器码运行。numba甚至还支持cuda目标,将python代码编译到cuda平台。
numba效果
1 | import numba |
954 ms ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.83 ms ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
15.7 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
这个结果还是相当优秀的,用numba加速函数的速度可以赶上C语言写的numpy。
(这个表现跟CPU也有很大的关系,在我的M3 Pro芯片上,第二个和第三个表现是差不多的,都是6.27ms/loop上下;然而在我的x86架构的9800X3D芯片上,numba加速效果甚至比numpy操作还快很多。)
(实测瓶颈在内存带宽上,M系芯片内存带宽400GB/s而9800X3D内存带宽只有96GB/s,将np.sum(large_array ** 2) 替换成 np.dot(large_array, large_array)后,9800X3D的速度提升到 161 µs/loop,而M芯片上仅提升到 923 µs/loop)
内存带宽瓶颈 vs 计算优化
- 原始NumPy: ████████████████ 15.7ms (内存带宽受限)
- Numba: ████ 3.86ms (单核计算优化)
- 点积: ▏ 0.16ms (多核+内存+计算全优化)
numba用处
Numba 在加速特定类型的 Python 代码时效果惊人,但对另一些类型的代码则效果有限,甚至可能适得其反。理解其优势和局限性对于有效利用 Numba 至关重要。
| 代码特征 | Numba 效果 | 原因/备注 |
|---|---|---|
| 密集数值循环 | ⭐⭐⭐⭐⭐ | 核心优势,消除循环开销,向量化。 |
| NumPy 数组操作 | ⭐⭐⭐⭐ | 理解内存布局,避免临时数组。复杂操作尤佳。 |
math/numpy** 数学函数** |
⭐⭐⭐⭐ | 内联为高效指令。 |
| 明确静态类型 | ⭐⭐⭐⭐ | 编译优化的基础。 |
| 可并行循环 (独立迭代) | ⭐⭐⭐ | parallel=True 或 @vectorize 可加速。 |
| I/O 操作 (文件/网络/DB) | ⭐ | 瓶颈不在 CPU,编译无用。 |
| 操作 Python list/dict/str | ⭐ (或 🐢) | @njit 不支持或受限;@jit 回退对象模式,极慢。 |
| 调用 Pandas/Sklearn 等库 | ⭐ (或 🐢) | 无法编译库代码,触发回退对象模式,极慢。 |
频繁异常处理 (try/except) |
⭐ | 编译代码中异常开销大,非常规控制流。 |
| 动态/多变类型 | ⭐ | 多次编译开销或回退对象模式。 |
| 非常小的/单次调用函数 | ⭐ (或 ⚠️) | 编译开销 > 执行收益。第一次调用特别慢。 |
| 深度递归 | ⭐ | 支持有限,可能栈溢出或低效。 |
| 复杂 OOP (虚函数/多态) | ⭐ | 不适合 Numba 的编译模型。 |
| 复杂字符串处理 | ⭐ | 支持有限,效率通常不高。 |
最佳实践:
- Profile First: 先用性能分析工具(如
cProfile,line_profiler)找出真正的计算瓶颈。只优化热点函数。 - 瞄准循环和数组: 优先尝试用
@njit加速包含密集数值计算和 NumPy 数组操作的循环。 - 坚持
@njit(nopython): 这是获得最大加速的关键。如果@njit失败,仔细阅读错误信息调整代码或提供类型签名,尽量避免回退到object mode。 - 避免外部调用: 尽量在 JIT 函数内部使用 Numba 支持的特性(
math, 有限np函数, 基本控制流)。将不支持的调用移到 JIT 函数外部。 - 处理编译开销: 对于会被反复调用的小型 JIT 函数,可以考虑在程序初始化时预先“热身”调用一次(用典型输入)来支付编译开销。对于一次性函数,纯 Python 可能更好。
- 类型稳定性: 确保函数内部变量类型一致且明确。必要时使用 Numba 的类型注解。
简而言之:Numba 是数值计算循环和 NumPy 数组处理的加速神器。它不适合加速涉及大量 I/O、复杂 Python 对象操作或调用外部库的代码。 明智地选择应用场景是发挥其威力的关键。
@jit
最常用的装饰器,用@jit装饰的函数,numba会尝试将其编译为LLVM IR,最后编译为你机器的本地机器码。使用这个装饰器装饰的函数,numba会默认优先尝试nopython模式,将代码编译为本地机器码,但是如果遇到numba无法理解的代码(例如pandas代码),则会退化为object模式,使用python解释器执行代码(这样我们就失去了numba带来的优化)。
jit支持以下配置:
nopython
强制启用nopython模式,如果遇到了numba理解不了的代码,则会直接报错。
同时也可以启用forceobj=True强制启用object模式,虽然这样没有什么意义。
在numba 0.59 版本后可以使用 @njit简写替换@jit(nopython=True) (最新的numba版本为0.61)。
nogil
释放python的GIL,这样我们的代码在多线程环境内运行的时候速度会变快。
这个功能不会自动将函数计算内容并行化,而是在多线程环境中调用这个函数时,函数不再会受到GIL的约束。
cache
为避免每次调用 Python 程序时都进行编译,您可以指示 Numba 将函数编译结果写入基于文件的缓存。
- 已编译函数的缓存不是按函数逐个进行的。被缓存的是主 jit 函数,而所有次级函数(由主函数调用的函数)都会被纳入主函数的缓存中。
- 缓存失效机制无法识别不同文件中定义的函数变更。这意味着当一个主 jit 函数调用了从其他模块导入的函数时,这些模块中的更改将不会被检测到,缓存也不会更新。这存在在计算中使用"旧"函数代码的风险。
- 全局变量被视为常量。缓存会记住编译时全局变量的值。在缓存加载时,缓存函数不会重新绑定到全局变量的新值。
parallel
自动利用多核CPU的功能并行化计算,隐含了<font style="color:rgb(64, 64, 64);background-color:rgb(252, 252, 252);">nogil</font>。
这个功能会自动将函数需要计算的内容放到多个核心上去计算,以此加速计算。
fastmath
允许编译器进行更激进的浮点运算优化,以牺牲部分 IEEE 754 标准的严格合规性为代价(牺牲精度),换取更高的计算性能。
适用于需要极致性能以及可以接收微小精度损失的场景,例如科学计算、深度学习。
不适用需要高精度的场景,例如金融数据计算、高精度科学模拟。
函数签名
numba支持在装饰器上指定函数参数的类型和返回值的类型,这样能大大减少编译时间。
1 | from numba import jit, int32 |
常见的类型有
**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">void</font>**是没有返回值的函数的返回类型(当从 Python 调用时,这些函数实际上返回**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">None</font>**)。**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">intp</font>**和**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uintp</font>**是指针大小的整数(分别表示有符号和无符号)。**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">intc</font>**和**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uintc</font>**分别等同于 C 语言中的**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int</font>**和**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">unsigned int</font>**整数类型。**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int8</font>**,**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uint8</font>**,**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int16</font>**,**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uint16</font>**,**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int32</font>**,**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uint32</font>**,**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int64</font>**,**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uint64</font>**是相应位宽(8位、16位、32位、64位)的固定宽度整数(分别表示有符号和无符号)。**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">float32</font>**和**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">float64</font>**分别是单精度和双精度浮点数。**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">complex64</font>**和**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">complex128</font>**分别是单精度和双精度复数。- 数组类型可以通过索引任何数字类型来指定,例如:
**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">float32[:]</font>**表示一个一维单精度浮点数数组**<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int8[:,:]</font>**表示一个二维 8 位有符号整数数组。
@njit
@njit装饰器是@jit(nopython=True)的缩写,效果一样,遇到不能理解的代码直接报错。
@vectorize & @guvectorize
numpy 的通用函数一般分为两种,一种是作用于标量的,称为“通用函数”或 ufuncs;另一种是作用于高维数组和标量的,称为“广义通用函数”或 gufuncs。
Numba 的 vectorize 功能允许将接收标量输入参数的 Python 函数用作 NumPy 通用函数。
函数签名
vectorize() 如果在定义的时候后标注明了函数签名,则会直接编译;如果没有注明函数签名,则会在调用的时候,每一次遇到新类型,重新编译一份代码。
一个vertorize的函数可以有一份签名或者多分签名:
1 | # 一份签名 |