初始化/套用

Init/Apply#

flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[來源代碼]#

建立一個 `apply` 函式來呼叫 `fn` 以綁定模組。

與 `Module.apply` 不同,此函式傳回新的函式,其簽章為 `(variables, *args, rngs=None, **kwargs) -> T`,其中 `T` 為 `fn` 的回傳型別。如果 `mutable` 不為 `False`,回傳型別為 Tuple,其中第二個項目為 `FrozenDict`,其變數已變更。

傳回的 `apply` 函式可直接與 JAX 轉換相結合,例如 `jax.jit`

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> variables = {}
>>> foo = Foo()
>>> f_jitted = jax.jit(nn.apply(f, foo))
>>> f_jitted(variables, jnp.ones((1, 3)))
參數
  • fn – 應使用的函式。傳遞的第一個引數將是 `module` 的模組執行個體,其中包含已綁定的變數和 RNG。

  • module – 將用於綁定變數和 RNG 的 `Module`。傳遞給 `fn` 作為第一個引數的 `Module`,將會是 `module` 的複製。

  • mutable – 可以是布林值、字串或清單。指定要視為可變的集合:`bool`:所有/無任何集合可變。<` str`:單一可變集合的名稱。<` list`:可變集合名稱清單。

  • capture_intermediates – 如果為 `True`,會擷取 `"intermediates"` 集合內所有 `Modules` 的中間回傳值。預設情況下,僅儲存所有 `__call__` 方法的回傳值。可以傳遞函式來變更篩選行為。篩選函式接收 `Module` 個體和方法名稱,並傳回布林值,指出是否應儲存該方法呼叫的輸出。

傳回

將函數包覆起來套用的 apply 函數 fn

flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

建立一個 init 函數來呼叫有綁定模組的 fn

Module.init 不同,此函數會回傳一個新函數,其簽章為 (rngs, *args, **kwargs) -> variables。rngs 可以是 PRNGKeys 的字典,或單一的 `PRNGKey,這等於傳入一個名稱為「params」的 PRNGKey 字典。

回傳的 init 函數可以直接與 JAX 轉換組成,例如 jax.jit

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init(f, foo))
>>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
參數
  • fn – 應使用的函式。傳遞的第一個引數將是 `module` 的模組執行個體,其中包含已綁定的變數和 RNG。

  • module – 將用於綁定變數和 RNG 的 `Module`。傳遞給 `fn` 作為第一個引數的 `Module`,將會是 `module` 的複製。

  • mutable – 可以是 bool、str 或 list。指定哪些集合應該被視為可變動的: bool:所有/沒有集合是可變動的。 str:單一可變動集合的名稱。 list:可變動集合名稱的清單。預設情況下,除了「intermediates」之外的所有集合都是可變的。

  • capture_intermediates – 如果為 True,則擷取「intermediates」集合內所有模組的中間回傳值。預設情況下,只儲存所有 __call__ 方法的回傳值。可以傳入一個函數來變更篩選行為。篩選函數會取得模組實例與方法名稱,並回傳布林值,表示該方法呼叫的輸出是否應該儲存。

傳回

將函數包覆起來套用的 apply 函數 fn

flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

建立初始化函式,使用連結的模組呼叫 fn,同時也會傳回函數的輸出。

不同於 Module.init_with_output 這個函數會傳回一個新的函數,它的簽章是 (rngs, *args, **kwargs) -> (T, variables) 這裡 Tfn 的回傳型別。rngs 可以是 PRNGKeys 的字典或是單一的 `PRNGKey。這等於傳遞一個字典,其中包含一個名稱為 “params” 的 PRNGKey。

回傳的 init 函數可以直接與 JAX 轉換組成,例如 jax.jit

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init_with_output(f, foo))
>>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
參數
  • fn – 應使用的函式。傳遞的第一個引數將是 `module` 的模組執行個體,其中包含已綁定的變數和 RNG。

  • module – 將用於綁定變數和 RNG 的 `Module`。傳遞給 `fn` 作為第一個引數的 `Module`,將會是 `module` 的複製。

  • mutable – 可以是 bool、str 或 list。指定哪些集合應該被視為可變動的: bool:所有/沒有集合是可變動的。 str:單一可變動集合的名稱。 list:可變動集合名稱的清單。預設情況下,除了「intermediates」之外的所有集合都是可變的。

  • capture_intermediates – 如果為 `True`,會擷取 `"intermediates"` 集合內所有 `Modules` 的中間回傳值。預設情況下,僅儲存所有 `__call__` 方法的回傳值。可以傳遞函式來變更篩選行為。篩選函式接收 `Module` 個體和方法名稱,並傳回布林值,指出是否應儲存該方法呼叫的輸出。

傳回

將函數包覆起來套用的 apply 函數 fn