模組#

class flax.nnx.Module(*args, **kwargs)[原始碼]#

所有神經網路模組的基底類別。

層和模型應該繼承這個類別。

Module 可以包含子模組,並以這種方式嵌套在樹狀結構中。子模組可以在 __init__ 方法內部作為常規屬性賦值。

您可以在您的 Module 子類別上定義任意的「前向傳遞」方法。雖然沒有特殊的方法,但 __call__ 是一個受歡迎的選擇,因為您可以直接呼叫 Module

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
eval(**attributes)[原始碼]#

將模組設定為評估模式。

eval 使用 set_attributes 遞迴地設定所有具有這些屬性的嵌套模組的屬性 deterministic=Trueuse_running_average=True。它主要用於控制 DropoutBatchNorm 模組的執行時行為。

範例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
參數

**attributes – 傳遞給 set_attributes 的其他屬性。

iter_children()[原始碼]#

迭代目前模組的所有子 Module。這個方法類似於 iter_modules(),除了它只迭代直接子代,並且不會進一步遞迴。

iter_children 會建立一個產生器,產生鍵和模組實例,其中鍵是一個字串,表示用於存取對應子模組的模組屬性名稱。

範例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_children():
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
iter_modules()[原始碼]#

遞迴迭代目前模組的所有巢狀 Module,包括目前的模組。

iter_modules 會建立一個產生器,產生路徑和模組實例,其中路徑是一個字串或整數的元組,表示從根模組到模組的路徑。

範例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
set_attributes(*filters, raise_if_not_found=True, **attributes)[原始碼]#

設定巢狀模組(包括目前模組)的屬性。如果屬性在模組中找不到,則會忽略它。

範例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

可以使用 Filter 來設定特定模組的屬性

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
參數
  • *filters – 選擇要設定其屬性的模組的篩選器。

  • raise_if_not_found – 如果為 True(預設值),則如果至少在其中一個選取的模組中找不到一個屬性實例,則會引發 ValueError。

  • **attributes – 要設定的屬性。

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[原始碼]#

sow() 可用於收集中間值,而無需透過每次模組呼叫顯式傳遞容器。sow() 會將值儲存在新的 Module 屬性中,以 name 表示。該值將由類型為 variable_typeVariable 包裝,這在 split()state()pop() 中進行篩選時非常有用。

預設情況下,這些值會儲存在一個元組中,並且每個儲存的值都會附加到結尾。這樣,當同一個模組被多次呼叫時,可以追蹤所有中間值。

使用範例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

或者,可以傳遞自訂的 init/reduce 函數

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value

>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
參數
  • variable_type – 儲存值的 Variable 類型。通常 Intermediate 用於表示中間值。

  • name – 一個字串,表示儲存已播種值的 Module 屬性名稱。

  • value – 要儲存的值。

  • reduce_fn – 用於將現有值與新值組合的函數。預設值是將值附加到一個元組。

  • init_fn – 對於儲存的第一個值,reduce_fn 將會與要儲存的值一起傳遞 init_fn 的結果。預設值為空元組。

train(**attributes)[原始碼]#

將模組設定為訓練模式。

train 使用 set_attributes 遞迴地設定所有具有這些屬性的巢狀模組的屬性 deterministic=Falseuse_running_average=False。它主要用於控制 DropoutBatchNorm 模組的執行時行為。

範例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
參數

**attributes – 傳遞給 set_attributes 的其他屬性。