Flax NNX 與 JAX 轉換的比較

Flax NNX 與 JAX 轉換的比較#

本指南描述了Flax NNX 轉換JAX 轉換之間的差異,以及如何在它們之間無縫切換或並排使用。此處的範例將重點介紹 nnx.jitjax.jitnnx.gradjax.grad 函數轉換 (transforms)。

首先,讓我們設定匯入並產生一些虛擬資料

from flax import nnx
import jax

x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))

差異#

Flax NNX 轉換可以轉換非純函數並進行修改和副作用:- Flax NNX 轉換使您能夠轉換將 Flax NNX 圖形物件作為引數的函數 - 例如 nnx.Modulennx.Rngsnnx.Optimizer 等等 - 甚至包括那些狀態將被修改的物件。- 相比之下,這些類型的物件在 JAX 轉換中無法被識別。

Flax NNX 函數式 API 提供了一種將圖形結構轉換為 pytrees 並返回的方法。透過在每個函數邊界執行此操作,您可以有效地將圖形結構與任何 JAX 轉換一起使用,並以與函數純度一致的方式傳播狀態更新。

Flax NNX 自訂轉換,例如 nnx.jitnnx.grad,只是移除了樣板程式碼,因此程式碼看起來像是有狀態的。

以下是一個使用 nnx.jitnnx.grad 轉換的範例,與使用 jax.jitjax.grad 轉換的程式碼進行比較。

請注意:

  • Flax NNX 轉換函數的函數簽名可以直接接受 nnx.Linear nnx.Module 實例,並對 Module 進行有狀態的更新。

  • JAX 轉換函數的函數簽名只能接受 pytree 註冊的 nnx.Statennx.GraphDef 物件,並且必須傳回它們的更新副本,以保持轉換函數的純度。

@nnx.jit
def train_step(model, x, y):
  def loss_fn(model):
    return ((model(x) - y) ** 2).mean()
  grads = nnx.grad(loss_fn)(model)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
  def loss_fn(graphdef, state):
    model = nnx.merge(graphdef, state)
    return ((model(x) - y) ** 2).mean()
  grads = jax.grad(loss_fn, argnums=1)(graphdef, state)

  model = nnx.merge(graphdef, state)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)
  return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)

混合使用 Flax NNX 和 JAX 轉換#

Flax NNX 轉換和 JAX 轉換可以混合使用,只要程式碼中 JAX 轉換的函數是純函數,並且具有 JAX 可以識別的有效引數類型。

@nnx.jit
def train_step(model, x, y):
  def loss_fn(graphdef, state):
    model = nnx.merge(graphdef, state)
    return ((model(x) - y) ** 2).mean()
  grads = jax.grad(loss_fn, 1)(*nnx.split(model))
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
  model = nnx.merge(graphdef, state)
  def loss_fn(model):
    return ((model(x) - y) ** 2).mean()
  grads = nnx.grad(loss_fn)(model)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)
  return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)