在多個裝置上擴展#

本指南示範如何使用 Flax NNX Modules 在[多個裝置和主機](Multi-host and multi-process environments)(例如 GPU、Google TPU 和 CPU)上擴展,使用 JAX 即時編譯機制 (jax.jit)flax.nnx.spmd

概述#

Flax 依賴 JAX 進行數值計算,並在多個裝置(例如 GPU 和 Google TPU)上擴展計算。擴展的核心是 JAX 即時 (jax.jit) 編譯器 jax.jit。在本指南中,您將使用 Flax 自己的 nnx.jit 轉換,它會包裝 jax.jit,並且可以更方便地與 Flax NNX Modules 搭配使用。

注意:若要了解更多關於 Flax 的轉換(例如 nnx.jitnnx.vmap)的資訊,請前往 為什麼選擇 Flax NNX? - 轉換轉換,以及 Flax NNX 與 JAX 轉換

JAX 編譯遵循 單一程式多重資料 (SPMD) 範式。這表示您編寫 Python 程式碼,就像它只在一個裝置上執行一樣,並且 jax.jit 將會自動編譯在多個裝置上執行它。

為了確保編譯效能,您通常需要指示 JAX 如何在裝置之間分片模型的變數。這就是 Flax NNX 的分片中繼資料 API - flax.nnx.spmd - 的用武之地。它可以協助您使用此資訊註解模型的變數。

Flax Linen 使用者請注意flax.nnx.spmd API 類似於 Linen Flax on (p)jit 指南在模型定義層級中描述的內容。然而,由於 Flax NNX 所帶來的優勢,Flax NNX 中的頂層程式碼更簡單,且某些文字說明會更加更新和清晰。

如果您是 JAX 中平行處理的新手,您可以在以下教學中了解更多關於其擴展 API 的資訊

設定#

匯入一些必要的相依性。

注意:本指南使用 --xla_force_host_platform_device_count=8 旗標,以在 Google Colab/Jupyter Notebook 的 CPU 環境中模擬多個裝置。如果您已經在使用多裝置 TPU 環境,則不需要此旗標。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from typing import *

import numpy as np
import jax
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

from flax import nnx

import optax # Optax for common losses and optimizers.
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
You have 8 “fake” JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]

以下程式碼展示如何匯入和設定 JAX 層級的裝置 API,遵循 JAX 的 分散式陣列和自動平行化指南

  1. 使用 JAX jax.sharding.Mesh 啟動 2x4 裝置 mesh (8 個裝置)。此配置與 TPU v3-8 (也是 8 個裝置) 相同。

  2. 使用 axis_names 參數以名稱註解每個軸。註解軸名稱的典型方式是 axis_name=('data', 'model'),其中

  • 'data':用於輸入和激活的批次維度資料平行分片的網格維度。

  • 'model':用於在裝置之間分片模型參數的網格維度。

# Create a mesh of two dimensions and annotate each axis with a name.
mesh = Mesh(devices=np.array(jax.devices()).reshape(2, 4),
            axis_names=('data', 'model'))
print(mesh)
Mesh('data': 2, 'model': 4)

定義具有指定分片的模型#

接下來,建立一個名為 DotReluDot 的範例層,它會子類別化 Flax nnx.Module

  • 此層在輸入 x 上執行兩個點積乘法,並在其中間使用 jax.nn.relu (ReLU) 激活函數。

  • 若要使用其理想的分片來註解模型變數,您可以使用 flax.nnx.with_partitioning 包裝其初始化器函數。基本上,這會呼叫 flax.nnx.with_metadata,它會將 .sharding 屬性欄位新增至對應的 nnx.Variable

注意:此註解將會在 Flax NNX 中跨提升轉換適當保留和調整。這表示如果您使用分片註解以及任何修改軸的轉換 (例如 nnx.vmapnnx.scan),您需要透過 transform_metadata 引數來提供該額外軸的分片。請查看Flax NNX 轉換 (transforms) 指南以了解更多資訊。

class DotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1` and annotate its kernel with.
    # `sharding (None, 'model')`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
      use_bias=False,  # or use `bias_init` to give it annotation too
      rngs=rngs)

    # Initialize a weight param `w2` and annotate with sharding ('model', None).
    # Note that this is simply adding `.sharding` to the variable as metadata!
    self.w2 = nnx.Param(
      init_fn(rngs.params(), (depth, depth)),  # RNG key and shape for W2 creation
      sharding=('model', None),
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # In data parallelism, input / intermediate value's first dimension (batch)
    # will be sharded on `data` axis
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', 'model'))
    z = jnp.dot(y, self.w2.value)
    return z

了解分片名稱#

所謂的「分片註解」基本上是裝置軸名稱的元組,例如 'data''model'None。這描述了此 JAX 陣列的每個維度應如何分片 — 要麼跨越其中一個裝置網格維度分片,要麼完全不分片。

因此,當您定義形狀為 (depth, depth) 且註解為 (None, 'model')W1

  • 第一個維度會在所有裝置之間複寫。

  • 第二個維度將會根據裝置網格的 'model' 軸進行分片。這表示 W1 將在這個維度上於裝置 (0, 4)(1, 5)(2, 6)(3, 7) 上進行 4 向分片。

JAX 的分散式陣列與自動平行化指南提供了更多範例和說明。

初始化分片模型#

現在,您已經將註釋附加到 Flax 的 nnx.Variable,但實際的權重尚未分片。如果您直接建立這個模型,所有的 jax.Arrays 仍會卡在裝置 0 上。實際上,您會希望避免這種情況,因為大型模型在這種情況下會「OOM」(導致裝置記憶體不足),而所有其他裝置都沒有被利用。

unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))

# You have annotations stuck there, yay!
print(unsharded_model.dot1.kernel.sharding)     # (None, 'model')
print(unsharded_model.w2.sharding)              # ('model', None)

# But the actual arrays are not sharded?
print(unsharded_model.dot1.kernel.value.sharding)  # SingleDeviceSharding
print(unsharded_model.w2.value.sharding)           # SingleDeviceSharding
(None, 'model')
('model', None)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

在這裡,您應該透過 Flax 的 nnx.jit 來利用 JAX 的編譯機制來建立分片模型。關鍵是在 jit 函式內初始化模型,並在模型狀態上指定分片。

  1. 使用 nnx.get_partition_spec 來剝離附加在模型變數上的 .sharding 註釋。

  2. 呼叫 jax.lax.with_sharding_constraint 將模型狀態與分片註釋綁定。這個 API 會告訴頂層 jit 如何分片變數!

  3. 丟棄未分片的狀態,並根據分片的狀態傳回模型。

  4. 使用 nnx.jit 編譯整個函式,這允許輸出為一個有狀態的 Flax NNX Module

  5. 在裝置網格環境下執行它,以便 JAX 知道要將其分片到哪些裝置。

整個編譯後的 create_sharded_model() 函式將直接產生一個具有分片 JAX 陣列的模型,並且不會發生單裝置的「OOM」!

@nnx.jit
def create_sharded_model():
  model = DotReluDot(1024, rngs=nnx.Rngs(0)) # Unsharded at this moment.
  state = nnx.state(model)                   # The model's state, a pure pytree.
  pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)           # The model is sharded now!
  return model

with mesh:
  sharded_model = create_sharded_model()

# They are some `GSPMDSharding` now - not a single device!
print(sharded_model.dot1.kernel.value.sharding)
print(sharded_model.w2.value.sharding)

# Check out their equivalency with some easier-to-read sharding descriptions
assert sharded_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model',), memory_kind=unpinned_host)

您可以使用 jax.debug.visualize_array_sharding 來檢視任何 1D 或 2D 陣列的分片。

print("sharded_model.dot1.kernel (None, 'model') :")
jax.debug.visualize_array_sharding(sharded_model.dot1.kernel.value)
print("sharded_model.w2 ('model', None) :")
jax.debug.visualize_array_sharding(sharded_model.w2.value)
sharded_model.dot1.kernel (None, 'model') :
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
sharded_model.w2 ('model', None) :
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

關於 jax.lax.with_sharding_constraint(半自動平行化)#

分片 JAX 陣列的關鍵是在 jax.jit 編譯的函式內呼叫 jax.lax.with_sharding_constraint。請注意,如果不在 JAX 裝置網格環境下,它會拋出錯誤。

注意: JAX 文件中的平行程式設計簡介分散式陣列與自動平行化都更詳細地介紹了使用 jax.jit 進行自動平行化,以及使用 jax.jit`jax.lax.with_sharding_constraint 進行半自動平行化。

您可能已經注意到,您也在模型定義中使用了一次 jax.lax.with_sharding_constraint 來約束中間值的分片。這只是為了展示如果您想明確分片非模型變數的值,您可以始終與 Flax NNX API 正交地使用它。

這帶來了一個問題:那麼為什麼要使用 Flax NNX 註釋 API?為什麼不只是在模型定義中加入 JAX 分片約束?最重要的原因是,您仍然需要明確的註釋才能從磁碟上的檢查點載入分片模型。這將在下一節中說明。

從檢查點載入分片模型#

現在您已經學會如何在沒有 OOM 的情況下初始化分片模型,但是如何從磁碟上的檢查點載入它呢?JAX 檢查點程式庫,例如 Orbax,通常支援在提供分片 pytree 的情況下載入分片模型。

您可以使用 Flax 的 nnx.get_named_sharding 來產生這樣的分片 pytree。為了避免任何實際的記憶體配置,請使用 nnx.eval_shape 轉換來產生一個抽象 JAX 陣列的模型,並且只使用其 .sharding 註釋來獲取分片樹。

以下是一個示範使用 Orbax 的 StandardCheckpointer API 的範例。(請前往 Orbax 文件網站以了解其最新和最推薦的 API。)

import orbax.checkpoint as ocp

# Save the sharded state.
sharded_state = nnx.state(sharded_model)
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(path / 'checkpoint_name', sharded_state)

# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`.
abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))
abs_state = nnx.state(abs_model)
# Orbax API expects a tree of abstract `jax.ShapeDtypeStruct`
# that contains both sharding and the shape/dtype of the arrays.
abs_state = jax.tree.map(
  lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
  abs_state, nnx.get_named_sharding(abs_state, mesh)
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
                                      target=abs_state)
jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
jax.debug.visualize_array_sharding(loaded_sharded.w2.value)
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

編譯訓練迴圈#

現在,在初始化或載入檢查點後,您有一個分片模型。為了執行編譯後的擴展訓練,您還需要分片輸入。

  • 在資料平行化的範例中,訓練資料的批次維度會根據 data 裝置軸進行分片,因此您應該將資料放入 ('data', None) 的分片中。您可以使用 jax.device_put 來執行此操作。

  • 請注意,對於所有輸入都進行正確的分片後,即使沒有 jit 編譯,輸出也會以最自然的方式進行分片。

  • 在下面的範例中,即使在輸出 y 上沒有 jax.lax.with_sharding_constraint,它仍然被分片為 ('data', None)

如果您有興趣了解原因:DotReluDot.__call__ 的第二個 matmul 有兩個分片為 ('data', 'model')('model', None) 的輸入,其中兩個輸入的收縮軸都是 model。因此,發生了 reduce-scatter matmul,並且自然會將輸出分片為 ('data', None)。如果您想在低階數學層面上了解它是如何發生的,請查看 JAX 分片映射集體指南及其範例。

# In data parallelism, the first dimension (batch) will be sharded on the `data` axis.
data_sharding = NamedSharding(mesh, PartitionSpec('data', None))
input = jax.device_put(jnp.ones((8, 1024)), data_sharding)

with mesh:
  output = sharded_model(input)
print(output.shape)
jax.debug.visualize_array_sharding(output)  # Also sharded as `('data', None)`.
(8, 1024)
                                                                                
                                                                                
                                  CPU 0,1,2,3                                   
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                  CPU 4,5,6,7                                   
                                                                                
                                                                                
                                                                                

現在,訓練迴圈的其餘部分非常傳統 - 它幾乎與 Flax NNX Basics 中的範例相同。

  • 只是輸入和標籤也明確地分片了。

  • nnx.jit 將根據其輸入的現有分片方式調整並自動選擇最佳佈局,因此請嘗試針對您自己的模型和輸入進行不同的分片。

optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3))  # reference sharing

@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model: DotReluDot):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)

  return loss

input = jax.device_put(jax.random.normal(jax.random.key(1), (8, 1024)), data_sharding)
label = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sharding)

with mesh:
  for i in range(5):
    loss = train_step(sharded_model, optimizer, input, label)
    print(loss)    # Model (over-)fitting to the labels quickly.
1.455235
0.7646729
0.50971293
0.378493
0.28089797

效能分析#

如果您使用的是 Google TPU pod 或 pod slice,您可以建立一個自訂的 block_all() 公用程式函式(如下定義)來測量效能。

%%timeit

def block_all(xs):
  jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(train_step(sharded_model, optimizer, input, label))
57.6 ms ± 569 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

邏輯軸註釋#

JAX 的自動 SPMD 鼓勵使用者探索不同的分片佈局,以找到最佳的佈局。為此,在 Flax 中,您可以選擇使用更具描述性的軸名稱進行註釋(而不僅僅是像 'data''model' 這樣的裝置網格軸名稱),只要您提供從別名到裝置網格軸的對應即可。

您可以將對應與註釋一起作為相應 nnx.Variable 的另一個中繼資料提供,或在頂層覆寫它。請查看下面的 LogicalDotReluDot() 範例。

# The mapping from alias annotation to the device mesh.
sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None))

class LogicalDotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_metadata(
        # Provide the sharding rules here.
        init_fn, sharding=('embed', 'hidden'), sharding_rules=sharding_rules),
      use_bias=False,
      rngs=rngs)

    # Initialize a weight param `w2`.
    self.w2 = nnx.Param(
      # Didn't provide the sharding rules here to show you how to overwrite it later.
      nnx.with_metadata(init_fn, sharding=('hidden', 'embed'))(
        rngs.params(), (depth, depth))
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # Unfortunately the logical aliasing doesn't work on lower-level JAX calls.
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', None))
    z = jnp.dot(y, self.w2.value)
    return z

如果您沒有在模型定義中提供所有 sharding_rule 註釋,您可以編寫幾行程式碼將其新增到 Flax 模型的 nnx.State 中,然後再呼叫 nnx.get_partition_specnnx.get_named_sharding

def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState:
  vs.sharding_rules = sharding_rules
  return vs

@nnx.jit
def create_sharded_logical_model():
  model = LogicalDotReluDot(1024, rngs=nnx.Rngs(0))
  state = nnx.state(model)
  state = jax.tree.map(add_sharding_rule, state,
                       is_leaf=lambda x: isinstance(x, nnx.VariableState))
  pspecs = nnx.get_partition_spec(state)
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)
  return model

with mesh:
  sharded_logical_model = create_sharded_logical_model()

jax.debug.visualize_array_sharding(sharded_logical_model.dot1.kernel.value)
jax.debug.visualize_array_sharding(sharded_logical_model.w2.value)

# Check out their equivalency with some easier-to-read sharding descriptions.
assert sharded_logical_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_logical_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)

with mesh:
  logical_output = sharded_logical_model(input)
  assert logical_output.sharding.is_equivalent_to(
    NamedSharding(mesh, PartitionSpec('data', None)), ndim=2
  )
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

何時使用裝置軸 / 邏輯軸#

選擇何時使用裝置軸或邏輯軸取決於您想要對模型分割的控制程度。

  • 裝置網格軸:

    • 對於更簡單的模型,這可以節省一些額外的程式碼行,將邏輯命名轉換回裝置命名。

    • 中間激活值的分割只能透過 jax.lax.with_sharding_constraint 和裝置網格軸來完成。因此,如果您想要對模型的分割進行超精細的控制,直接在各處使用裝置網格軸名稱可能會比較不容易混淆。

  • 邏輯命名:如果您想要嘗試並找到模型權重的最佳分割佈局,這會很有幫助。