狀態#
- class flax.nnx.State(mapping, /, *, _copy=True)[原始碼]#
一個類似於 pytree 的結構,包含從可哈希且可比較的鍵到葉節點的
Mapping
。葉節點可以是任何類型,但VariableState
和Variable
最常見。- filter(first, /, *filters)[原始碼]#
將一個
State
過濾成一個或多個State
。使用者必須至少傳遞一個Filter
(例如Variable
)。這個方法類似於split()
,但過濾器可以是非詳盡的。範例用法
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param = state.filter(nnx.Param) >>> batch_stats = state.filter(nnx.BatchStat) >>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
- 參數
first – 第一個過濾器
*filters – 可選的額外過濾器,用於將狀態分組為互斥的子狀態。
- 返回
一個或多個
States
,等於傳遞的過濾器數量。
- static merge(state, /, *states)[原始碼]#
與
split()
相反。merge
接受一個或多個State
並建立一個新的State
。範例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> params.linear.bias.value += 1 >>> state = nnx.State.merge(params, batch_stats) >>> nnx.update(model, state) >>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
- 參數
state – 一個
State
物件。*states – 額外的
State
物件。
- 返回
合併後的
State
。
- split(first, /, *filters)[原始碼]#
將一個
State
分割成一個或多個State
。使用者必須至少傳遞一個Filter
(例如Variable
),且過濾器必須是詳盡的 (也就是說,它們必須涵蓋State
中所有Variable
類型)。範例用法
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
- 參數
first – 第一個過濾器
*filters – 可選的額外過濾器,用於將狀態分組為互斥的子狀態。
- 返回
一個或多個
States
,等於傳遞的過濾器數量。