模型手術#
模型手術是指對現有神經網路的建構模塊和參數進行修改的動作,例如層替換、參數或狀態操作,甚至是「猴子修補」。在本指南中,您將學習如何在 Flax NNX 中使用多種真實場景執行模型手術
Pythonic
nnx.Module
操作:使用 Pythonic 方法操作給定模型下的子Module
。操作抽象模型或狀態:在不配置記憶體的情況下操作
flax.nnx.Module
和狀態的關鍵技巧。從原始狀態到模型的檢查點手術:當參數狀態與現有模型程式碼不相容時,如何操作參數狀態。
部分初始化:如何使用簡單方法或記憶體有效率的方法從頭開始初始化模型的一部分。
from typing import *
from pprint import pprint
import functools
import jax
from jax import lax, numpy as jnp, tree_util as jtu
from jax.sharding import PartitionSpec, Mesh, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import flax.traverse_util
import numpy as np
import orbax.checkpoint as orbax
key = jax.random.key(0)
class TwoLayerMLP(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
def __call__(self, x):
x = self.linear1(x)
return self.linear2(x)
Pythonic nnx.Module
操作#
當出現以下情況時,執行模型手術會更容易:
您已經有一個完全成熟且載入正確參數的模型;以及
您不打算變更模型定義程式碼。
您可以在其子 Module
上執行各種 Pythonic 操作,例如子 Module
交換、Module
共用、變數共用和猴子修補
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))
# Sub-`Module` swapping.
original1, original2 = model.linear1, model.linear2
model.linear1, model.linear2 = model.linear2, model.linear1
np.testing.assert_allclose(model(x), original1(original2(x)))
# `Module` sharing (tying all weights together).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear2 = model.linear1
assert not hasattr(nnx.state(model), 'linear2')
np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))
# Variable sharing (weight-tying).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate
assert hasattr(nnx.state(model), 'linear2')
assert hasattr(nnx.state(model)['linear2'], 'bias')
assert not hasattr(nnx.state(model)['linear2'], 'kernel')
# Monkey-patching.
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
def awesome_layer(x): return x
model.linear2 = awesome_layer
np.testing.assert_allclose(model(x), model.linear1(x))
建立不含記憶體配置的抽象模型或狀態#
若要執行更複雜的模型手術,您可以使用的關鍵技術是在不配置任何實際參數資料的情況下建立和操作抽象模型或狀態。這可以加快試驗疊代速度,並消除對記憶體限制的擔憂。
若要建立抽象模型
建立一個返回有效 Flax NNX 模型的函數;以及
在其上執行
nnx.eval_shape
(而不是jax.eval_shape
)。
現在,您可以像平常一樣使用 nnx.split
來取得其抽象狀態。請注意,真實模型中應該是 jax.Array
的所有欄位現在都具有抽象 jax.ShapeDtypeStruct
類型,僅具有形狀/dtype/分片資訊。
abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
gdef, abs_state = nnx.split(abs_model)
pprint(abs_state)
State({
'linear1': {
'bias': VariableState(
type=Param,
value=ShapeDtypeStruct(shape=(4,), dtype=float32)
),
'kernel': VariableState(
type=Param,
value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
)
},
'linear2': {
'bias': VariableState(
type=Param,
value=ShapeDtypeStruct(shape=(4,), dtype=float32)
),
'kernel': VariableState(
type=Param,
value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
)
}
})
當您使用真實的 jax.Array
填滿每個 nnx.VariableState
pytree 葉節點的 value
屬性時,抽象模型就等同於真實模型。
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
abs_state['linear1']['kernel'].value = model.linear1.kernel
abs_state['linear1']['bias'].value = model.linear1.bias
abs_state['linear2']['kernel'].value = model.linear2.kernel
abs_state['linear2']['bias'].value = model.linear2.bias
nnx.update(abs_model, abs_state)
np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now!
檢查點手術#
掌握抽象狀態技術後,您可以對任何檢查點(或執行階段參數 pytree)執行任意操作,使其與給定的模型程式碼相符,然後呼叫 nnx.update
來合併它們。
如果您嘗試大幅變更模型程式碼,例如從 Flax Linen 遷移到 Flax NNX,而且舊的權重不再自然相容,這可能會很有幫助。
讓我們在這裡執行一個簡單的範例
# Save a version of model into a checkpoint
checkpointer = orbax.PyTreeCheckpointer()
old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)
在這個新模型中,子 Module
的名稱已從 linear(1|2)
重新命名為 layer(1|2)
。由於 pytree 結構已變更,因此不可能使用新的模型狀態結構直接載入舊的檢查點
class ModifiedTwoLayerMLP(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.layer1 = nnx.Linear(dim, dim, rngs=rngs) # no longer linear1!
self.layer2 = nnx.Linear(dim, dim, rngs=rngs)
def __call__(self, x):
x = self.layer1(x)
return self.layer2(x)
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
try:
with_item = checkpointer.restore('/tmp/nnx-surgery-state', item=nnx.state(abs_model))
print(with_item)
except Exception as e:
print(f'This will throw error: {type(e)}: {e}')
This will throw error: <class 'ValueError'>: Dict key mismatch; expected keys: ['linear1', 'linear2']; dict: {'layer1': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}, 'layer2': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}}.
但是,您可以將參數 pytree 載入為原始字典,執行重新命名,並產生與新的模型定義保證相容的新狀態。
def process_raw_dict(raw_state_dict):
flattened = nnx.traversals.flatten_mapping(raw_state_dict)
# Cut the '.value' postfix on every leaf path.
flattened = {(path[:-1] if path[-1] == 'value' else path): value
for path, value in flattened.items()}
return nnx.traversals.unflatten_mapping(flattened)
# Make your local change on the checkpoint dictionary.
raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw_dict)
raw_dict['layer1'] = raw_dict.pop('linear1')
raw_dict['layer2'] = raw_dict.pop('linear2')
# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
state.replace_by_pure_dict(process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)
np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},
'kernel': {'value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],
[ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],
[ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],
[-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)}},
'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},
'kernel': {'value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],
[ 0.41914317, 0.84359694, -0.47937787, -0.49135214],
[-0.46072108, 0.4630125 , 0.39276958, -0.9441406 ],
[-0.6690758 , -0.18474789, -0.57622856, 0.4821079 ]], dtype=float32)}}}
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1136: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
部分初始化#
在某些情況下(例如使用 LoRA (低秩適應)),您可能只想隨機初始化模型的部分參數。這可以透過以下方式實現:
簡單的部分初始化;或
記憶體有效率的部分初始化。
簡單的部分初始化#
若要執行簡單的部分初始化,您可以直接初始化整個模型,然後換入預先訓練的參數。但是,如果您的修改需要重新建立稍後會捨棄的模組參數,這種方法可能會在中途配置額外的記憶體。以下是這方面的範例。
注意: 您可以使用
jax.live_arrays()
來檢查任何給定時間所有存在於記憶體中的陣列。當您多次執行單一 Jupyter 筆記本儲存格時(由於舊 Python 變數的垃圾收集),此呼叫可能會「搞砸」。但是,在筆記本中重新啟動 Python 核心並從頭開始執行程式碼將始終產生相同的輸出。
# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))
simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))
print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')
# In this line, extra kernel and bias is created inside the new LoRALinear!
# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.
simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))
print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'
' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')
nnx.update(simple_model, old_state)
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
' (2 discarded - only lora_a & lora_b are used in model)')
Number of jax arrays in memory at start: 38
Number of jax arrays in memory midway: 42 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)
Number of jax arrays in memory at end: 40 (2 discarded - only lora_a & lora_b are used in model)
記憶體有效率的部分初始化#
若要執行記憶體有效率的部分初始化,請使用 nnx.jit
的有效編譯程式碼,以確保只初始化您需要的狀態參數
# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))
# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!
@nnx.jit(donate_argnums=0)
def partial_init(old_state, rngs):
model = TwoLayerMLP(4, rngs=rngs)
# Create a new state.
model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)
# Add the existing state.
nnx.update(model, old_state)
return model
print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}')
# Note that `old_state` will be deleted after this `partial_init` call.
good_model = partial_init(old_state, nnx.Rngs(42))
print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'
' (2 new created - lora_a and lora_b)')
Number of JAX Arrays in memory at start: 44
Number of JAX Arrays in memory at end: 46 (2 new created - lora_a and lora_b)