將檢查點移轉至 Orbax#
本指南說明如何將 Flax 的檢查點儲存和復原呼叫(flax.training.checkpoints.save_checkpoint 和 restore_checkpoint)轉換成等效的 Orbax 方法。Orbax 提供一個靈活又可自訂的 API 來管理各式物件的檢查點。請注意,當 Flax 的檢查點機制從 flax.training.checkpoints
遷移至 Orbax 時,Flax API 中所有現有功能仍會繼續獲得支援,但 API 會有所變更。
您將透過下列情境學習如何遷移至 Orbax
最常見的用例:儲存/載入和管理檢查點
「輕量化」用例:「純粹」儲存/載入,不使用頂層檢查點管理程式
在沒有目標樹狀結構的情況下,復原檢查點
非同步檢查點
儲存/載入單一 JAX 或 NumPy 陣列
若要深入了解 Orbax,請參閱 快速入門 Colab 筆記本 和 Orbax 官方文件。
您可以按一下上方的「在 Colab 中開啟」來執行本指南中的程式碼。
在整個指南中,您將可以比較有和沒有 Orbax 程式碼的範例程式碼。
設定#
# Create some dummy variables for this example.
MAX_STEPS = 5
CKPT_PYTREE = [12, {'bar': np.array((2, 3))}, [1, 4, 10]]
TARGET_PYTREE = [0, {'bar': np.array((0))}, [0, 0, 0]]
最常見的用例:儲存/載入和管理檢查點#
本節涵蓋下列情境
您的原始 Flax
save_checkpoint()
或save_checkpoint_multiprocess()
呼叫包含下列參數:prefix
、keep
、keep_every_n_steps
;或您想為您的檢查點使用一些自動化管理邏輯(例如,用於刪除舊資料、根據指標/損失刪除資料,等等)。
在這種情況下,您需要使用 orbax.CheckpointManager
。這讓您不僅可以儲存和載入模型,還能管理檢查點並自動刪除過時的檢查點。
要升級您的程式碼
在頂層建立並保留
orbax.CheckpointManager
執行個體,並使用orbax.CheckpointManagerOptions
自訂。在執行階段,呼叫
orbax.CheckpointManager.save()
來儲存您的資料。然後,呼叫
orbax.CheckpointManager.restore()
來復原您的資料。然後,如果您檢查點包含一些多主機/多處理程序陣列,請將正確的
mesh
傳入flax.training.orbax_utils.restore_args_from_target()
中,以在復原之前產生正確的restore_args
。
例如
CKPT_DIR = '/tmp/orbax_upgrade/'
flax.config.update('flax_use_orbax_checkpointing', False)
# Inside your training loop
for step in range(MAX_STEPS):
# do training
checkpoints.save_checkpoint(CKPT_DIR, CKPT_PYTREE, step=step,
prefix='test_', keep=3, keep_every_n_steps=2)
checkpoints.restore_checkpoint(CKPT_DIR, target=TARGET_PYTREE, step=4, prefix='test_')
CKPT_DIR = '/tmp/orbax_upgrade/orbax'
# At the top level
mgr_options = orbax.checkpoint.CheckpointManagerOptions(
create=True, max_to_keep=3, keep_period=2, step_prefix='test')
ckpt_mgr = orbax.checkpoint.CheckpointManager(
CKPT_DIR,
orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)
# Inside your training loop
for step in range(MAX_STEPS):
# do training
save_args = flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE)
ckpt_mgr.save(step, CKPT_PYTREE, save_kwargs={'save_args': save_args})
restore_args = flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None)
ckpt_mgr.restore(4, items=TARGET_PYTREE, restore_kwargs={'restore_args': restore_args})
「輕量級」用例:「純粹的」儲存/載入,沒有最高層級的檢查點管理程式#
如果你不想要保留最高層級的檢查點管理程式,你仍然可以用 orbax.checkpoint.Checkpointer
儲存和復原任何個別的檢查點。請注意,這表示你無法使用所有 Orbax 管理功能。
若要移轉到 Orbax 程式碼,可以使用 orbax.checkpoint.Checkpointer.save()
中的 force
參數,而非在 flax.save_checkpoint()
中使用 overwrite
參數。
例如
PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(PURE_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True)
checkpoints.restore_checkpoint(PURE_CKPT_DIR, target=TARGET_PYTREE)
PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) # A stateless object, can be created on the fly.
ckptr.save(PURE_CKPT_DIR, CKPT_PYTREE,
save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=True)
ckptr.restore(PURE_CKPT_DIR, item=TARGET_PYTREE,
restore_args=flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None))
在沒有目標 pytree 的情況下復原檢查點#
如果你需要在沒有目標 pytree 的情況下復原檢查點,請傳入 item=None
到 orbax.checkpoint.Checkpointer
,或傳入 items=None
到 orbax.CheckpointManager
的 .restore()
方法,這應會觸發復原。
例如
NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target'
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(NOTARGET_CKPT_DIR, CKPT_PYTREE, step=0)
checkpoints.restore_checkpoint(NOTARGET_CKPT_DIR, target=None)
NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target'
# A stateless object, can be created on the fly.
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
ckptr.save(NOTARGET_CKPT_DIR, CKPT_PYTREE,
save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE))
ckptr.restore(NOTARGET_CKPT_DIR, item=None)
非同步檢查點#
若要讓你的檢查點儲存變成非同步,請用 orbax.checkpoint.AsyncCheckpointer
取代 orbax.checkpoint.Checkpointer
。
然後,你可以呼叫 orbax.checkpoint.AsyncCheckpointer.wait_until_finished()
或 Orbax 的 CheckpointerManager.wait_until_finished()
以等待儲存完成。
有關更多詳細資訊,請閱讀 檢查點指南。
你也可以透過非同步管理程式的 Flax API 使用 Orbax AsyncCheckpointer。非同步管理程式會在內部呼叫 wait_until_finished()。這個解決方案沒有被動維護,建議使用 Orbax 非同步檢查點。
例如
ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'
flax.config.update('flax_use_orbax_checkpointing', True)
async_manager = checkpoints.AsyncManager()
checkpoints.save_checkpoint(ASYNC_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True, async_manager=async_manager)
checkpoints.restore_checkpoint(ASYNC_CKPT_DIR, target=TARGET_PYTREE)
ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'
import orbax.checkpoint as ocp
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(ASYNC_CKPT_DIR, args=ocp.args.StandardSave(CKPT_PYTREE))
# ... Continue with your work...
# ... Until a time when you want to wait until the save completes:
ckptr.wait_until_finished() # Blocks until the checkpoint saving is completed.
ckptr.restore(ASYNC_CKPT_DIR, args=ocp.args.StandardRestore(TARGET_PYTREE))
儲存/載入單一的 JAX 或 NumPy 陣列#
orbax.checkpoint.PyTreeCheckpointHandler
類別,正如其名稱所示,只能用於 pytrees。因此,如果您需要儲存/還原單一 pytree 葉(例如陣列),請改用 orbax.checkpoint.ArrayCheckpointHandler
。
例如
ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton'
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(ARR_CKPT_DIR, jnp.arange(10), step=0)
checkpoints.restore_checkpoint(ARR_CKPT_DIR, target=None)
ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton'
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.ArrayCheckpointHandler())
ckptr.save(ARR_CKPT_DIR, jnp.arange(10))
ckptr.restore(ARR_CKPT_DIR, item=None)
最後的話#
本指南提供從「舊版」Flax 檢查點 API 遷移至 Orbax API 的概觀。Orbax 提供更多功能,而 Orbax 團隊正積極開發新功能。敬請期待並追蹤 官方 Orbax GitHub 存放庫 以進一步瞭解!