Flax 基礎#
Flax NNX 是一個新的簡化 API,旨在讓在 JAX 中創建、檢查、除錯和分析神經網路變得更容易。它透過新增對 Python 參考語義的一流支援來實現這一點。這允許使用者使用常規的 Python 物件來表達他們的模型,這些物件被建模為 PyGraph(而不是 pytrees),從而實現參考共享和可變性。這種 API 設計應該讓 PyTorch 或 Keras 的使用者感到賓至如歸。
首先,使用 pip
安裝 Flax 並導入必要的依賴項
# ! pip install -U flax
from flax import nnx
import jax
import jax.numpy as jnp
Flax NNX 模組系統#
Flaxnnx.Module
與 Flax Linen 或 Haiku 中其他 Module
系統的主要區別在於,在 NNX 中一切都是顯式的。這意味著,除了其他事項之外,nnx.Module
本身直接持有狀態(例如參數),PRNG 狀態由使用者線程化,並且所有形狀資訊必須在初始化時提供(沒有形狀推斷)。
讓我們從建立一個 Linear
nnx.Module
開始。如下所示,動態狀態通常儲存在 nnx.Param
中,而靜態狀態(所有 NNX 未處理的類型),例如整數或字串,則直接儲存。類型為 jax.Array
和 numpy.ndarray
的屬性也被視為動態狀態,儘管將它們儲存在 nnx.Variable
中,例如 Param
,是首選的方法。此外,nnx.Rngs
物件可用於根據傳遞給建構函式的根 PRNG 金鑰獲取新的唯一金鑰。
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
另請注意,可以使用 value
屬性存取 nnx.Variable
的內部值,但為了方便起見,它們實現了所有數值運算符,並且可以直接在算術表達式中使用(如上面的程式碼所示)。
要初始化 Flax nnx.Module
,您只需呼叫建構函式,並且通常會及早建立 Module
的所有參數。由於 nnx.Module
持有自己的狀態方法,因此您可以直接呼叫它們,而無需單獨的 apply
方法。這對於除錯非常方便,讓您可以直接檢查模型的整個結構。
model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))
print(y)
nnx.display(model)
[[1.245453 0.74195766 0.8553282 0.6763327 1.2617068 ]]
上述由 nnx.display
產生的視覺化是使用很棒的 Treescope 函式庫產生的。
具狀態的計算#
實作層,例如 nnx.BatchNorm
,需要在前向傳遞期間執行狀態更新。在 Flax NNX 中,您只需要建立一個 nnx.Variable
並在前向傳遞期間更新其 .value
。
class Count(nnx.Variable): pass
class Counter(nnx.Module):
def __init__(self):
self.count = Count(jnp.array(0))
def __call__(self):
self.count += 1
counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)
通常在 JAX 中會避免可變參考。但是,Flax NNX 提供了健全的機制來處理它們,如本指南的後續章節所示。
巢狀模組#
Flax nnx.Module
可以用於在巢狀結構中組合其他 Module
。這些可以直接作為屬性分配,也可以分配在任何(巢狀)pytree 類型的屬性內,例如 list
、dict
、tuple
等。
下面的範例展示了如何透過子類化 nnx.Module
來定義一個簡單的 MLP
。該模型由兩個 Linear
層、一個 nnx.Dropout
層和一個 nnx.BatchNorm
層組成。
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = Linear(dmid, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
model = MLP(2, 16, 5, rngs=nnx.Rngs(0))
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
在 Flax 中,nnx.Dropout
是一個有狀態的模組,它儲存一個 nnx.Rngs
物件,以便它可以在前向傳遞期間生成新的遮罩,而無需使用者每次傳遞新的金鑰。
模型手術#
預設情況下,Flax nnx.Module
是可變的。這意味著它們的結構可以隨時更改,這使得 模型手術 相當容易,因為任何子 Module
屬性都可以替換為任何其他屬性,例如新的 Module
、現有的共享 Module
、不同類型的 Module
等。此外,nnx.Variable
也可以修改或替換/共享。
下面的範例展示了如何將先前範例中 MLP
模型中的 Linear
層替換為 LoraLinear
層
class LoraParam(nnx.Param): pass
class LoraLinear(nnx.Module):
def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
self.linear = linear
self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))
def __call__(self, x: jax.Array):
return self.linear(x) + x @ self.A @ self.B
rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)
# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
Flax 轉換#
Flax NNX 轉換(transforms) 擴展了 JAX 轉換,以支援 nnx.Module
和其他物件。它們作為其等效 JAX 對應項的超集,並額外具有感知物件狀態和提供其他 API 來轉換物件的能力。
Flax 轉換的主要功能之一是保留參考語義,這意味著只要在轉換規則內合法,在轉換內部發生的任何物件圖變更都會傳播到外部。在實踐中,這意味著可以使用命令式程式碼表達 Flax 程式,從而大大簡化使用者體驗。
在下面的範例中,您定義一個 train_step
函式,該函式接受一個 MLP
模型、一個 nnx.Optimizer
和一批資料,並傳回該步驟的損失。損失和梯度是使用 nnx.value_and_grad
轉換在 loss_fn
上計算的。梯度會傳遞給優化器的 nnx.Optimizer.update
方法,以更新 model
的參數。
import optax
# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
@nnx.jit # Automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # In place updates.
return loss
x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)
print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000255, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)
在此範例中發生了兩件事值得一提
對每個
nnx.BatchNorm
和nnx.Dropout
層狀態的更新會自動從loss_fn
內部傳播到train_step
,一直到外部的model
參考。optimizer
持有model
的可變參考 - 這種關係保留在train_step
函式內部,從而可以使用優化器單獨更新模型的參數。
注意
對於小型模型,nnx.jit
會有效能上的額外負擔,請參閱效能考量指南以取得更多資訊。
掃描多層 (Scan over layers)#
下一個範例使用 Flax 的 nnx.vmap
來建立多個 MLP 層的堆疊,並使用 nnx.scan
來將堆疊的每一層迭代地應用到輸入。
在下面的程式碼中,請注意以下幾點:
自訂的
create_model
函數接收一個 key 並回傳一個MLP
物件。由於您建立了五個 key 並在create_model
上使用了nnx.vmap
,因此會建立一個包含 5 個MLP
物件的堆疊。nnx.scan
用於將堆疊中的每個MLP
迭代地應用到輸入x
。nnx.scan
(有意識地) 與jax.lax.scan
不同,而是模仿nnx.vmap
,這樣更具表達力。nnx.scan
允許指定多個輸入、每個輸入/輸出的掃描軸,以及 carry 的位置。對於
nnx.BatchNorm
和nnx.Dropout
層的State
更新,會由nnx.scan
自動傳播。
@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
return MLP(10, 32, 10, rngs=nnx.Rngs(key))
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
x = model(x)
return x
x = jnp.ones((3, 10))
y = forward(model, x)
print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)
Flax NNX 轉換是如何實現這一點的?為了了解 Flax NNX 物件如何與 JAX 轉換互動,下一節將說明 Flax NNX Functional API。
Flax Functional API#
Flax NNX Functional API 在參考/物件語義和值/pytree 語義之間建立了明確的界限。它還允許對 Flax Linen 和 Haiku 使用者習慣的狀態進行相同程度的細粒度控制。Flax NNX Functional API 包含三個基本方法:nnx.split
、nnx.merge
和 nnx.update
。
以下是一個使用 Functional API 的 StatefulLinear
nnx.Module
的範例。它包含:
一些
nnx.Param
nnx.Variable
s;以及一個自訂的
Count()
nnx.Variable
類型,用於追蹤每次正向傳遞時都會增加的整數純量狀態。
class Count(nnx.Variable): pass
class StatefulLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(jnp.array(0, dtype=jnp.uint32))
def __call__(self, x: jax.Array):
self.count += 1
return x @ self.w + self.b
model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))
nnx.display(model)
State 和 GraphDef#
可以使用 nnx.split
函數將 Flax 的 nnx.Module
分解為 nnx.State
和 nnx.GraphDef
。
nnx.State
是一個從字串到nnx.Variable
或巢狀State
的Mapping
。nnx.GraphDef
包含重建nnx.Module
圖形所需的所有靜態資訊,它類似於 JAX 的 JAX 的PyTreeDef
。
graphdef, state = nnx.split(model)
nnx.display(graphdef, state)
拆分、合併和更新#
Flax 的 nnx.merge
是 nnx.split
的反向操作。它接收 nnx.GraphDef
+ nnx.State
並重建 nnx.Module
。以下範例示範了這一點:
透過依序使用
nnx.split
和nnx.merge
,任何Module
都可以被提升以在任何 JAX 轉換中使用。nnx.update
可以使用給定的nnx.State
的內容就地更新物件。此模式用於將狀態從轉換傳播回外部的來源物件。
print(f'{model.count.value = }')
# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)
@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
# 2. Use `nnx.merge` to create a new model inside the JAX transformation.
model = nnx.merge(graphdef, state)
# 3. Call the `nnx.Module`
y = model(x)
# 4. Use `nnx.split` to propagate `nnx.State` updates.
_, state = nnx.split(model)
return y, state
y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)
print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)
此模式的關鍵洞察是,在轉換內容 (包括基礎的 eager 解釋器) 中使用可變參考是可以的,但在跨越邊界時必須使用 Functional API。
為什麼模組不只是 pytree? 主要原因是很容易意外地失去共享參考的追蹤,例如,如果您透過 JAX 邊界傳遞兩個具有共享 Module
的 nnx.Module
,您將會悄悄地失去該共享。Flax 的 Functional API 使此行為明確化,因此更容易推理。
細粒度狀態控制#
有經驗的 Flax Linen 或 Haiku API 使用者可能會發現,將所有狀態放在單一結構中並不總是最好的選擇,因為在某些情況下,您可能希望以不同的方式處理狀態的不同子集。這是在與 JAX 轉換互動時常見的情況。
例如:
當與
jax.grad
互動時,並非每個模型狀態都可以或應該被區分。或者,有時,在使用
jax.lax.scan
時,需要指定模型狀態的哪一部分是 carry,哪一部分不是 carry。
為了處理這個問題,Flax NNX API 具有 nnx.split
,它允許您傳遞一個或多個 nnx.filterlib.Filter
s 來將 nnx.Variable
s 分割成互斥的 nnx.State
s。Flax NNX 在 API 中使用 Filter
建立 State
群組 (例如 nnx.split
、nnx.state()
和許多 NNX 轉換)。
下面的範例顯示了最常見的 Filter
s:
# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
nnx.display(params, counts)
注意: nnx.filterlib.Filter
s 必須是詳盡的,如果一個值沒有被匹配,將會引發錯誤。
如預期般,nnx.merge
和 nnx.update
方法自然地會使用多個 State
。
# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)