Flax NNX 與 JAX 轉換的比較#
本指南描述了Flax NNX 轉換和JAX 轉換之間的差異,以及如何在它們之間無縫切換或並排使用。此處的範例將重點介紹 nnx.jit
、jax.jit
、nnx.grad
和 jax.grad
函數轉換 (transforms)。
首先,讓我們設定匯入並產生一些虛擬資料
from flax import nnx
import jax
x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))
差異#
Flax NNX 轉換可以轉換非純函數並進行修改和副作用:- Flax NNX 轉換使您能夠轉換將 Flax NNX 圖形物件作為引數的函數 - 例如 nnx.Module
、nnx.Rngs
、nnx.Optimizer
等等 - 甚至包括那些狀態將被修改的物件。- 相比之下,這些類型的物件在 JAX 轉換中無法被識別。
Flax NNX 函數式 API 提供了一種將圖形結構轉換為 pytrees 並返回的方法。透過在每個函數邊界執行此操作,您可以有效地將圖形結構與任何 JAX 轉換一起使用,並以與函數純度一致的方式傳播狀態更新。
Flax NNX 自訂轉換,例如 nnx.jit
和 nnx.grad
,只是移除了樣板程式碼,因此程式碼看起來像是有狀態的。
以下是一個使用 nnx.jit
和 nnx.grad
轉換的範例,與使用 jax.jit
和 jax.grad
轉換的程式碼進行比較。
請注意:
Flax NNX 轉換函數的函數簽名可以直接接受
nnx.Linear
nnx.Module
實例,並對Module
進行有狀態的更新。JAX 轉換函數的函數簽名只能接受 pytree 註冊的
nnx.State
和nnx.GraphDef
物件,並且必須傳回它們的更新副本,以保持轉換函數的純度。
@nnx.jit
def train_step(model, x, y):
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
def loss_fn(graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, argnums=1)(graphdef, state)
model = nnx.merge(graphdef, state)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
return nnx.split(model)
graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)
混合使用 Flax NNX 和 JAX 轉換#
Flax NNX 轉換和 JAX 轉換可以混合使用,只要程式碼中 JAX 轉換的函數是純函數,並且具有 JAX 可以識別的有效引數類型。
@nnx.jit
def train_step(model, x, y):
def loss_fn(graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, 1)(*nnx.split(model))
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
model = nnx.merge(graphdef, state)
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
return nnx.split(model)
graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)