Flax#

為 JAX 打造的神經網路


Flax 為使用 JAX 進行神經網路的研究人員和開發人員提供彈性的端到端使用者體驗。Flax 使您能夠使用 JAX 的完整功能。

Flax 的核心是 NNX - 一個簡化的 API,可以更輕鬆地在 JAX 中建立、檢查、偵錯和分析神經網路。 Flax NNX 對 Python 參考語義具有一流的支援,讓使用者可以使用常規 Python 物件來表示他們的模型。Flax NNX 是先前 Flax Linen API 的演進,並且經過多年的經驗才帶來更簡單且更易於使用的 API。

注意

Flax Linen API 在近期內不會被棄用,因為大多數 Flax 使用者仍然依賴此 API。但是,建議新使用者使用 Flax NNX。請查看 為何選擇 Flax NNX 以比較 Flax NNX 和 Linen,以及我們創建新 API 的原因。

若要將您的 Flax Linen 程式碼庫移至 Flax NNX,請先熟悉 NNX 基礎中的 API,然後按照演進指南開始您的遷移。

功能#

Pythonic

Flax NNX 支援使用常規 Python 物件,提供直觀且可預測的開發體驗。

簡單

Flax NNX 依賴 Python 的物件模型,這為使用者帶來了簡單性,並提高了開發速度。

表達力強

Flax NNX 允許透過其 篩選器系統精細控制模型的狀態。

熟悉

Flax NNX 可以透過 Functional API 非常輕鬆地將物件與常規 JAX 程式碼整合。

基本用法#

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit  # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates

  return loss

安裝#

透過 pip 安裝

pip install flax

或從儲存庫安裝最新版本

pip install git+https://github.com/google/flax.git

了解更多#

Flax NNX 基礎
nnx_basics.html
MNIST 教學
mnist_tutorial.html
Flax Linen 到 Flax NNX
guides/linen_to_nnx.html