儲存與載入檢查點#
本指南示範如何使用 Orbax 儲存和載入 Flax NNX 模型檢查點。
注意:Flax 團隊不會積極維護用於將模型檢查點儲存和載入到磁碟的程式庫。因此,建議您使用 Orbax 之類的外部程式庫來執行此操作。
在本指南中,您將學習如何
儲存檢查點。
還原檢查點。
在檢查點結構不同的情況下還原檢查點。
執行多進程檢查點。
本指南中使用的 Orbax API 範例僅用於示範目的,如需最新的建議 API,請參閱 Orbax 網站。
注意:Flax 團隊建議使用 Orbax 將檢查點儲存和載入到磁碟,因為我們不會積極維護這些功能的程式庫。
注意:如果您正在尋找 Flax Linen 的舊版
flax.training.checkpoints
套件,它已於 2023 年被棄用,改用 Orbax。相關文件位於 Flax Linen 網站。
設定#
匯入必要的相依性,設定檢查點目錄和範例 Flax NNX 模型 - TwoLayerMLP
- 方法是繼承 nnx.Module
。
from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
class TwoLayerMLP(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
def __call__(self, x):
x = self.linear1(x)
return self.linear2(x)
# Instantiate the model and show we can run it.
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)
儲存檢查點#
JAX 檢查點程式庫(例如 Orbax)可以儲存和載入任何指定的 JAX pytree,它是一個純粹的(可能是巢狀的)jax.Array
的容器(或者,其他框架會稱之為「張量」)。在機器學習的內容中,檢查點通常是模型參數和其他資料(例如最佳化器狀態)的 pytree。
在 Flax NNX 中,您可以透過呼叫 nnx.split
,並取得傳回的 nnx.State
,從 nnx.Module
取得這樣的 pytree。
_, state = nnx.split(model)
nnx.display(state)
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / 'state', state)
還原檢查點#
請注意,您將檢查點儲存為 Flax 的 nnx.State
類別,它也與 nnx.VariableState
和 nnx.Param
類別巢狀。
在檢查點還原時間,您需要在執行階段準備好這些類別,並指示檢查點程式庫 (Orbax) 將您的 pytree 還原回該結構。這可以透過以下方式達成
首先,建立抽象的 Flax NNX 模型(而不配置任何陣列記憶體),並將其抽象變數狀態顯示給檢查點程式庫。
取得狀態後,使用
nnx.merge
取得您的 Flax NNX 模型,並像平常一樣使用它。
# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
print('The abstract NNX state (all leaves are abstract arrays):')
nnx.display(abstract_state)
state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)
jax.tree.map(np.testing.assert_array_equal, state, state_restored)
print('NNX State restored: ')
nnx.display(state_restored)
# The model is now good to use!
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4)
The abstract NNX state (all leaves are abstract arrays):
NNX State restored:
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1136: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
The abstract NNX state (all leaves are abstract arrays):
NNX State restored:
/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
以純字典儲存和還原#
當與檢查點程式庫(如 Orbax)互動時,您可能偏好使用 Python 內建的容器類型。在這種情況下,您可以使用 nnx.State.to_pure_dict
和 nnx.State.replace_by_pure_dict
API 將 nnx.State
轉換為純巢狀字典並從中轉換。
# Save as pure dict
pure_dict_state = state.to_pure_dict()
nnx.display(pure_dict_state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)
# Restore as a pure dictionary.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
當檢查點結構不同時還原#
當您想要載入一些不再符合您目前模型程式碼的過時檢查點時,將檢查點載入為純巢狀字典的功能會派上用場。查看下面的簡單範例。
如果您將檢查點儲存為 nnx.State
而非純字典,此模式也有效。查看 模型手術 指南的 檢查點手術章節,以取得包含程式碼的範例。唯一的不同是您需要在呼叫 nnx.State.replace_by_pure_dict
之前稍微重新處理您的原始字典。
class ModifiedTwoLayerMLP(nnx.Module):
"""A modified version of TwoLayerMLP, which requires bias arrays."""
def __init__(self, dim, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!
self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!
def __call__(self, x):
x = self.linear1(x)
return self.linear2(x)
# Accommodate your old checkpoint to the new code.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))
restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))
# Same restore code as above.
abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The new model works!
nnx.display(model.linear1)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
多進程檢查點#
在多主機/多進程環境中,您會希望將檢查點還原為跨多個裝置分片。查看 Flax 在多個裝置上擴展 指南中的 從檢查點載入分片模型 章節,以了解如何衍生分片 pytree 並使用它來載入您的檢查點。
注意:JAX 提供多種方法可以在多個主機上同時擴展您的程式碼。當裝置 (CPU/GPU/TPU) 的數量非常大,以至於不同的裝置由不同的主機 (CPU) 管理時,通常會發生這種情況。查看 JAX 的 平行程式設計簡介、在多主機和多進程環境中使用 JAX、分散式陣列和自動平行化,以及 使用
shard_map
進行手動平行化。
其他檢查點功能#
本指南僅使用最簡單的 orbax.checkpoint.StandardCheckpointer
API 來示範如何在 Flax 模型端進行儲存和載入。您可以隨意使用其他工具或程式庫。
此外,請查看 Orbax 網站,以了解其他常用的功能,例如
CheckpointManager
用於追蹤不同步驟的檢查點。Orbax 轉換:一種在載入時(而非在載入後)修改 pytree 結構的方法,本指南中示範了這一點。