模組#

Flax 模組系統。

類別 flax.linen.Module[原始碼]#

所有神經網路模組的基礎類別。

層與模型應以此類別建立子類別。

所有 Flax 模組均為 Python 3.7 資料類別。由於資料類別採用 __init__,請改寫 setup(),其會自動呼叫以初始化模組。

模組可以包含子模組,並以此方式巢狀放置在樹狀結構中。子模型可指定為 setup() 方法中的常規屬性。

您可以在模組子類別中定義任意的「前向傳遞」方法。雖然沒有任何方法具有特殊狀況,但 __call__ 是一個常見選項,只要使用模組執行個體,就如同在使用函式一樣

>>> from flax import linen as nn
>>> from typing import Tuple

>>> class Module(nn.Module):
...   features: Tuple[int, ...] = (16, 4)

...   def setup(self):
...     self.dense1 = nn.Dense(self.features[0])
...     self.dense2 = nn.Dense(self.features[1])

...   def __call__(self, x):
...     return self.dense2(nn.relu(self.dense1(x)))

對於較簡潔的模組實作(其子模組定義與其使用位於同位置),您可以選用 compact() 包覆器。

__setattr__(name, val)[原始碼]#

設定此模組的屬性。

我們覆寫 setattr 僅用於透過特殊 setup() 函式中子模組的指定支援 Pythonic 命名。

self.submodule_name = MyModule(...)

我們也支援清單和其他一般 pytrees,例如

self.submodules = [MyModule0(..), MyModule1(..), ...]
參數
  • name – 要設定的屬性。

  • val – 屬性的值。

apply(變數, *自變數, rng=, 方法=, 可變動=, 擷取中間值=, **關鍵字自變數)[來源]#

將模組方法運用到變數,並傳回輸出與修改後的變數。

請注意,如果想呼叫 apply 於不同的類別方法而非 __call__,應設定 方法。例如,假設 Transformers 模組有一個名為 encode 的方法,則下列會針對該方法呼叫 apply

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Transformer(nn.Module):
...   def encode(self, x):
...     ...

>>> x = jnp.ones((16, 9))
>>> model = Transformer()
>>> variables = model.init(jax.random.key(0), x, method=Transformer.encode)

>>> encoded = model.apply(variables, x, method=Transformer.encode)

如果提供函式實體,則使用未繫結函式。例如,下列範例與上方範例相同

>>> encoded = model.apply(variables, x, method=model.encode)

您也可以將字串傳遞至模組的可呼叫屬性。例如,之前的範例可以寫成

>>> encoded = model.apply(variables, x, method='encode')

請注意 方法 也可能是未定義於 Transformer 中的函式。在這種情況下,函式應至少有一個代表模組類別實體的自變數

>>> def other_fn(instance, x):
...   # instance.some_module_attr(...)
...   instance.encode
...   ...

>>> model.apply(variables, x, method=other_fn)

如果您傳遞單一 PRNGKey,Flax 將使用它來提供 '參數' RNG 串流。如果您想使用不同的 RNG 串流或需要使用多個串流,您可以傳遞將每個 RNG 串流名稱對應至其相應 PRNGKey 的字典至 apply。如果在 RNG 串流名稱上呼叫 self.make_rng(name) 而該名稱未由使用者傳遞,則預設會使用 '參數' RNG 串流。

範例

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, add_noise=False):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     if add_noise:
...       # Add gaussian noise
...       noise_key = self.make_rng('noise')
...       x = x + jax.random.normal(noise_key, x.shape)
...
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)}
>>> variables = module.init(rngs, x)
>>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs)

>>> rngs['noise'] = jax.random.key(0)
>>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # different output (key(1) vs key(0))
>>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1)

>>> del rngs['noise']
>>> # self.make_rng('noise') will default to using the 'params' RNG stream
>>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # same output (key(0))
>>> np.testing.assert_allclose(out1, out2)

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0))
>>> # same output (key(0))
>>> np.testing.assert_allclose(out2, out3)
參數
  • 變數 – 包含由變數集合鍵控的變數的字典。請參閱flax.core.variables以取得有關變數的更多詳細資料。

  • *args – 傳遞至指定的套用方法的名稱參數。

  • rngs – 要初始化 PRNG 序列的 PRNGKey 字典。「params」PRNG 序列用於初始化參數。

  • method – 要呼叫套用函式的函式。這通常是模組中的函式。如果已提供,套用此方法。如果未提供,則套用模組的__call__方法。也可以提供字串以指定方法名稱。

  • mutable – 可以是布林、字串或清單。指定應視為可變的集合:bool:所有/沒有集合是可變的。 str:單一可變集合的集合名稱。 list:可變集合名稱的清單。

  • capture_intermediates – 如果True,則會擷取「intermediates」集合中所有模組的中間回傳值。預設情況下,只會儲存所有__call__方法的回傳值。可以傳遞函式來變更篩選器行為。篩選器函式會採用模組執行個體和方法名稱,並傳回布林值,用於指示該方法呼叫的輸出是否應儲存。

  • **kwargs – 傳遞至指定的套用方法的關鍵字參數。

傳回

如果mutable為 False,傳回輸出。如果任何集合是可變的,則傳回(output, vars),其中vars是她經過修改的集合的字典。

bind(variables, *args, rngs=None, mutable=False)[原始碼]#

透過繫結變數和 RNG,建立互動式模組執行個體。

bind 提供一個直接非透過函式 apply 轉換之模組「互動式」實例。此方法對於除錯和像筆記本般的互動式使用範例特別有用,其中函式將限制區分為不同儲存格的程式碼的能力。

一旦變數(和 RNG(可選))繫結至 Module,它就會變成一個有狀態物件。請注意,慣用語法 JAX 是函式的,因此互動式實例和純粹的 JAX API 不是很搭配。只有在互動式實驗中才應使用 bind(),在所有其他情況下,我們強烈建議使用者改用 apply()

範例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = nn.Dense(3)
...     self.decoder = nn.Dense(5)
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> x = jnp.ones((16, 9))
>>> ae = AutoEncoder()
>>> variables = ae.init(jax.random.key(0), x)
>>> model = ae.bind(variables)
>>> z = model.encoder(x)
>>> x_reconstructed = model.decoder(z)
參數
  • 變數 – 包含由變數集合鍵控的變數的字典。請參閱flax.core.variables以取得有關變數的更多詳細資料。

  • *args – 命名參數(未使用)。

  • rngs – PRNGKey 的字典,用於初始化 PRNG 順序。

  • mutable – 可以是布林、字串或清單。指定應視為可變的集合:bool:所有/沒有集合是可變的。 str:單一可變集合的集合名稱。 list:可變集合名稱的清單。

傳回

此實例的副本,具有繫結變數和 RNG。

copy(*, parent=<flax.linen.module._Sentinel object>, name=None, **updates)[source]#

建立此模組的副本,並可選擇更新參數。

參數
  • parent – 副本的父項。如果未明確指定,會以目前模組作為預設父項。

  • name – 複製模組的新名稱,預設會給予新的自動名稱。

  • **updates – 屬性更新。

傳回

此模組的副本,具有更新的名稱、父項和屬性。

get_variable(col, name, default=None)[source]#

擷取變數的值。

參數
  • col – 變數集合。

  • name – 變數名稱。

  • default – 如果變數不存在於此作用域,則傳回的預設值。

傳回

輸入變數的值,如果變數在此作用域不存在,則為預設值。

has_rng(name)[原始程式碼]#

如果名稱 namePRNGSequence 存在,將傳回 True。

has_variable(col, name)[原始程式碼]#

檢查此模組中是否具有指定集合和名稱的變數。

請參閱 flax.core.variables 以取得關於變數及集合的詳細說明。

參數
  • col – 變數集合名稱。

  • name – 變數的名稱。

傳回

如果變數存在,將傳回 True。

init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[原始程式碼]#

使用變數初始化模組方法並傳回已修改的變數。

init 的第一個引數為單一 PRNGKey,或一個將變數集合名稱對應至其 PRNGKeys 的字典,並且會呼叫 method(在預設情況下為模組的 __call__ 函式),傳遞 *args**kwargs,並傳回已初始化變數的字典。

範例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> key = jax.random.key(0)
>>> variables = module.init(key, x, train=False)

如果您傳遞單一的 PRNGKey,Flax 將使用它來提供 'params' RNG stream。如果您想使用不同的 RNG stream 或需要使用多個 stream,您可以傳遞將每個 RNG stream 名稱對應到其對應的 PRNGKey 的字典給 init。如果在非使用者所傳入的 RNG stream 名稱上呼叫 self.make_rng(name),它將預設使用 'params' RNG stream。

範例

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     other_variable = self.variable(
...       'other_collection',
...       'other_variable',
...       lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape),
...       x,
...     )
...     x = x + other_variable.value
...
...     return nn.Dense(1)(x)

>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)}
>>> variables0 = module.init(rngs, x)

>>> rngs['other_rng'] = jax.random.key(0)
>>> variables1 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables0['params'], variables1['params']
... )
>>> # different other_variable (key(1) vs key(0))
>>> np.testing.assert_raises(
...   AssertionError,
...   np.testing.assert_allclose,
...   variables0['other_collection']['other_variable'],
...   variables1['other_collection']['other_variable'],
... )

>>> del rngs['other_rng']
>>> # self.make_rng('other_rng') will default to using the 'params' RNG stream
>>> variables2 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables1['params'], variables2['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables1['other_collection']['other_variable'],
...   variables2['other_collection']['other_variable'],
... )

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> variables3 = module.init(jax.random.key(0), x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables2['params'], variables3['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables2['other_collection']['other_variable'],
...   variables3['other_collection']['other_variable'],
... )

Jit init 會用僅提供參數形狀的方式對模式進行延遲初始化,並避免使用實際值來計算前向通道。範例

>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.key(0), x)

initapply 上的一層薄封裝,因此其他 apply 參數,例如 methodmutablecapture_intermediates 也可用。

參數
  • rngs – 變數集合的 rng。

  • *args – 傳遞給 init 函式的命名參數。

  • method – 選擇性方法。如果提供,套用此方法。如果未提供,套用 __call__ 方法。也可以提供一個字串,依據名稱指定方法。

  • mutable – 可以是布林值、字串或清單。指定哪些集合應視為可變的:bool:所有/無任何集合為可變。 str:單個可變集合的名稱。 list:可變集合名稱清單。預設上,除了「中間」以外的所有集合均為可變。

  • capture_intermediates – 如果為 True,請擷取「中間」集合內部所有模組的中間回傳值。預設上,僅儲存所有 __call__ 方法的回傳值。可以傳遞函式來變更篩選器行為。篩選器函式採用模組執行個體和方法名稱,並回傳布林值指示應否儲存該方法呼叫的輸出。

  • **kwargs – 傳遞給 init 函式的關鍵字參數。

傳回

初始化的變數字典。

init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#

使用變數初始模組方法,並回傳輸出和修改的變數。

參數
  • rngs – 變數集合的 rng。

  • *args – 傳遞給 init 函式的命名參數。

  • method – 選擇性方法。如果提供,套用此方法。如果未提供,套用 __call__ 方法。也可以提供一個字串,依據名稱指定方法。

  • mutable – 可以是布林值、字串或清單。指定哪些集合應視為可變的: bool:所有/沒有集合是可變的。 str:單一可變集合名稱。 list:可變集合名稱清單。預設情況下,除了「intermediates」,所有集合都是可變的。

  • capture_intermediates – 如果為 True,請擷取「中間」集合內部所有模組的中間回傳值。預設上,僅儲存所有 __call__ 方法的回傳值。可以傳遞函式來變更篩選器行為。篩選器函式採用模組執行個體和方法名稱,並回傳布林值指示應否儲存該方法呼叫的輸出。

  • **kwargs – 傳遞給 init 函式的關鍵字參數。

傳回

(output, vars),其中 vars 是已修改集合的字典。

is_initializing()[source]#

如果在 self.init(…) 或 nn.init(…)() 中執行,傳回 True。

這是一個輔助方法,用於處理只希望在 module.initnn.init 下呼叫時執行設定邏輯的簡單初始化一般案例。對於更複雜的多階段初始化情境,最好測試特定變數集合的可變性和可能需要初始化的特定變數是否存在。

is_mutable_collection(col)[source]#

如果集合 col 是可變的,傳回 true。

lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)[原始碼]#

初始化模組,而不會在實際輸入上進行運算。

lazy_init 將初始化變數,而不會進行不必要的運算。輸入資料應傳遞為 jax.ShapeDtypeStruct,此資料會指定輸入的形狀和資料類型,但不會包含具體資料。

範例

>>> model = nn.Dense(features=256)
>>> variables = model.lazy_init(
...     jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))

傳遞給 lazy_init 的 args 和 kwargs 引數可以是具體(jax 陣列、純量、布林)和抽象(ShapeDtypeStruct)值組合。具體值僅適用於會影響變數初始化的引數。例如,此模組可能會預期關鍵字引數能啟用/停用模組的子部分。在此情況下,應傳遞明確的值(True/Flase),否則 lazy_init 無法推斷應初始化哪些變數。

參數
  • rngs – 變數集合的 rng。

  • *args – 傳遞給 init 函式的引數。

  • method – 選用方法。如果提供,則套用此方法。如果未提供,則套用 __call__ 方法。

  • mutable – 可以是布林值、字串或清單。指定哪些集合應視為可變的:bool:所有/無任何集合為可變。 str:單個可變集合的名稱。 list:可變集合名稱清單。預設上,除了「中間」以外的所有集合均為可變。

  • **kwargs – 傳遞給 init 函式的關鍵字參數。

傳回

初始化的變數字典。

make_rng(name='params')[原始碼]#

從此模組的指定 RNG 序列傳回新的 RNG 金鑰。

新的 RNG 金鑰會從前一個 RNG 金鑰分離。因此,每次呼叫 make_rng 都會傳回新的 RNG 金鑰,同時仍能保證完全可重製。

注意

如果傳遞了無效的名稱(即使用在 .init.apply 中傳遞給使用者的 RNG 金鑰),則 name 將預設為 'params'

範例

>>> import jax
>>> import flax.linen as nn

>>> class ParamsModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('params')
>>> class OtherModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('other')

>>> key = jax.random.key(0)
>>> params_out, _ = ParamsModule().init_with_output({'params': key})
>>> # self.make_rng('other') will default to using the 'params' RNG stream
>>> other_out, _ = OtherModule().init_with_output({'params': key})
>>> assert params_out == other_out

閱讀 Flax RNG 指南了解更多關於 RNG 的資訊:https://flax.dev.org.tw/en/latest/guides/flax_fundamentals/rng_guide.html

參數

name – RNG 順序名稱。

傳回

新產生的 RNG 金鑰。

module_paths(rngs, *args, show_repeated=False, mutable=DenyList(deny='intermediates'), **kwargs)[source]#

傳回將模組路徑對應至模組實例的字典。

此方法有相同的簽章,並在內部呼叫 Module.init,但它並未傳回變數,而是傳回將模組路徑對應至在執行階段使用的未連結模組實例副本的字典。 module_paths 使用 jax.eval_shape 執行前向運算,而不會使用任何 FLOP 或配置記憶體。

範例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> modules = Foo().module_paths(jax.random.key(0), x)
>>> print({
...     p: type(m).__name__ for p, m in modules.items()
... })
{'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
參數
  • rngs – 變數類別的 rng,傳遞給 Module.init

  • *args – 傳遞給前向運算的引數。

  • show_repeated – 如果為 True,重複呼叫同一個模組將顯示在表格中,否則僅顯示第一次呼叫。預設為 False

  • mutable – 可以是布林值、字串或清單。指定哪些類別應視為可變的:bool:所有/沒有類別是可變的。 str:單一可變類別的名稱。 list:可變類別名稱的清單。預設情況下,除了「中間體」以外的所有類別都是可變的。

  • **kwargs – 傳遞給前向運算的關鍵字引數。

傳回

將模組路徑對應至模組實例的字典。

param(name, init_fn, *init_args, unbox=True, **init_kwargs)[原始碼]#

宣告並傳回此模組中的參數。

參數是名為「params」的集合中的唯讀變數。請參閱 flax.core.variables 以進一步了解變數。

init_fn 的第一個引數假設為 PRNG 金鑰,會自動提供,不須使用 init_argsinit_kwargs 傳遞

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     mean = self.param('mean', nn.initializers.lecun_normal(), x.shape)
...     ...
...     return x * mean
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}

在上面的範例中,函數 lecun_normal 預計兩個引數:keyshape,但只有 shape 必須明確提供;key 會使用 params 的 PRNG 自動設定,而 params 是在使用 init() 初始化模組時傳遞的。

參數
  • name – 參數名稱。

  • init_fn – 將呼叫來計算此變數初始值的函數。此函數只會在此參數第一次在此模組中使用時呼叫一次。

  • *init_args – 要傳遞給 init_fn 的位置參數。

  • unbox – 如果為 True,AxisMetadata 實例會被其取消封裝的值取代,請參閱 flax.nn.meta.unbox(預設值:True)。

  • **init_kwargs – 要傳遞給 init_fn 的關鍵字參數。

傳回

已初始化參數的值。如果參數已存在,則會擲回錯誤。

屬性 path#

取得此模組的路徑。頂層根模組具有空的「路徑」()。請注意,這個方法只能用於具有有效範圍的繫結模組中。

使用範例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class SubModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'SubModel path: {self.path}')
...     return x

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'Model path: {self.path}')
...     return SubModel()(x)

>>> model = Model()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 2)))
Model path: ()
SubModel path: ('SubModel_0',)
perturb(name, value, collection='perturbations')[原始碼]#

將零值變數(「擾動」)新增到中間值中。

value 的梯度將與此擾動變數的梯度相同。因此,如果您使用參數和擾動作為獨立參數來定義損失函式,則可以透過對擾動參數執行 jax.grad 來取得 value 的中間梯度。

注意

這是個實驗性 API,稍後可能會進行調整以強化效能和可用性。在目前的階段,它會建立額外的虛擬變數,而虛擬變數會佔用額外的記憶體空間。僅將其用於偵錯訓練中的梯度。

範例

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = self.perturb('dense3', x)
...     return nn.Dense(2)(x)

>>> def loss(variables, inputs, targets):
...   preds = model.apply(variables, inputs)
...   return jnp.square(preds - targets).mean()

>>> x = jnp.ones((2, 9))
>>> y = jnp.ones((2, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y)
>>> print(intm_grads['perturbations']['dense3'])
[[-1.456924   -0.44332537  0.02422847]
 [-1.456924   -0.44332537  0.02422847]]

如果未將擾動傳遞至 apply,則 perturb 的行為就像停用操作,因此當不需要時,您可以輕鬆停用此行為。

>>> model.apply(variables, x) # works as expected
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> model.apply({'params': variables['params']}, x) # behaves like a no-op
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y)
>>> 'perturbations' not in intm_grads
True
put_variable(col, name, value)[原始碼]#

更新指定變數的值(如果它可變更),否則傳回錯誤。

參數
  • col – 變數集合。

  • name – 變數名稱。

  • value - 變數的新值。

setup()[原始碼]#

初始化一個模組(類似於延遲的 __init__)。

setup 會在模組實體於綁定時呼叫一次,在呼叫任何其他方法(例如 __call__)之前,或是存取 selfsetup 定義的屬性之前,立即進行這個呼叫。

這會在三種情況下發生

  1. 在呼叫 apply()init()init_and_output()

  2. 在另一個模組的 setup 方法內將名子給予模組(將模組指定到另一個模組的屬性);(請參閱 __setattr__()

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. 在使用 compact() 包裝的方法內建構模組時,在呼叫另一個方法或存取 setup 定義屬性之前

sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[原始碼]#

將一個值儲存在一個集合中。

集合可用於收集中間值,而不必承受在每個模組呼叫之間明確傳遞容器的負擔。

如果目標集合不可變異,sow 的行為就像一個無操作(no-op)並回傳 False

範例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     self.sow('intermediates', 'h', h)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(variables, x, mutable=['intermediates'])
>>> jax.tree.map(jnp.shape, state['intermediates'])
{'h': ((16, 4),)}

預設情況下,這些值儲存在一個元組中,而且每個儲存值都會附加在尾端。以這種方式,當同一個模組被呼叫多次時,就可以追蹤所有中間值。或者,可以傳遞自訂的初始化/簡化函數

>>> class Foo2(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     init_fn = lambda: 0
...     reduce_fn = lambda a, b: a + b
...     self.sow('intermediates', 'h', x,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     self.sow('intermediates', 'h', x * 2,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     return x

>>> x = jnp.ones((1, 1))
>>> model = Foo2()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(
...     variables, x, mutable=['intermediates'])
>>> print(state['intermediates'])
{'h': Array([[3.]], dtype=float32)}
參數
  • col – 變數集合的名稱。

  • name – 變數的名稱。

  • value – 變數的值。

  • reduce_fn – 與新值組合現有值時所使用的函式。預設值是將值附加到元組。

  • init_fn – 對於儲存的第一個值,reduce_fn 將傳遞 init_fn 的結果與儲存的值。預設值是空元組。

傳回

True,如果值已成功儲存;否則為 False

tabulate(rngs, *args, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[來源]#

建立表示為表格的模組摘要。

此方法具有相同的簽章,並在內部呼叫 Module.init,但不會傳回變數,而是傳回表格中模組摘要的字串。 tabulate 使用 jax.eval_shape 在不消耗任何浮點運算次數或配置記憶體的情況下執行正向計算。

額外的引數可以傳遞到 console_kwargs 引數中,例如 {'width': 120}。完整的 console_kwargs 引數清單,請參閱:https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

範例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))

>>> # print(Foo().tabulate(
>>> #     jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))

輸出如下所示

                                      Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ flops ┃ vjp_flops ┃ params          ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│         │ Foo    │ float32[16,9] │ float32[16,2] │ 1504  │ 4460      │                 │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_0 │ Dense  │ float32[16,9] │ float32[16,4] │ 1216  │ 3620      │ bias:           │
│         │        │               │               │       │           │ float32[4]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[9,4]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 40 (160 B)      │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_1 │ Dense  │ float32[16,4] │ float32[16,2] │ 288   │ 840       │ bias:           │
│         │        │               │               │       │           │ float32[2]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[4,2]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 10 (40 B)       │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│         │        │               │               │       │     Total │ 50 (200 B)      │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘

                              Total Parameters: 50 (200 B)

注意:表格中列的順序並未顯示執行順序,而是與按字母順序排序的 變數 的金鑰順序一致。

注意vjp_flops 在模組不可微時會傳回 0

參數
  • rngs – 變數類別的 rng,傳遞給 Module.init

  • *args – 傳遞給前向運算的引數。

  • 深度 – 控制概觀可以深入多少個子模組。預設值為 ,表示沒有限制。如果因為深度限制而未顯示某個子模組,其參數計數和位組會加到其第一個已顯示祖先的列中,讓所有列的總和永遠等於模組的總參數數目。

  • show_repeated – 如果為 True,重複呼叫同一個模組將顯示在表格中,否則僅顯示第一次呼叫。預設為 False

  • mutable – 可以是布林值、字串或清單。指定哪些類別應視為可變的:bool:所有/沒有類別是可變的。 str:單一可變類別的名稱。 list:可變類別名稱的清單。預設情況下,除了「中間體」以外的所有類別都是可變的。

  • 控制台關鍵字引數 – 額外的關鍵字引數的選用字典,會在使用 rich.console.Console 呈現表格時傳入。預設引數為 {'force_terminal': True, 'force_jupyter': False}

  • 表格關鍵字引數 – 額外的關鍵字引數的選用字典,會在 rich.table.Table 建構函式傳入。

  • 欄位關鍵字引數 – 額外的關鍵字引數的選用字典,會在使用 rich.table.Table.add_column 將欄位新增至表格時傳入。

  • 計算 FLOP – 是否在表格中加入 flop 欄位,以列出每個模組前向傳遞的預估 FLOP 成本。會執行實際的裝置運算 / 編譯 / 記憶體配置,但對於大型模組來說還是會增加負擔(例如,穩定擴散的 UNet 會多花 20 秒,否則表格化需花 5 秒即可完成)。

  • 計算 VJP FLOP – 是否在表格中加入 vjp_flops 欄位,以列出每個模組後向傳遞的預估 VJP FLOP 成本。會增加約為 compute_flops 2-3 倍的運算負擔。

  • **kwargs – 傳遞給前向運算的關鍵字引數。

傳回

一個摘要模組的字串。

unbind()[來源]#

傳回一個模組及其變數的不繫結副本。

unbind 有助於建立繫結模組的無狀態版本。

常見用途範例:擷取在 setup() 內部定義的子模組及其對應變數:1) 暫時 bind 父模組;然後 2) unbind 所需子模組。(請注意,setup() 僅在繫結模組時呼叫。)

>>> class Encoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(256)(x)

>>> class Decoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(784)(x)

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = Encoder()
...     self.decoder = Decoder()
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> module = AutoEncoder()
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 784)))

>>> # Extract the Encoder sub-Module and its variables
>>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
傳回

具有此模組和其變數未繫結副本的元組。

variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)[source]#

在此模組中宣告並傳回變數。

請參閱 flax.core.variables 以取得更多資訊。另請參閱 param(),以使用簡潔方式在「params」集合中定義唯讀變數。

param() 相反,傳遞使用 init_fn 的所有引數都應明確傳遞

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     key = self.make_rng('stats')
...     mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape)
...     ...
...     return x * mean.value
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}

在上述範例中,函式 lecun_normal 預期兩個引數:keyshape,且兩個都必須傳遞。呼叫 init()apply() 時,stats 的 PRNG 必須明確提供。

參數
  • col – 變數集合名稱。

  • name – 變數名稱。

  • init_fn – 將呼叫來計算此變數初始值的函式。只有在這個變數第一次在此模組中使用時,才會呼叫此函式。若為 None,則變數必須已初始化,否則會產生錯誤。

  • *init_args – 要傳遞給 init_fn 的位置參數。

  • unbox – 如果為 True,AxisMetadata 實例會被其取消封裝的值取代,請參閱 flax.nn.meta.unbox(預設值:True)。

  • **init_kwargs** – 傳遞至 init_fn 的關鍵字參數

傳回

一個 flax.core.variables.Variable 可透過「.value」屬性進行讀寫。若變數已存在,則會拋出錯誤。

property variables#

傳回此模組中的變數。

flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[原始碼]#

建立一個 apply 函式,以繫結模組呼叫 fn

Module.apply 不同,此函式傳回一項新函式,其簽章為 (variables, *args, rngs=None, **kwargs) -> T,其中 Tfn 的傳回類型。若 mutable 不為 False,傳回類型為一個二元組,其中第二項為含有異動變數的 FrozenDict

傳回的 apply 函式可直接與 JAX 轉換搭配使用,例如 jax.jit

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> variables = {}
>>> foo = Foo()
>>> f_jitted = jax.jit(nn.apply(f, foo))
>>> f_jitted(variables, jnp.ones((1, 3)))
參數
  • fn – 應套用的函式。傳遞的第一個引數將會是 module 的模組實例,其中繫結了變數及 RNG。

  • module – 用於繫結變數和 RNG 的 Module。傳遞給 fn 做為第一個引數的 Module 將會是 module 的複製項。

  • mutable – 可以是布林、字串或清單。指定應視為可變的集合:bool:所有/沒有集合是可變的。 str:單一可變集合的集合名稱。 list:可變集合名稱的清單。

  • capture_intermediates – 如果 True,在“intermediates”集合中擷取所有模組的 промежу回傳值。預設只儲存所有 __call__ 方法的回傳值。可以傳遞函式來變更篩選器行為。篩選器函式接收模組實例和方法名稱,並回傳布林值,表示是否應儲存該方法呼叫的輸出。

傳回

包裝 fn 的 apply 函式。

flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

建立 init 函式,用已繫結模組來呼叫 fn

不同於 Module.init,此函式會回傳新的函式,其簽章為 (rngs, *args, **kwargs) -> variables。rng 可以是 PRNGKeys 的詞典,或單一 `PRNGKey,這相當於傳遞一個包含一個名為“params”的 PRNGKey 的詞典。

回傳的 init 函式可直接與 JAX 轉換組合,例如 jax.jit

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init(f, foo))
>>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
參數
  • fn – 應套用的函式。傳遞的第一個引數將會是 module 的模組實例,其中繫結了變數及 RNG。

  • module – 用於繫結變數和 RNG 的 Module。傳遞給 fn 做為第一個引數的 Module 將會是 module 的複製項。

  • mutable – 可以是布林值、字串或清單。指定哪些集合應視為可變的: bool:所有/沒有集合是可變的。 str:單一可變集合名稱。 list:可變集合名稱清單。預設情況下,除了「intermediates」,所有集合都是可變的。

  • capture_intermediates – 如果 True,在“intermediates”集合中擷取所有模組的 промежу回傳值。預設只儲存所有 __call__ 方法的回傳值。可以傳遞函式來變更篩選器行為。篩選器函式接收模組實例和方法名稱,並回傳布林值,表示是否應儲存該方法呼叫的輸出。

傳回

包裝 fn 的 init 函式。

flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

建立一個 init 函式,用已繫結模組呼叫 fn,同時傳回函式輸出。

Module.init_with_output 不同,此函式傳回一個簽名為 (rngs, *args, **kwargs) -> (T, variables) 的新函式,其中 Tfn 的傳回類型。rngs 可以是 PRNGKeys 的字典,或單一 `PRNGKey,後者等於傳遞一個名稱為 “params” 的單一 PRNGKey 的字典。

回傳的 init 函式可直接與 JAX 轉換組合,例如 jax.jit

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init_with_output(f, foo))
>>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
參數
  • fn – 應套用的函式。傳遞的第一個引數將會是 module 的模組實例,其中繫結了變數及 RNG。

  • module – 用於繫結變數和 RNG 的 Module。傳遞給 fn 做為第一個引數的 Module 將會是 module 的複製項。

  • mutable – 可以是布林值、字串或清單。指定哪些集合應視為可變的: bool:所有/沒有集合是可變的。 str:單一可變集合名稱。 list:可變集合名稱清單。預設情況下,除了「intermediates」,所有集合都是可變的。

  • capture_intermediates – 如果 True,在“intermediates”集合中擷取所有模組的 промежу回傳值。預設只儲存所有 __call__ 方法的回傳值。可以傳遞函式來變更篩選器行為。篩選器函式接收模組實例和方法名稱,並回傳布林值,表示是否應儲存該方法呼叫的輸出。

傳回

包裝 fn 的 init 函式。

flax.linen.intercept_methods(interceptor)[source]#

註冊新的方法攔截器。

方法攔截器讓您可以(遠距)攔截對模組的方法呼叫。其運作方式類似於裝飾器。您可以在呼叫基礎方法前修改 args/kwargs 和/或修改呼叫基礎方法後傳回的結果。或者,您可以完全略過呼叫基礎方法並執行不同的操作。例如

>>> import flax.linen as nn
>>> import jax.numpy as jnp
...
>>> class Foo(nn.Module):
...   def __call__(self, x):
...     return x
...
>>> def my_interceptor1(next_fun, args, kwargs, context):
...   print('calling my_interceptor1')
...   return next_fun(*args, **kwargs)
...
>>> foo = Foo()
>>> with nn.intercept_methods(my_interceptor1):
...   _ = foo(jnp.ones([1]))
calling my_interceptor1

您也可以在同一方法上註冊多個攔截器。攔截器將按順序執行。例如

>>> def my_interceptor2(next_fun, args, kwargs, context):
...   print('calling my_interceptor2')
...   return next_fun(*args, **kwargs)
...
>>> with nn.intercept_methods(my_interceptor1), \
...      nn.intercept_methods(my_interceptor2):
...   _ = foo(jnp.ones([1]))
calling my_interceptor1
calling my_interceptor2

您可以直接呼叫 context.orig_method 來略過其他攔截器。例如

>>> def my_interceptor3(next_fun, args, kwargs, context):
...   print('calling my_interceptor3')
...   return context.orig_method(*args, **kwargs)
>>> with nn.intercept_methods(my_interceptor3), \
...      nn.intercept_methods(my_interceptor1), \
...      nn.intercept_methods(my_interceptor2):
...   _ = foo(jnp.ones([1]))
calling my_interceptor3

下列方法無法攔截。

  1. 加上 nn.nowrap 裝飾器的

  2. Dunder 方法,包括 __eq____repr____init____hash____post_init__

  3. Module dataclass 欄位。

  4. Module 描述符。

參數

interceptor – 方法攔截器。

flax.linen.share_scope(module, other, /)[來源]#

變更其中一個 Modules,使其共享相同的範圍。當您要包裝一個 Module,並擴充其功能,而不改變參數結構時,此功能會很有用。

share_scope 使用兩個 Modules,moduleother。若 other 有範圍,且其並非 ``module`` 範圍的下屬範圍,則 module 將使用 other 的範圍

>>> import flax.linen as nn
>>> import jax
>>> from jax import numpy as jnp, random
...
>>> class DenseLoRA(nn.Module):
...   base: nn.Dense
...   rank: int
...
...   def setup(self):
...     nn.share_scope(self, self.base)
...
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     din, dout = x.shape[-1], self.base.features
...     A = self.param('A', nn.zeros_init(), (din, self.rank))
...     B = self.param('B', nn.zeros_init(), (self.rank, dout))
...     return self.base(x) + x @ A @ B
...
>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     dense = nn.Dense(10) # base scope
...     return DenseLoRA(dense, rank=2)(x) # reuse the base scope
...
>>> model = Model()
...
>>> params = model.init(random.key(0), jnp.ones((1, 5)))['params']
>>> list(params['Dense_0'].keys())
['A', 'B', 'kernel', 'bias']

other 的範圍是 module 範圍的下屬範圍時,other 將使用 module 的範圍

>>> class DenseLoRA(nn.Module):
...   features: int
...   rank: int
...
...   def setup(self):
...     self.child = nn.Dense(self.features)
...     nn.share_scope(self, self.child)
...
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     din, dout = x.shape[-1], self.features
...     A = self.param('A', nn.zeros_init(), (din, self.rank))
...     B = self.param('B', nn.zeros_init(), (self.rank, dout))
...     return self.child(x) + x @ A @ B
...
>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     return DenseLoRA(10, rank=2)(x)
...
>>> model = Model()
...
>>> params = model.init(random.key(0), jnp.ones((1, 5)))['params']
>>> list(params['DenseLoRA_0'].keys())
['A', 'B', 'kernel', 'bias']