Flax 基礎#

Flax NNX 是一個新的簡化 API,旨在讓在 JAX 中創建、檢查、除錯和分析神經網路變得更容易。它透過新增對 Python 參考語義的一流支援來實現這一點。這允許使用者使用常規的 Python 物件來表達他們的模型,這些物件被建模為 PyGraph(而不是 pytrees),從而實現參考共享和可變性。這種 API 設計應該讓 PyTorch 或 Keras 的使用者感到賓至如歸。

首先,使用 pip 安裝 Flax 並導入必要的依賴項

# ! pip install -U flax
from flax import nnx
import jax
import jax.numpy as jnp

Flax NNX 模組系統#

Flaxnnx.ModuleFlax LinenHaiku 中其他 Module 系統的主要區別在於,在 NNX 中一切都是顯式的。這意味著,除了其他事項之外,nnx.Module 本身直接持有狀態(例如參數),PRNG 狀態由使用者線程化,並且所有形狀資訊必須在初始化時提供(沒有形狀推斷)。

讓我們從建立一個 Linear nnx.Module 開始。如下所示,動態狀態通常儲存在 nnx.Param 中,而靜態狀態(所有 NNX 未處理的類型),例如整數或字串,則直接儲存。類型為 jax.Arraynumpy.ndarray 的屬性也被視為動態狀態,儘管將它們儲存在 nnx.Variable 中,例如 Param,是首選的方法。此外,nnx.Rngs 物件可用於根據傳遞給建構函式的根 PRNG 金鑰獲取新的唯一金鑰。

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

另請注意,可以使用 value 屬性存取 nnx.Variable 的內部值,但為了方便起見,它們實現了所有數值運算符,並且可以直接在算術表達式中使用(如上面的程式碼所示)。

要初始化 Flax nnx.Module,您只需呼叫建構函式,並且通常會及早建立 Module 的所有參數。由於 nnx.Module 持有自己的狀態方法,因此您可以直接呼叫它們,而無需單獨的 apply 方法。這對於除錯非常方便,讓您可以直接檢查模型的整個結構。

model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))

print(y)
nnx.display(model)
[[1.245453   0.74195766 0.8553282  0.6763327  1.2617068 ]]

上述由 nnx.display 產生的視覺化是使用很棒的 Treescope 函式庫產生的。

具狀態的計算#

實作層,例如 nnx.BatchNorm,需要在前向傳遞期間執行狀態更新。在 Flax NNX 中,您只需要建立一個 nnx.Variable 並在前向傳遞期間更新其 .value

class Count(nnx.Variable): pass

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0))

  def __call__(self):
    self.count += 1

counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)

通常在 JAX 中會避免可變參考。但是,Flax NNX 提供了健全的機制來處理它們,如本指南的後續章節所示。

巢狀模組#

Flax nnx.Module 可以用於在巢狀結構中組合其他 Module。這些可以直接作為屬性分配,也可以分配在任何(巢狀)pytree 類型的屬性內,例如 listdicttuple 等。

下面的範例展示了如何透過子類化 nnx.Module 來定義一個簡單的 MLP。該模型由兩個 Linear 層、一個 nnx.Dropout 層和一個 nnx.BatchNorm 層組成。

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

model = MLP(2, 16, 5, rngs=nnx.Rngs(0))

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

在 Flax 中,nnx.Dropout 是一個有狀態的模組,它儲存一個 nnx.Rngs 物件,以便它可以在前向傳遞期間生成新的遮罩,而無需使用者每次傳遞新的金鑰。

模型手術#

預設情況下,Flax nnx.Module 是可變的。這意味著它們的結構可以隨時更改,這使得 模型手術 相當容易,因為任何子 Module 屬性都可以替換為任何其他屬性,例如新的 Module、現有的共享 Module、不同類型的 Module 等。此外,nnx.Variable 也可以修改或替換/共享。

下面的範例展示了如何將先前範例中 MLP 模型中的 Linear 層替換為 LoraLinear

class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
    self.linear = linear
    self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
    self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)

# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

Flax 轉換#

Flax NNX 轉換(transforms) 擴展了 JAX 轉換,以支援 nnx.Module 和其他物件。它們作為其等效 JAX 對應項的超集,並額外具有感知物件狀態和提供其他 API 來轉換物件的能力。

Flax 轉換的主要功能之一是保留參考語義,這意味著只要在轉換規則內合法,在轉換內部發生的任何物件圖變更都會傳播到外部。在實踐中,這意味著可以使用命令式程式碼表達 Flax 程式,從而大大簡化使用者體驗。

在下面的範例中,您定義一個 train_step 函式,該函式接受一個 MLP 模型、一個 nnx.Optimizer 和一批資料,並傳回該步驟的損失。損失和梯度是使用 nnx.value_and_grad 轉換在 loss_fn 上計算的。梯度會傳遞給優化器的 nnx.Optimizer.update 方法,以更新 model 的參數。

import optax

# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit  # Automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # In place updates.

  return loss

x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)

print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000255, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)

在此範例中發生了兩件事值得一提

  1. 對每個 nnx.BatchNormnnx.Dropout 層狀態的更新會自動從 loss_fn 內部傳播到 train_step,一直到外部的 model 參考。

  2. optimizer 持有 model 的可變參考 - 這種關係保留在 train_step 函式內部,從而可以使用優化器單獨更新模型的參數。

注意
對於小型模型,nnx.jit 會有效能上的額外負擔,請參閱效能考量指南以取得更多資訊。

掃描多層 (Scan over layers)#

下一個範例使用 Flax 的 nnx.vmap 來建立多個 MLP 層的堆疊,並使用 nnx.scan 來將堆疊的每一層迭代地應用到輸入。

在下面的程式碼中,請注意以下幾點:

  1. 自訂的 create_model 函數接收一個 key 並回傳一個 MLP 物件。由於您建立了五個 key 並在 create_model 上使用了 nnx.vmap,因此會建立一個包含 5 個 MLP 物件的堆疊。

  2. nnx.scan 用於將堆疊中的每個 MLP 迭代地應用到輸入 x

  3. nnx.scan (有意識地) 與 jax.lax.scan 不同,而是模仿 nnx.vmap,這樣更具表達力。nnx.scan 允許指定多個輸入、每個輸入/輸出的掃描軸,以及 carry 的位置。

  4. 對於 nnx.BatchNormnnx.Dropout 層的 State 更新,會由 nnx.scan 自動傳播。

@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
  return MLP(10, 32, 10, rngs=nnx.Rngs(key))

keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)

@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
  x = model(x)
  return x

x = jnp.ones((3, 10))
y = forward(model, x)

print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)

Flax NNX 轉換是如何實現這一點的?為了了解 Flax NNX 物件如何與 JAX 轉換互動,下一節將說明 Flax NNX Functional API。

Flax Functional API#

Flax NNX Functional API 在參考/物件語義和值/pytree 語義之間建立了明確的界限。它還允許對 Flax Linen 和 Haiku 使用者習慣的狀態進行相同程度的細粒度控制。Flax NNX Functional API 包含三個基本方法:nnx.splitnnx.mergennx.update

以下是一個使用 Functional API 的 StatefulLinear nnx.Module 的範例。它包含:

class Count(nnx.Variable): pass

class StatefulLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.count = Count(jnp.array(0, dtype=jnp.uint32))

  def __call__(self, x: jax.Array):
    self.count += 1
    return x @ self.w + self.b

model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))

nnx.display(model)

State 和 GraphDef#

可以使用 nnx.split 函數將 Flax 的 nnx.Module 分解為 nnx.Statennx.GraphDef

graphdef, state = nnx.split(model)

nnx.display(graphdef, state)

拆分、合併和更新#

Flax 的 nnx.mergennx.split 的反向操作。它接收 nnx.GraphDef + nnx.State 並重建 nnx.Module。以下範例示範了這一點:

  • 透過依序使用 nnx.splitnnx.merge,任何 Module 都可以被提升以在任何 JAX 轉換中使用。

  • nnx.update 可以使用給定的 nnx.State 的內容就地更新物件。

  • 此模式用於將狀態從轉換傳播回外部的來源物件。

print(f'{model.count.value = }')

# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)

@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
  # 2. Use `nnx.merge` to create a new model inside the JAX transformation.
  model = nnx.merge(graphdef, state)
  # 3. Call the `nnx.Module`
  y = model(x)
  # 4. Use `nnx.split` to propagate `nnx.State` updates.
  _, state = nnx.split(model)
  return y, state

y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)

print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)

此模式的關鍵洞察是,在轉換內容 (包括基礎的 eager 解釋器) 中使用可變參考是可以的,但在跨越邊界時必須使用 Functional API。

為什麼模組不只是 pytree? 主要原因是很容易意外地失去共享參考的追蹤,例如,如果您透過 JAX 邊界傳遞兩個具有共享 Modulennx.Module,您將會悄悄地失去該共享。Flax 的 Functional API 使此行為明確化,因此更容易推理。

細粒度狀態控制#

有經驗的 Flax LinenHaiku API 使用者可能會發現,將所有狀態放在單一結構中並不總是最好的選擇,因為在某些情況下,您可能希望以不同的方式處理狀態的不同子集。這是在與 JAX 轉換互動時常見的情況。

例如:

  • 當與 jax.grad 互動時,並非每個模型狀態都可以或應該被區分。

  • 或者,有時,在使用 jax.lax.scan 時,需要指定模型狀態的哪一部分是 carry,哪一部分不是 carry。

為了處理這個問題,Flax NNX API 具有 nnx.split,它允許您傳遞一個或多個 nnx.filterlib.Filters 來將 nnx.Variables 分割成互斥的 nnx.States。Flax NNX 在 API 中使用 Filter 建立 State 群組 (例如 nnx.splitnnx.state() 和許多 NNX 轉換)。

下面的範例顯示了最常見的 Filters:

# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)

nnx.display(params, counts)

注意: nnx.filterlib.Filters 必須是詳盡的,如果一個值沒有被匹配,將會引發錯誤。

如預期般,nnx.mergennx.update 方法自然地會使用多個 State

# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)