Flax#
為 JAX 打造的神經網路
Flax 為使用 JAX 進行神經網路研究和開發的開發人員和研究人員提供彈性的端到端使用者體驗。Flax 讓您能夠充分利用 JAX 的強大功能。
Flax 的核心是 NNX - 一個簡化的 API,讓您能夠更輕鬆地在 JAX 中建立、檢查、除錯和分析神經網路。 Flax NNX 具備對 Python 參考語意的首要支援,讓使用者可以使用常規的 Python 物件來表達其模型。Flax NNX 是先前 Flax Linen API 的演進,並經過多年的經驗,才帶來了更簡單且更易於使用的 API。
注意
Flax Linen API 不會在近期內被棄用,因為大多數 Flax 使用者仍然依賴此 API。然而,我們鼓勵新使用者使用 Flax NNX。請查看為何選擇 Flax NNX,以比較 Flax NNX 和 Linen,以及我們建立新 API 的理由。
若要將您的 Flax Linen 程式碼庫遷移到 Flax NNX,請先熟悉 NNX 基礎知識 中的 API,然後依照演進指南開始您的遷移。
特色#
基本用法#
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
@nnx.jit # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # in-place updates
return loss