模組#
- class flax.nnx.Module(*args, **kwargs)[原始碼]#
所有神經網路模組的基底類別。
層和模型應該繼承這個類別。
Module
可以包含子模組,並以這種方式嵌套在樹狀結構中。子模組可以在__init__
方法內部作為常規屬性賦值。您可以在您的
Module
子類別上定義任意的「前向傳遞」方法。雖然沒有特殊的方法,但__call__
是一個受歡迎的選擇,因為您可以直接呼叫Module
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- eval(**attributes)[原始碼]#
將模組設定為評估模式。
eval
使用set_attributes
遞迴地設定所有具有這些屬性的嵌套模組的屬性deterministic=True
和use_running_average=True
。它主要用於控制Dropout
和BatchNorm
模組的執行時行為。範例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.eval() >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
- 參數
**attributes – 傳遞給
set_attributes
的其他屬性。
- iter_children()[原始碼]#
迭代目前模組的所有子
Module
。這個方法類似於iter_modules()
,除了它只迭代直接子代,並且不會進一步遞迴。iter_children
會建立一個產生器,產生鍵和模組實例,其中鍵是一個字串,表示用於存取對應子模組的模組屬性名稱。範例
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_children(): ... print(path, type(module).__name__) ... batch_norm BatchNorm dropout Dropout linear Linear submodule SubModule
- iter_modules()[原始碼]#
遞迴迭代目前模組的所有巢狀
Module
,包括目前的模組。iter_modules
會建立一個產生器,產生路徑和模組實例,其中路徑是一個字串或整數的元組,表示從根模組到模組的路徑。範例
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_modules(): ... print(path, type(module).__name__) ... ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear ('submodule', 'linear1') Linear ('submodule', 'linear2') Linear ('submodule',) SubModule () Block
- set_attributes(*filters, raise_if_not_found=True, **attributes)[原始碼]#
設定巢狀模組(包括目前模組)的屬性。如果屬性在模組中找不到,則會忽略它。
範例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.set_attributes(deterministic=True, use_running_average=True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
可以使用
Filter
來設定特定模組的屬性>>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False)
- 參數
*filters – 選擇要設定其屬性的模組的篩選器。
raise_if_not_found – 如果為 True(預設值),則如果至少在其中一個選取的模組中找不到一個屬性實例,則會引發 ValueError。
**attributes – 要設定的屬性。
- sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[原始碼]#
sow()
可用於收集中間值,而無需透過每次模組呼叫顯式傳遞容器。sow()
會將值儲存在新的Module
屬性中,以name
表示。該值將由類型為variable_type
的Variable
包裝,這在split()
、state()
和pop()
中進行篩選時非常有用。預設情況下,這些值會儲存在一個元組中,並且每個儲存的值都會附加到結尾。這樣,當同一個模組被多次呼叫時,可以追蹤所有中間值。
使用範例
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x, add=0): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x+add) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> assert len(model.i.value) == 1 # tuple of length 1 >>> assert model.i.value[0].shape == (1, 3) >>> y = model(x, add=1) >>> assert len(model.i.value) == 2 # tuple of length 2 >>> assert (model.i.value[0] + 1 == model.i.value[1]).all()
或者,可以傳遞自訂的 init/reduce 函數
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... self.sow(nnx.Intermediate, 'product', x, ... init_fn=lambda: 1, ... reduce_fn=lambda prev, curr: prev*curr) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x) >>> assert (model.sum.value == model.product.value).all() >>> intermediate = model.sum.value >>> y = model(x) >>> assert (model.sum.value == intermediate*2).all() >>> assert (model.product.value == intermediate**2).all()
- 參數
variable_type – 儲存值的
Variable
類型。通常Intermediate
用於表示中間值。name – 一個字串,表示儲存已播種值的
Module
屬性名稱。value – 要儲存的值。
reduce_fn – 用於將現有值與新值組合的函數。預設值是將值附加到一個元組。
init_fn – 對於儲存的第一個值,
reduce_fn
將會與要儲存的值一起傳遞init_fn
的結果。預設值為空元組。
- train(**attributes)[原始碼]#
將模組設定為訓練模式。
train
使用set_attributes
遞迴地設定所有具有這些屬性的巢狀模組的屬性deterministic=False
和use_running_average=False
。它主要用於控制Dropout
和BatchNorm
模組的執行時行為。範例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... # initialize Dropout and BatchNorm in eval mode ... self.dropout = nnx.Dropout(0.5, deterministic=True) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) >>> block.train() >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False)
- 參數
**attributes – 傳遞給
set_attributes
的其他屬性。