效能考量#

目前,Flax nnx.jit 在純 Python 中遍歷物件圖,這很慢且會增加開銷。這就是為什麼為了解決這個問題,Flax 團隊將開發一個名為 flaxlib 的 Rust 擴充功能,以加速 graph.py 中的一些遍歷邏輯。這將類似於 JAX 團隊透過引入 jaxlib 來解決標準 JAX pytrees 的類似問題(請參閱 Flax PR #4196 中的第一個步驟)。

然而,有兩件事需要考慮

非同步調度#

benchmarks/nnx_simple_training.py 中,我們正在增加層寬(每層的特徵),並測量使用 nnx.jitjax.jit 訓練的相同模型的總訓練時間。

如下圖所示,在達到一定的模型大小後,遍歷所花費的時間會被非同步調度完全吸收。當 Python 能夠完成目前的 for 迴圈步驟,並到達下一個 train_step,而 JAX 仍然沒有完成先前的 train_step 時,就會發生這種情況。

performance-graph

這意味著您只需要擔心小型模型的 nnx.jit 開銷。如果您正在使用小型模型,請查看下一節,了解如何消除開銷。

降低 Python 開銷#

要消除 Python 開銷,您可以結合使用常規的 jax.jitnnx.splitnnx.merge 來移出遍歷邏輯。

為了了解如何做到這一點,讓我們先建立以下簡單的 Model

from flax import nnx
import jax
import jax.numpy as jnp
import optax

class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

接下來,讓我們建立一個使用 nnx.jittrain_step() 函式,輸入 modeloptimizermetrics,它們都是 Flax NNX 物件

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)

@nnx.jit  # <== currently slow
def train_step(model, optimizer, metrics, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates
  metrics.update(loss=loss)

  return loss
  
for _ in range(10):
  x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
  loss = train_step(model, optimizer, metrics, x, y)

為了加快速度,在開始訓練迴圈之前,我們可以對 train_step() 的所有輸入 Flax NNX 物件使用 nnx.split 來建立 graphdefstate pytrees,這些 pytrees 的遍歷速度更快。

接下來,我們變更 train_step() 以接受 graphdefstate,並在 train_step() 的開頭和結尾使用 nnx.mergennx.split 來在物件及其 pytree 表示形式之間來回切換。儘管 nnx.splitnnx.merge 很慢,但這並不重要,因為它們只會在追蹤期間執行一次。

完成此設定後,我們可以將 train_step() 函式變更為使用 jax.jit 而不是 nnx.jit

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adamw(1e-3))  # reference sharing
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)
# split before training loop
graphdef, state = nnx.split((model, optimizer, metrics))

@jax.jit  # regular JAX
def train_step(graphdef, state, x, y):
  # merge at the beginning of the function
  model, optimizer, metrics = nnx.merge(graphdef, state)

  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)
  metrics.update(loss=loss)

  # split at the end of the function
  _, state = nnx.split((model, optimizer, metrics))

  # return new state
  return state, loss

for _ in range(10):
  x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
  state, loss = train_step(graphdef, state, x, y)

# update objects after training
nnx.update((model, optimizer, metrics), state)

請注意,我們只對 jit 執行此操作。您仍然可以使用其他 Flax 轉換,例如上述範例中顯示的 nnx.value_and_grad,因為它們的開銷已經被外部的 jit 吸收了。

在訓練迴圈完成後(或在需要時),我們可以使用 Flax nnx.update 將 Flax NNX 物件(例如 modeloptimizermetrics)更新為新的 state