numba是一个python的JIT编译器,用numba加速的Python函数可以达到接近于机器码的运行速度。只用在python函数上加上 @jit的装饰器,就可以将该函数编译到本地机器码运行。numba甚至还支持cuda目标,将python代码编译到cuda平台。

numba效果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numba
import numpy as np

# 普通的 Python 函数 - 计算数组平方和
def sum_squares(arr):
result = 0.0
for i in range(len(arr)):
result += arr[i] ** 2
return result

# 使用 Numba JIT 加速
@numba.njit # 使用 "no-python" 模式获得最大速度
def sum_squares_numba(arr):
result = 0.0
for i in range(len(arr)):
result += arr[i] ** 2
return result

# 创建一个大型数组
large_array = np.random.rand(10000000)

# 测试速度
%timeit sum_squares(large_array) # 纯 Python,通常较慢
%timeit sum_squares_numba(large_array) # Numba 编译后,通常快很多倍
%timeit np.sum(large_array ** 2) # 对比优化的 NumPy 向量化操作

954 ms ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

3.83 ms ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

15.7 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这个结果还是相当优秀的,用numba加速函数的速度可以赶上C语言写的numpy。

(这个表现跟CPU也有很大的关系,在我的M3 Pro芯片上,第二个和第三个表现是差不多的,都是6.27ms/loop上下;然而在我的x86架构的9800X3D芯片上,numba加速效果甚至比numpy操作还快很多。)

(实测瓶颈在内存带宽上,M系芯片内存带宽400GB/s而9800X3D内存带宽只有96GB/s,将np.sum(large_array ** 2) 替换成 np.dot(large_array, large_array)后,9800X3D的速度提升到 161 µs/loop,而M芯片上仅提升到 923 µs/loop)

内存带宽瓶颈 vs 计算优化

  1. 原始NumPy: ████████████████ 15.7ms (内存带宽受限)
  2. Numba: ████ 3.86ms (单核计算优化)
  3. 点积: ▏ 0.16ms (多核+内存+计算全优化)

numba用处

Numba 在加速特定类型的 Python 代码时效果惊人,但对另一些类型的代码则效果有限,甚至可能适得其反。理解其优势和局限性对于有效利用 Numba 至关重要。

代码特征 Numba 效果 原因/备注
密集数值循环 ⭐⭐⭐⭐⭐ 核心优势,消除循环开销,向量化。
NumPy 数组操作 ⭐⭐⭐⭐ 理解内存布局,避免临时数组。复杂操作尤佳。
math/numpy** 数学函数** ⭐⭐⭐⭐ 内联为高效指令。
明确静态类型 ⭐⭐⭐⭐ 编译优化的基础。
可并行循环 (独立迭代) ⭐⭐⭐ parallel=True@vectorize 可加速。
I/O 操作 (文件/网络/DB) 瓶颈不在 CPU,编译无用。
操作 Python list/dict/str ⭐ (或 🐢) @njit 不支持或受限;@jit 回退对象模式,极慢
调用 Pandas/Sklearn 等库 ⭐ (或 🐢) 无法编译库代码,触发回退对象模式,极慢
频繁异常处理 (try/except) 编译代码中异常开销大,非常规控制流。
动态/多变类型 多次编译开销或回退对象模式。
非常小的/单次调用函数 ⭐ (或 ⚠️) 编译开销 > 执行收益。第一次调用特别慢。
深度递归 支持有限,可能栈溢出或低效。
复杂 OOP (虚函数/多态) 不适合 Numba 的编译模型。
复杂字符串处理 支持有限,效率通常不高。

最佳实践:

  1. Profile First: 先用性能分析工具(如 cProfile, line_profiler)找出真正的计算瓶颈。只优化热点函数。
  2. 瞄准循环和数组: 优先尝试用 @njit 加速包含密集数值计算和 NumPy 数组操作的循环。
  3. 坚持 @njit (nopython): 这是获得最大加速的关键。如果 @njit 失败,仔细阅读错误信息调整代码或提供类型签名,尽量避免回退到 object mode
  4. 避免外部调用: 尽量在 JIT 函数内部使用 Numba 支持的特性(math, 有限 np 函数, 基本控制流)。将不支持的调用移到 JIT 函数外部。
  5. 处理编译开销: 对于会被反复调用的小型 JIT 函数,可以考虑在程序初始化时预先“热身”调用一次(用典型输入)来支付编译开销。对于一次性函数,纯 Python 可能更好。
  6. 类型稳定性: 确保函数内部变量类型一致且明确。必要时使用 Numba 的类型注解。

简而言之:Numba 是数值计算循环和 NumPy 数组处理的加速神器。它不适合加速涉及大量 I/O、复杂 Python 对象操作或调用外部库的代码。 明智地选择应用场景是发挥其威力的关键。

@jit

最常用的装饰器,用@jit装饰的函数,numba会尝试将其编译为LLVM IR,最后编译为你机器的本地机器码。使用这个装饰器装饰的函数,numba会默认优先尝试nopython模式,将代码编译为本地机器码,但是如果遇到numba无法理解的代码(例如pandas代码),则会退化为object模式,使用python解释器执行代码(这样我们就失去了numba带来的优化)。

jit支持以下配置:

nopython

强制启用nopython模式,如果遇到了numba理解不了的代码,则会直接报错。

同时也可以启用forceobj=True强制启用object模式,虽然这样没有什么意义。

在numba 0.59 版本后可以使用 @njit简写替换@jit(nopython=True) (最新的numba版本为0.61)。

nogil

释放python的GIL,这样我们的代码在多线程环境内运行的时候速度会变快。

这个功能不会自动将函数计算内容并行化,而是在多线程环境中调用这个函数时,函数不再会受到GIL的约束。

cache

为避免每次调用 Python 程序时都进行编译,您可以指示 Numba 将函数编译结果写入基于文件的缓存。

  • 已编译函数的缓存不是按函数逐个进行的。被缓存的是主 jit 函数,而所有次级函数(由主函数调用的函数)都会被纳入主函数的缓存中。
  • 缓存失效机制无法识别不同文件中定义的函数变更。这意味着当一个主 jit 函数调用了从其他模块导入的函数时,这些模块中的更改将不会被检测到,缓存也不会更新。这存在在计算中使用"旧"函数代码的风险。
  • 全局变量被视为常量。缓存会记住编译时全局变量的值。在缓存加载时,缓存函数不会重新绑定到全局变量的新值。

parallel

自动利用多核CPU的功能并行化计算,隐含了<font style="color:rgb(64, 64, 64);background-color:rgb(252, 252, 252);">nogil</font>

这个功能会自动将函数需要计算的内容放到多个核心上去计算,以此加速计算。

fastmath

允许编译器进行更激进的浮点运算优化,以牺牲部分 IEEE 754 标准的严格合规性为代价(牺牲精度),换取更高的计算性能。

适用于需要极致性能以及可以接收微小精度损失的场景,例如科学计算、深度学习。

不适用需要高精度的场景,例如金融数据计算、高精度科学模拟。

函数签名

numba支持在装饰器上指定函数参数的类型和返回值的类型,这样能大大减少编译时间

1
2
3
4
5
6
from numba import jit, int32

@jit(int32(int32, int32))
def f(x, y):
# A somewhat trivial example
return x + y

常见的类型有

  • **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">void</font>** 是没有返回值的函数的返回类型(当从 Python 调用时,这些函数实际上返回 **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">None</font>**)。
  • **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">intp</font>****<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uintp</font>** 是指针大小的整数(分别表示有符号和无符号)。
  • **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">intc</font>****<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uintc</font>** 分别等同于 C 语言中的 **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int</font>****<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">unsigned int</font>** 整数类型。
  • **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int8</font>**, **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uint8</font>**, **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int16</font>**, **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uint16</font>**, **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int32</font>**, **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uint32</font>**, **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int64</font>**, **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">uint64</font>** 是相应位宽(8位、16位、32位、64位)的固定宽度整数(分别表示有符号和无符号)。
  • **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">float32</font>****<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">float64</font>** 分别是单精度双精度浮点数
  • **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">complex64</font>****<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">complex128</font>** 分别是单精度双精度复数
  • 数组类型可以通过索引任何数字类型来指定,例如:
    • **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">float32[:]</font>** 表示一个一维单精度浮点数数组
    • **<font style="color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">int8[:,:]</font>** 表示一个二维 8 位有符号整数数组。

@njit

@njit装饰器是@jit(nopython=True)的缩写,效果一样,遇到不能理解的代码直接报错。

@vectorize & @guvectorize

numpy 的通用函数一般分为两种,一种是作用于标量的,称为“通用函数”或 ufuncs;另一种是作用于高维数组和标量的,称为“广义通用函数”或 gufuncs。

Numba 的 vectorize 功能允许将接收标量输入参数的 Python 函数用作 NumPy 通用函数。

函数签名

vectorize() 如果在定义的时候后标注明了函数签名,则会直接编译;如果没有注明函数签名,则会在调用的时候,每一次遇到新类型,重新编译一份代码。

一个vertorize的函数可以有一份签名或者多分签名:

1
2
3
4
5
6
7
8
9
10
11
12
# 一份签名
@vectorize([float64(float64, float64)])
def f(x, y):
return x + y

# 多份签名,每个签名编译一份源代码
@vectorize([int32(int32, int32),
int64(int64, int64),
float32(float32, float32),
float64(float64, float64)])
def f(x, y):
return x + y

@stencil

@jitclass

@cfunc

@overload