Flax#

為 JAX 打造的神經網路


Flax 為使用 JAX 進行神經網路研究和開發的開發人員和研究人員提供彈性的端到端使用者體驗。Flax 讓您能夠充分利用 JAX 的強大功能。

Flax 的核心是 NNX - 一個簡化的 API,讓您能夠更輕鬆地在 JAX 中建立、檢查、除錯和分析神經網路。 Flax NNX 具備對 Python 參考語意的首要支援,讓使用者可以使用常規的 Python 物件來表達其模型。Flax NNX 是先前 Flax Linen API 的演進,並經過多年的經驗,才帶來了更簡單且更易於使用的 API。

注意

Flax Linen API 不會在近期內被棄用,因為大多數 Flax 使用者仍然依賴此 API。然而,我們鼓勵新使用者使用 Flax NNX。請查看為何選擇 Flax NNX,以比較 Flax NNX 和 Linen,以及我們建立新 API 的理由。

若要將您的 Flax Linen 程式碼庫遷移到 Flax NNX,請先熟悉 NNX 基礎知識 中的 API,然後依照演進指南開始您的遷移。

特色#

Pythonic

Flax NNX 支援使用常規的 Python 物件,提供直覺且可預測的開發體驗。

簡單

Flax NNX 依賴 Python 的物件模型,這可讓使用者更簡單,並加快開發速度。

富有表現力

Flax NNX 可透過其過濾器系統對模型的狀態進行精細的控制。

熟悉

Flax NNX 透過函數式 API,讓物件可以非常容易地與常規 JAX 程式碼整合。

基本用法#

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit  # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates

  return loss

安裝#

透過 pip 安裝

pip install flax

或從存放庫安裝最新版本

pip install git+https://github.com/google/flax.git

了解更多#

Flax NNX 基礎知識
nnx_basics.html
MNIST 教學
mnist_tutorial.html
Flax Linen 到 Flax NNX
guides/linen_to_nnx.html