將檢查點移轉至 Orbax#

本指南說明如何將 Flax 的檢查點儲存和復原呼叫(flax.training.checkpoints.save_checkpointrestore_checkpoint)轉換成等效的 Orbax 方法。Orbax 提供一個靈活又可自訂的 API 來管理各式物件的檢查點。請注意,當 Flax 的檢查點機制從 flax.training.checkpoints 遷移至 Orbax 時,Flax API 中所有現有功能仍會繼續獲得支援,但 API 會有所變更。

您將透過下列情境學習如何遷移至 Orbax

  • 最常見的用例:儲存/載入和管理檢查點

  • 「輕量化」用例:「純粹」儲存/載入,不使用頂層檢查點管理程式

  • 在沒有目標樹狀結構的情況下,復原檢查點

  • 非同步檢查點

  • 儲存/載入單一 JAX 或 NumPy 陣列

若要深入了解 Orbax,請參閱 快速入門 Colab 筆記本Orbax 官方文件

您可以按一下上方的「在 Colab 中開啟」來執行本指南中的程式碼。

在整個指南中,您將可以比較有和沒有 Orbax 程式碼的範例程式碼。

設定#

# Create some dummy variables for this example.
MAX_STEPS = 5
CKPT_PYTREE = [12, {'bar': np.array((2, 3))}, [1, 4, 10]]
TARGET_PYTREE = [0, {'bar': np.array((0))}, [0, 0, 0]]

最常見的用例:儲存/載入和管理檢查點#

本節涵蓋下列情境

  • 您的原始 Flax save_checkpoint()save_checkpoint_multiprocess() 呼叫包含下列參數:prefixkeepkeep_every_n_steps;或

  • 您想為您的檢查點使用一些自動化管理邏輯(例如,用於刪除舊資料、根據指標/損失刪除資料,等等)。

在這種情況下,您需要使用 orbax.CheckpointManager。這讓您不僅可以儲存和載入模型,還能管理檢查點並自動刪除過時的檢查點。

要升級您的程式碼

  1. 在頂層建立並保留 orbax.CheckpointManager 執行個體,並使用 orbax.CheckpointManagerOptions 自訂。

  2. 在執行階段,呼叫 orbax.CheckpointManager.save() 來儲存您的資料。

  3. 然後,呼叫 orbax.CheckpointManager.restore() 來復原您的資料。

  4. 然後,如果您檢查點包含一些多主機/多處理程序陣列,請將正確的 mesh 傳入 flax.training.orbax_utils.restore_args_from_target() 中,以在復原之前產生正確的 restore_args

例如

CKPT_DIR = '/tmp/orbax_upgrade/'
flax.config.update('flax_use_orbax_checkpointing', False)

# Inside your training loop
for step in range(MAX_STEPS):
  # do training
  checkpoints.save_checkpoint(CKPT_DIR, CKPT_PYTREE, step=step,
                              prefix='test_', keep=3, keep_every_n_steps=2)


checkpoints.restore_checkpoint(CKPT_DIR, target=TARGET_PYTREE, step=4, prefix='test_')
CKPT_DIR = '/tmp/orbax_upgrade/orbax'

# At the top level
mgr_options = orbax.checkpoint.CheckpointManagerOptions(
  create=True, max_to_keep=3, keep_period=2, step_prefix='test')
ckpt_mgr = orbax.checkpoint.CheckpointManager(
  CKPT_DIR,
  orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)

# Inside your training loop
for step in range(MAX_STEPS):
  # do training
  save_args = flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE)
  ckpt_mgr.save(step, CKPT_PYTREE, save_kwargs={'save_args': save_args})


restore_args = flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None)
ckpt_mgr.restore(4, items=TARGET_PYTREE, restore_kwargs={'restore_args': restore_args})

「輕量級」用例:「純粹的」儲存/載入,沒有最高層級的檢查點管理程式#

如果你不想要保留最高層級的檢查點管理程式,你仍然可以用 orbax.checkpoint.Checkpointer 儲存和復原任何個別的檢查點。請注意,這表示你無法使用所有 Orbax 管理功能。

若要移轉到 Orbax 程式碼,可以使用 orbax.checkpoint.Checkpointer.save() 中的 force 參數,而非在 flax.save_checkpoint() 中使用 overwrite 參數。

例如

PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'
flax.config.update('flax_use_orbax_checkpointing', False)

checkpoints.save_checkpoint(PURE_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True)
checkpoints.restore_checkpoint(PURE_CKPT_DIR, target=TARGET_PYTREE)
PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'

ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())  # A stateless object, can be created on the fly.
ckptr.save(PURE_CKPT_DIR, CKPT_PYTREE,
           save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=True)
ckptr.restore(PURE_CKPT_DIR, item=TARGET_PYTREE,
              restore_args=flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None))

在沒有目標 pytree 的情況下復原檢查點#

如果你需要在沒有目標 pytree 的情況下復原檢查點,請傳入 item=Noneorbax.checkpoint.Checkpointer,或傳入 items=Noneorbax.CheckpointManager.restore() 方法,這應會觸發復原。

例如

NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target'
flax.config.update('flax_use_orbax_checkpointing', False)

checkpoints.save_checkpoint(NOTARGET_CKPT_DIR, CKPT_PYTREE, step=0)
checkpoints.restore_checkpoint(NOTARGET_CKPT_DIR, target=None)
NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target'

# A stateless object, can be created on the fly.
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
ckptr.save(NOTARGET_CKPT_DIR, CKPT_PYTREE,
           save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE))
ckptr.restore(NOTARGET_CKPT_DIR, item=None)

非同步檢查點#

若要讓你的檢查點儲存變成非同步,請用 orbax.checkpoint.AsyncCheckpointer 取代 orbax.checkpoint.Checkpointer

然後,你可以呼叫 orbax.checkpoint.AsyncCheckpointer.wait_until_finished() 或 Orbax 的 CheckpointerManager.wait_until_finished() 以等待儲存完成。

有關更多詳細資訊,請閱讀 檢查點指南

你也可以透過非同步管理程式的 Flax API 使用 Orbax AsyncCheckpointer。非同步管理程式會在內部呼叫 wait_until_finished()。這個解決方案沒有被動維護,建議使用 Orbax 非同步檢查點。

例如

ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'
flax.config.update('flax_use_orbax_checkpointing', True)
async_manager = checkpoints.AsyncManager()

checkpoints.save_checkpoint(ASYNC_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True, async_manager=async_manager)
checkpoints.restore_checkpoint(ASYNC_CKPT_DIR, target=TARGET_PYTREE)
ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'

import orbax.checkpoint as ocp
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(ASYNC_CKPT_DIR, args=ocp.args.StandardSave(CKPT_PYTREE))
# ... Continue with your work...
# ... Until a time when you want to wait until the save completes:
ckptr.wait_until_finished() # Blocks until the checkpoint saving is completed.
ckptr.restore(ASYNC_CKPT_DIR, args=ocp.args.StandardRestore(TARGET_PYTREE))

儲存/載入單一的 JAX 或 NumPy 陣列#

orbax.checkpoint.PyTreeCheckpointHandler 類別,正如其名稱所示,只能用於 pytrees。因此,如果您需要儲存/還原單一 pytree 葉(例如陣列),請改用 orbax.checkpoint.ArrayCheckpointHandler

例如

ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton'
flax.config.update('flax_use_orbax_checkpointing', False)

checkpoints.save_checkpoint(ARR_CKPT_DIR, jnp.arange(10), step=0)
checkpoints.restore_checkpoint(ARR_CKPT_DIR, target=None)
ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton'

ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.ArrayCheckpointHandler())
ckptr.save(ARR_CKPT_DIR, jnp.arange(10))
ckptr.restore(ARR_CKPT_DIR, item=None)

最後的話#

本指南提供從「舊版」Flax 檢查點 API 遷移至 Orbax API 的概觀。Orbax 提供更多功能,而 Orbax 團隊正積極開發新功能。敬請期待並追蹤 官方 Orbax GitHub 存放庫 以進一步瞭解!