狀態#

class flax.nnx.State(mapping, /, *, _copy=True)[原始碼]#

一個類似於 pytree 的結構,包含從可哈希且可比較的鍵到葉節點的 Mapping。葉節點可以是任何類型,但 VariableStateVariable 最常見。

filter(first, /, *filters)[原始碼]#

將一個 State 過濾成一個或多個 State。使用者必須至少傳遞一個 Filter (例如 Variable)。這個方法類似於 split(),但過濾器可以是非詳盡的。

範例用法

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param = state.filter(nnx.Param)
>>> batch_stats = state.filter(nnx.BatchStat)
>>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
參數
  • first – 第一個過濾器

  • *filters – 可選的額外過濾器,用於將狀態分組為互斥的子狀態。

返回

一個或多個 States,等於傳遞的過濾器數量。

static merge(state, /, *states)[原始碼]#

split() 相反。

merge 接受一個或多個 State 並建立一個新的 State

範例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> params.linear.bias.value += 1

>>> state = nnx.State.merge(params, batch_stats)
>>> nnx.update(model, state)
>>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
參數
  • state – 一個 State 物件。

  • *states – 額外的 State 物件。

返回

合併後的 State

split(first, /, *filters)[原始碼]#

將一個 State 分割成一個或多個 State。使用者必須至少傳遞一個 Filter (例如 Variable),且過濾器必須是詳盡的 (也就是說,它們必須涵蓋 State 中所有 Variable 類型)。

範例用法

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
參數
  • first – 第一個過濾器

  • *filters – 可選的額外過濾器,用於將狀態分組為互斥的子狀態。

返回

一個或多個 States,等於傳遞的過濾器數量。