Từ SciPy sang Diffrax: Cách tôi giải quyết nút thắt hiệu năng trong suy luận Bayesian
Một nhà vũ trụ học đã chia sẻ hành trình chuyển đổi từ SciPy sang Diffrax để giải quyết vấn đề hiệu năng khi thực hiện suy luận Bayesian. Nhờ khả năng biên dịch JIT và tự động phân biệt gradient, Diffrax giúp tăng tốc độ tính toán gấp nhiều lần so với phương pháp truyền thống.

Từ SciPy sang Diffrax: Cách tôi giải quyết nút thắt hiệu năng trong suy luận Bayesian
Là một nhà vũ trụ học lý thuyết, công việc của tôi thường xuyên xoay quanh việc kiểm chứng các mô hình Vũ trụ—như phương trình trạng thái của năng lượng tối, trọng lực sửa đổi hay các trường tachyonic—và trả lời câu hỏi: Dữ liệu thực tế nói gì về các tham số này? Công cụ tôi sử dụng để trả lời câu hỏi đó là suy luận Bayesian (Bayesian inference). Tôi thường chạy thuật toán lấy mẫu lồng nhau (nested sampling) như dynesty với vài nghìn đến vài trăm nghìn lần đánh giá hàm khả năng (likelihood evaluations), tùy thuộc vào độ phức tạp của mô hình.
Trong phần lớn thời gian làm Tiến sĩ, tôi không quá bận tâm đến bộ giải phương trình vi phân thường (ODE solver) nằm bên trong hàm khả năng miễn là nó hoạt động tốt. solve_ivp của SciPy rất đáng tin cậy, nên tôi đã sử dụng nó và chuyển sang việc khác.
Tuy nhiên, mọi thứ thay đổi khi tôi bắt đầu làm việc với một mô hình năng lượng tối tachyonic DBI, trong đó trường năng lượng tối được điều khiển bởi một số hạng động học phi tiêu chuẩn. Các phương trình nền và nhiễu loạn tạo thành một hệ liên kết khá "cứng" (stiff). Mỗi lần gọi hàm khả năng cần giải các ODE này, tính toán khoảng cách cùng chuyển động (comoving distance) và đánh giá mô đun khoảng cách tại dịch chuyển đỏ (redshift) của 30 siêu tân tinh.
Biểu đồ phân bổ thời gian thực thi
Sau khi đo hiệu năng (profiling), tôi nhận thấy chỉ riêng việc giải ODE đã mất 0,4 mili-giây mỗi lần gọi. Trong một lần chạy lấy mẫu lồng nhau với 10⁵ lần đánh giá, con số này lên tới 40 giây—chỉ để giải ODE, chưa tính đến các chi phí quản lý khác. Với một mô hình có 10 tham số, việc lấy gradient thông qua sai phân hữu hạn trung tâm tốn thêm 20 lần giải tiến, biến 0,4 ms thành 8 ms cho mỗi gradient. Tổng cộng là 300 giây, hay khoảng 5 phút, chỉ để tính gradient cho một lần chạy duy nhất.
Đã đến lúc phải thay đổi.
Giải pháp tôi tìm thấy: Diffrax
Sau một ngày tìm kiếm, tôi đã tìm thấy Diffrax, một thư viện bộ giải ODE số được viết hoàn toàn bằng JAX. Đây không phải là một mạng nơ-ron thay thế (neural surrogate) hay một phép xấp xỉ. Đó là cùng các thuật toán Runge–Kutta nhúng mà tôi đã dùng trong SciPy—Tsit5 thay vì RK45, nhưng cùng một họ phương pháp—chỉ khác là được biên dịch, có thể phân biệt và vector hóa.
Ba thuộc tính chính xuất phát từ thiết kế "viết hoàn toàn bằng JAX" bao gồm:
- Biên dịch JIT (JIT compilation): Toàn bộ vòng lặp bước thích nghi được biên dịch thành một nhân XLA duy nhất. Chi phí bằng không cho Python sau lần gọi đầu tiên.
- Tự động phân biệt (Autodiff): Vì mọi thao tác bên trong bộ giải đều là nguyên thủy của JAX,
jax.gradlan truyền gradient qua quá trình giải. Gradient chính xác. Một lần lùi ngược. Bất kể có bao nhiêu tham số. - Vector hóa (vmap): Một lô toàn bộ các vector tham số có thể được giải song song với
jax.vmap. Điều này cực kỳ quan trọng cho việc lấy mẫu lồng nhau.
Cài đặt nó chỉ mất 10 giây: pip install jax diffrax.
Bất ngờ thứ nhất: Tốc độ
solve_ivp: 404 µs mỗi lần gọi.
diffrax sau JIT: 59 µs mỗi lần gọi.
Đó là nhanh hơn ~7 lần.
Tôi đã nhìn chằm chằm vào con số này vài giây khi lần đầu thấy nó. Hãy thành thật về nguồn gốc thực sự của việc tăng tốc này, vì nó không phải là phép màu.
Trong solve_ivp, Python phải quay lại phần phụ trợ C/Cython ở mỗi lần gọi. Bộ nhớ được cấp phát mới. Vòng lặp while thích nghi đi qua trình thông dịch Python, liên tục hỏi: "lỗi cục bộ có quá lớn không? từ chối; nếu không thì tăng bước; lặp lại". Với một bài toán giải 12 bước, đó là 12 vòng điều phối Python, 12 lần cấp phát bộ nhớ và 12 lần tính toán ước lượng lỗi nằm sau khóa thông dịch.
Trong diffrax, lần gọi @jax.jit đầu tiên sẽ truy xuất toàn bộ tính toán—bao gồm cả vòng lặp while thích nghi, được hạ cấp xuống lax.while_loop và giao cho XLA biên dịch thành một nhân mã máy. Mọi lần gọi sau đó thực thi nhân đó trực tiếp. Do đó, không còn Python, không cần cấp phát bộ nhớ và không có chi phí điều phối.
So sánh tốc độ giữa SciPy và Diffrax
Với 100.000 lần đánh giá hàm khả năng, 404 µs so với 59 µs tương đương với 40,4 giây so với 5,9 giây. Đó là sự khác biệt sẽ được khuếch đại khi độ phức tạp của mô hình tăng lên.
Bất ngờ thứ hai: Gradient trở nên "miễn phí"
Đây là phần không chỉ thay đổi quy trình làm việc của tôi mà còn thay đổi cách tôi nghĩ về suy luận. Với scipy, việc lấy một gradient của log-khả năng đối với 2 tham số (Ωₘ, H₀) tốn 4 lần giải tiến (sai phân hữu hạn trung tâm). Khi bạn bắt đầu tăng số lượng tham số, chi phí tăng rất nhanh: 10 tham số nghĩa là 20 lần giải tiến, 50 tham số nghĩa là 100 lần. Chi phí tăng tuyến tính theo số tham số.
Với diffrax, tôi chỉ cần viết:
def loss(theta):
mu_pred = forward_diffrax(theta, z_obs)
return 0.5 * jnp.sum(((mu_pred - mu_obs) / sigma_mu)**2)
grad_fn = jax.jit(jax.grad(loss)) # đây là toàn bộ thay đổi
g = grad_fn(jnp.array([0.3, 70.0])) # gradient chính xác
Bên dưới, autodiff chế độ ngược của JAX tích hợp các phương trình adjoint ngược qua quá trình giải ODE—nhưng tôi không bao giờ phải viết các phương trình đó. Kết quả là một gradient chính xác trong thời gian tương đương với một lần lặp tiến, độc lập với số lượng tham số.
So sánh chi phí tính toán gradient
Lưu ý khi chọn bộ giải và sử dụng
Khi chọn bộ giải, bạn cần lưu ý một chút. Tôi mặc định dùng Tsit5 cho hầu hết mọi thứ và nó xử lý khoảng 95% vấn đề của tôi mà không phàn nàn.
- ODE không cứng (hầu hết các vấn đề vũ trụ học):
dfx.Tsit5()<- bắt đầu từ đây. - Dung sai rất chặt chẽ (< 10⁻⁸):
dfx.Dopri8(). - ODE cứng (stiff):
dfx.Kvaerno5()hoặcdfx.Radau().
Tuy nhiên, có một "cạm bẫy" quan trọng: Dùng số thực 64-bit. JAX mặc định dùng 32-bit để tăng tốc GPU, nhưng vật lý và thiên văn học cần độ chính xác cao hơn. Bạn phải bật nó:
jax.config.update("jax_enable_x64", True)
Nếu quên dòng này, các kết quả vật lý của bạn sẽ sai một cách thảm hại.
Kết luận
Việc chuyển đổi mô hình tiến của tôi sang diffrax không thay đổi vật lý hay phương pháp suy luận. Nó thay đổi tính khả thi về mặt thực tế của việc thực hiện suy luận đó. Một lần chạy lấy mẫu lồng nhau vốn đang tốn nhiều thời gian cho mô hình tiến nay trở nên mất chưa đến một phút. Những gradient vốn tốn 20 lần giải thêm cho mỗi bước giờ về cơ bản là miễn phí.
Đường cong học hỏi chỉ mất một buổi chiều. Việc gỡ lỗi chủ yếu xoay quanh vấn đề 64-bit và sự nhầm lẫn khi khởi tạo JIT. Phần thưởng đã thực sự và tức thì.
Nếu bạn là một nhà vật lý sử dụng scipy cho các lần đánh giá hàm khả năng lặp đi lặp lại và chưa từng xem xét diffrax, tôi hy vọng bài viết này đã cung cấp cho bạn một lý do để làm điều đó.
