儲存並載入檢查點#

本指南說明如何使用 Orbax 儲存和載入 Flax 檢查點。

Orbax 提供多樣儲存和載入模型資料的功能,您將在此文件了解這些功能。

  • 支援各種陣列類型和儲存格式

  • 非同步儲存以減少訓練等待時間

  • 建立過去檢查點的版本和自動分類記錄

  • 彈性的 transformations 以調整和載入舊檢查點

  • 基於 jax.sharding 的 API,可在多主機場景中儲存和載入


持續移轉至 Orbax

2023 年 7 月 30 日之後,Flax 的舊版 flax.training.checkpoints API 將停止使用,取而代之的是 Orbax

  • 如果您是 Flax 新使用者:請使用新的 orbax.checkpoint API,如下說明。

  • 如果您在專案中有舊版 flax.training.checkpoints 程式碼:考量下列選項

    • 將您的程式碼移轉至 Orbax(建議):依照 此移轉指南,將 API 呼叫移轉至 orbax.checkpoint API。

    • 自動使用 Orbax 後端:將 flax.config.update('flax_use_orbax_checkpointing', True) 加入您的專案,這將使 flax.training.checkpoints 呼叫自動使用 Orbax 後端來儲存您的檢查點。

      • 排程切換:此處將於 2023 年 5 月(暫定日期)後成為預設模式。

      • 如果您在自動移轉過程中遇到任何問題,請瀏覽 Orbax-as-backend 疑難排解區段


為了向後相容,本指南在 Flax 舊版的 flax.training.checkpoints API 中顯示了等效的 Orbax 呼叫。

如果您需要進一步了解 orbax.checkpoint,請參閱 Orbax 文件

設定#

安裝/升級 Flax 和 Orbax。如要使用支援 GPU/TPU 的 JAX 安裝,請前往 GitHub 上的此區段

注意:在執行 import jax 之前,請建立八個假裝置,以模仿 多位主機環境,並在這個筆記本中。請注意,這裡的匯入順序很重要。只有 CPU 後端才能使用 os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' 指令,這表示您無法在 Google Colab 中執行此筆記本時,用於 GPU/TPU 加速。如果您已在多個裝置上執行代碼(例如在 4x2 TPU 環境中),您可以跳過執行下一儲存格。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from typing import Optional, Any
import shutil

import numpy as np
import jax
from jax import random, numpy as jnp

import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint

import optax
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
ckpt_dir = '/tmp/flax_ckpt'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

儲存檢查點#

在 Orbax 和 Flax 中,您可以儲存和載入任何 JAX pytree。這不僅包括典型的 Python 和 NumPy 容器,還包括自 flax.struct.dataclass 延伸的客製化類別。這表示您可以儲存幾乎所有產生資料,不僅是您的模型參數,還有任何陣列/字典、元資料/組態等等。

首先,建立一個 pytree,內含許多資料結構和容器,並加以處理

# A simple model with one linear layer.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001)      # An Optax SGD optimizer.
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))

# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}

# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
ckpt
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695322343.254588       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': array([5, 3])},
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

使用 Orbax#

使用 orbax.checkpoint.PyTreeCheckpointer,直接儲存檢查點至 tmp/orbax/single_save 目錄。

注意:已提供一個選用的 save_args。建議使用此選項以提升效能,因為它會將 pytree 中較小陣列整理成一個大型檔案,而非多個較小的檔案。

from flax.training import orbax_utils

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)

接下來,您需要在 orbax.checkpoint.PyTreeCheckpointer 之上包住 orbax.checkpoint.CheckpointManager,才能使用版本控制和自動記帳功能。

此外,請提供 orbax.checkpoint.CheckpointManagerOptions,以客製化您的需求,如您偏好的刪除舊檢查點頻率和標準。請參閱 文件 以取得提供的所有選項清單。

orbax.checkpoint.CheckpointManager 應放置在訓練步驟之外的頂層,以管理您的儲存。

options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options)

# Inside a training loop
for step in range(5):
    # ... do your training
    checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args})

os.listdir('/tmp/flax_ckpt/orbax/managed')  # Because max_to_keep=2, only step 3 and 4 are retained
['4', '3']

使用舊版 API#

以下是使用舊版 Flax checkpointing 工具的儲存方法(請注意,與 orbax.checkpoint.CheckpointManagerOptions 相比,此方法提供的管理功能較少)。

# Import Flax Checkpoints.
from flax.training import checkpoints

checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=0,
                            overwrite=True,
                            keep=2)
'/tmp/flax_ckpt/flax-checkpointing/checkpoint_0'

復原檢查點#

使用 Orbax#

在 Orbax 中,對 orbax.checkpoint.PyTreeCheckpointerorbax.checkpoint.CheckpointManager 呼叫 .restore() 以原始 pytree 格式復原檢查點。

raw_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save')
raw_restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': {'opt_state': [None, None],
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

請注意,step 編號是 CheckpointManger 的必要條件。您也可以使用 .latest_step() 找出可用的最新步驟。

step = checkpoint_manager.latest_step()  # step = 4
checkpoint_manager.restore(step)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': {'opt_state': [None, None],
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

使用舊版 API#

請注意,由於正在遷移至 Orbax,flax.training.checkpointing.restore_checkpoint 可以自動辨識檢查點是否以舊版 Flax 格式或 Orbax 後端儲存,並正確復原 pytree。因此,加入 flax.config.update('flax_use_orbax_checkpointing', True) 仍然可以復原舊檢查點。

以下是使用舊版 API 復原檢查點的方法。

raw_restored = checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=None)
raw_restored
{'config': {'dimensions': array([5, 3])},
 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)},
 'model': {'opt_state': {'0': None, '1': None},
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

使用自訂資料類別復原#

使用 Orbax#

  • 在先前的範例中復原的 pytrees 是原始字典的格式。原始的 pytrees 包含自訂資料類別,例如 TrainStateoptax 狀態。

  • 這是因為在復原 pytree 時,程式還不知道它曾經屬於哪個結構。

  • 為了解決此問題,您應先提供一個範例 pytree,讓 Orbax 或 Flax 確實知道要復原到哪個結構。

本節說明如何明確設定任何自訂 Flax 資料類別,使其具有與儲存檢查點相同的結構。

請注意:以 JAX NumPy 陣列 (jnp.array) 格式儲存的資料會復原為 NumPy 陣列 (numpy.array)。由於計算開始後,JAX 會 自動將 NumPy 陣列轉換為 JAX 陣列,因此這不會影響您的工作。

empty_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=jax.tree_util.tree_map(np.zeros_like, variables['params']),  # values of the tree leaf doesn't matter
    tx=tx,
)
empty_config = {'dimensions': np.array([0, 0])}
target = {'model': empty_state, 'config': empty_config, 'data': [jnp.zeros_like(x1)]}
state_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save', item=target)
state_restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

使用舊版 API#

或者,您可以按照下列方式從 Orbax CheckpointManager 和舊版 Flax 程式碼還原

checkpoint_manager.restore(4, items=target)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}
checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=target)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': array([5, 3])},
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

通常建議重構初始化檢查點結構的流程(例如,TrainState),這樣才能更容易而且更不容易出錯地儲存/載入。這是因為函數和複雜物件(例如 apply_fntx(最佳化器))無法序列化到檢查點檔案,而且必須透過程式碼初始化。

還原檢查點結構不同的時候#

在開發過程中,修改模型、在調整期間加入/移除欄位等等都會讓檢查點結構改變。

此區段說明如何將舊資料載入新程式碼。

以下是一個簡單範例:CustomTrainState 是從 flax.training.train_state.TrainState 延伸的,它多了一個名為 batch_stats 的欄位。在實際模型中,在套用批次正規化時會用得到。

在這裡,您將新的 CustomTrainState 儲存為步驟 5,而步驟 4 包含舊的/先前的 TrainState

class CustomTrainState(train_state.TrainState):
    batch_stats: Any = None

custom_state = CustomTrainState.create(
    apply_fn=state.apply_fn,
    params=state.params,
    tx=state.tx,
    batch_stats=np.arange(10),
)

custom_ckpt = {'model': custom_state, 'config': config, 'data': [x1]}
# Use a custom state to read the old `TrainState` checkpoint.
custom_target = {'model': custom_state, 'config': None, 'data': [jnp.zeros_like(x1)]}

# Save it in Orbax.
custom_save_args = orbax_utils.save_args_from_target(custom_ckpt)
checkpoint_manager.save(5, custom_ckpt, save_kwargs={'save_args': custom_save_args})
True

建議讓檢查點隨時保持最新,與您的 pytree 資料類別定義相同。但是,您有時候可能被迫在執行時期還原具有不相容參考物件的檢查點。當這發生時,檢查點還原會盡量遵從給定的參考結構。

以下列出幾個常見情境範例。

情境 1:參考物件部分時#

如果您的參考物件是檢查點的子樹,還原作業會忽略額外的欄位,並還原與參考具有相同結構的檢查點。

就像以下範例中一樣,batch_stats 欄位在 CustomTrainState 中被忽略,而檢查點則還原成 TrainState

這在只讀取檢查點的一部分時也很有用。

restored = checkpoint_manager.restore(5, items=target)
assert not hasattr(restored, 'batch_stats')
assert type(restored['model']) == train_state.TrainState
restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=0, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

情境 2:檢查點部分時#

另一方面,如果參考物件包含檢查點中沒有的值,檢查點程式碼預設會警告某些資料不相容。

要繞過這個錯誤,你必須傳遞一個 Orbax transform,告訴 Orbax 如何將這個檢查點轉換成 custom_target 的結構。

在本例中,傳遞一個預設的 {},讓 Orbax 使用 custom_target 的數值來填補空白。這可讓你將一個舊的檢查點還原到一個新的資料結構 CustomTrainState 中。

try:
    checkpoint_manager.restore(4, items=custom_target)
except KeyError as e:
    print(f'KeyError when target state has an unmentioned field: {e}')
    print('')

# Step 4 is an original `TrainState`, without the `batch_stats`
custom_restore_args = orbax_utils.restore_args_from_target(custom_target)
restored = checkpoint_manager.restore(4, items=custom_target,
                                      restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args})
assert type(restored['model']) == CustomTrainState
np.testing.assert_equal(restored['model'].batch_stats,
                        custom_target['model'].batch_stats)
restored
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
KeyError when target state has an unmentioned field: 'batch_stats'
{'config': None,
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)],
 'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))}

使用 Orbax#

如果你已將檢查點儲存在 Orbax 後端,則可以使用 orbax_transforms 在 Flax API 中存取這個 transforms 參數。

# Save in the "Flax-with-Orbax" backend.
flax.config.update('flax_use_orbax_checkpointing', True)
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=4,
                            overwrite=True,
                            keep=2)

checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=custom_target, step=4,
                               orbax_transforms={})
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),
 'config': None,
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

使用舊版 API#

使用舊版 flax.training.checkpoints API,也能執行類似的操作,但不如 Orbax 轉換 靈活。

你需要將檢查點還原成一個 target=None 的原始字典,適當地修改結構,然後將其解序列化回原始目標。

# Save using the legacy Flax `checkpoints` API.
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=5,
                            overwrite=True,
                            keep=2)

# Pass no target to get a raw state dictionary first.
raw_state_dict = checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=None, step=5)
# Add/remove fields as needed.
raw_state_dict['model']['batch_stats'] = np.flip(np.arange(10))
# Restore the classes with correct target now
flax.serialization.from_state_dict(custom_target, raw_state_dict)
{'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])),
 'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)]}

非同步檢查點#

檢查點會大量使用輸入/輸出,如果你有大量的資料要儲存,最好將其放入背景執行緒,同時繼續你的訓練。

你可以透過建立一個 orbax.checkpoint.AsyncCheckpointer 來取代 orbax.checkpoint.PyTreeCheckpointer 來做到這一點。

注意:你應使用相同的 async_checkpointer 來處理所有在你的訓練步驟中的非同步儲存,以便確保在上一個非同步儲存完成之前,不會開始下一個儲存。這能啟用簿記,例如 keep(檢查點的數量)和 overwrite,以確保各個步驟的一致性。

每當你想要明確等待非同步儲存完成時,可以呼叫 async_checkpointer.wait_until_finished()

# `orbax.checkpoint.AsyncCheckpointer` needs some multi-process initialization, because it was
# originally designed for multi-process large model checkpointing.
# For Python notebooks or other single-process settings, just set up with `num_processes=1`.
# Refer to https://jax.dev.org.tw/en/latest/multi_process.html#initializing-the-cluster
# for how to set it up in multi-process scenarios.
jax.distributed.initialize("localhost:8889", num_processes=1, process_id=0)

async_checkpointer = orbax.checkpoint.AsyncCheckpointer(
    orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)

# Save your job:
async_checkpointer.save('/tmp/flax_ckpt/orbax/single_save_async', ckpt, save_args=save_args)
# ... Continue with your work...

# ... Until a time when you want to wait until the save completes:
async_checkpointer.wait_until_finished()  # Blocks until the checkpoint saving is completed.
async_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save_async', item=target)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

如果你正在使用 Orbax CheckpointManager,只要在初始化時傳入 async_checkpointer 即可。然後,實際上應呼叫 async_checkpoint_manager.wait_until_finished()

async_checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed_async', async_checkpointer, options)
async_checkpoint_manager.wait_until_finished()

多主機/多程序檢查點#

JAX 提供幾個方法在多台主機上同時擴充你的代碼規模。這通常發生在設備 (CPU/GPU/TPU) 數量龐大到不同設備由不同主機 (CPU) 管理時。若要開始在多程序設定中使用 JAX,請查看於多主機和多程序環境中使用 JAX分散式陣列指南

使用 JAX jit單一程式多資料 (SPMD)準則中,大量多程序陣列可以在不同設備上共享其資料。(請注意 JAX pjitjit已經合併成單一統一介面。若要了解如何在多主機或多核心環境中編譯和執行 JAX 函數,請參閱本指南jax.Array 移轉指南。)當一個多程序陣列被序列化時,每個主機會將其資料碎片轉儲到單一共享儲存體,例如 Google Cloud 儲存空間。

Orbax 支援儲存和載入具備多程序陣列的 pytrees,方式與單程序 pytrees 相同。然而,建議使用非同步orbax.AsyncCheckpointer在另一個執行緒上儲存大型多程序陣列,以便你可以同時執行計算與儲存。使用純粹 Orbax,在多程序上下文中儲存檢查點會使用與單程序上下文中相同的 API。

from jax.sharding import PartitionSpec, NamedSharding

# Create an array sharded across multiple devices.
mesh_shape = (4, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, ('x', 'y'))

mp_array = jax.device_put(np.arange(8 * 2).reshape(8, 2),
                          NamedSharding(mesh, PartitionSpec('x', 'y')))

# Make it a pytree.
mp_ckpt = {'model': mp_array}
async_checkpoint_manager.save(0, mp_ckpt)
async_checkpoint_manager.wait_until_finished()

還原含有具多程序陣列的檢查點時,你需要指定各個陣列應還原回什麼分片。否則,它們會在程序 0 上被還原為大型np.array,耗費時間和記憶體。

(在本筆記書中,由於我們在單程序上執行,即使我們提供分片,它仍會還原為 np.array。)

使用 Orbax#

Orbax 允許你透過在 restore_args 中傳遞一個 分片 pytree 來指定。如果你已經有一個具有所有具有正確分片的陣列的參考 pytree,你可以使用 orbax_utils.restore_args_from_target 將其轉換為 Orbax 所需的 restore_args

# The reference doesn't need to be as large as your checkpoint!
# Just make sure it has the `.sharding` you want.
mp_smaller = jax.device_put(np.arange(8).reshape(4, 2),
                            NamedSharding(mesh, PartitionSpec('x', 'y')))
ref_ckpt = {'model': mp_smaller}

restore_args = orbax_utils.restore_args_from_target(ref_ckpt)
async_checkpoint_manager.restore(
    0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args})
{'model': Array([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]], dtype=int32)}

使用舊版 Flax:使用 save_checkpoint_multiprocess#

在舊有的 Flax 中,如要儲存多處理程序陣列,請使用 flax.training.checkpoints.save_checkpoint_multiprocess() 取代 save_checkpoint(),並加上相同的參數。

如果您的檢查點太大,您可以在管理員中指定 timeout_secs,並給予更多時間讓寫入動作完成。

async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)
checkpoints.save_checkpoint_multiprocess(ckpt_dir,
                                         mp_ckpt,
                                         step=3,
                                         overwrite=True,
                                         keep=4,
                                         orbax_checkpointer=async_checkpointer)
'/tmp/flax_ckpt/checkpoint_3'
mp_restored = checkpoints.restore_checkpoint(ckpt_dir,
                                             target=ref_ckpt,
                                             step=3,
                                             orbax_checkpointer=async_checkpointer)
mp_restored
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': Array([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]], dtype=int32)}

Orbax 作為後端故障排除 #

作為遷移的中間階段(從舊版 Flax checkpoints API 轉移到 Orbax),flax.training.checkpoints API 將從 2023 年 5 月 15 日開始儲存檢查點時使用 Orbax 作為其後端。

使用 Orbax 後端儲存的檢查點可以由 flax.training.checkpoints.restore_checkpointorbax.checkpoint.PyTreeCheckpointer 讀取。

在程式碼中,這等於設定 config 旗標 flax.config.flax_use_orbax_checkpointing 預設為 True。您可以在專案中隨時使用 flax.config.update('flax_use_orbax_checkpointing', <BoolValue>) 改寫此值。

一般來說,此自動遷移不會影響大部分使用者。但是,如果您的 API 使用方式遵循一些特定模式,您可能會遇到問題。請查看以下各節以進行故障排除。

如果您的裝置在寫入檢查點時當機 #

如果您在多主機環境(通常大於 8 個 TPU 裝置)中執行,而且您的裝置在寫入檢查點時當機,請檢查您的程式碼是否符合下列模式(亦即,save_checkpoint 僅在主機 0 上執行)

if jax.process_index() == 0:
  flax.training.checkpoints.save_checkpoint(...)

很遺憾,這是一個將會被棄用且不受支援的舊模式,因為在多處理程序環境中,檢查點編寫程式碼應在主機間協調,而不是只在主機 0 上觸發。將上述程式碼替換為下列程式碼應該可以解決當機問題。

flax.training.checkpoints.save_checkpoint_multiprocess(...)

如果您不儲存 pytrees #

Orbax 使用 orbax.checkpoint.PyTreeCheckpointHandler 儲存檢查點,這表示它們只儲存 pytrees。

如果您想要儲存單數陣列或數字,您有兩個選擇

  1. 使用 orbax.ArrayCheckpointHandler 來儲存它們,依照 此遷移部分

  2. 用 pytree 包裝它,並像往常一樣儲存。