Init/Apply#
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[來源代碼]#
建立一個 `apply` 函式來呼叫 `
fn
` 以綁定模組。與 `
Module.apply
` 不同,此函式傳回新的函式,其簽章為 `(variables, *args, rngs=None, **kwargs) -> T
`,其中 `T
` 為 `fn
` 的回傳型別。如果 `mutable
` 不為 `False
`,回傳型別為 Tuple,其中第二個項目為 `FrozenDict`,其變數已變更。傳回的 `apply` 函式可直接與 JAX 轉換相結合,例如 `
jax.jit
`>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3)))
- 參數
fn – 應使用的函式。傳遞的第一個引數將是 `
module
` 的模組執行個體,其中包含已綁定的變數和 RNG。module – 將用於綁定變數和 RNG 的 `
Module
`。傳遞給 `fn
` 作為第一個引數的 `Module`,將會是 `module` 的複製。mutable – 可以是布林值、字串或清單。指定要視為可變的集合:`
bool
`:所有/無任何集合可變。<`str
`:單一可變集合的名稱。<`list
`:可變集合名稱清單。capture_intermediates – 如果為 `
True
`,會擷取 `"intermediates"` 集合內所有 `Modules` 的中間回傳值。預設情況下,僅儲存所有 `__call__` 方法的回傳值。可以傳遞函式來變更篩選行為。篩選函式接收 `Module` 個體和方法名稱,並傳回布林值,指出是否應儲存該方法呼叫的輸出。
- 傳回
將函數包覆起來套用的 apply 函數
fn
。
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
建立一個 init 函數來呼叫有綁定模組的
fn
。與
Module.init
不同,此函數會回傳一個新函數,其簽章為(rngs, *args, **kwargs) -> variables
。rngs 可以是 PRNGKeys 的字典,或單一的`PRNGKey
,這等於傳入一個名稱為「params」的 PRNGKey 字典。回傳的 init 函數可以直接與 JAX 轉換組成,例如
jax.jit
>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 參數
fn – 應使用的函式。傳遞的第一個引數將是 `
module
` 的模組執行個體,其中包含已綁定的變數和 RNG。module – 將用於綁定變數和 RNG 的 `
Module
`。傳遞給 `fn
` 作為第一個引數的 `Module`,將會是 `module` 的複製。mutable – 可以是 bool、str 或 list。指定哪些集合應該被視為可變動的:
bool
:所有/沒有集合是可變動的。str
:單一可變動集合的名稱。list
:可變動集合名稱的清單。預設情況下,除了「intermediates」之外的所有集合都是可變的。capture_intermediates – 如果為 True,則擷取「intermediates」集合內所有模組的中間回傳值。預設情況下,只儲存所有 __call__ 方法的回傳值。可以傳入一個函數來變更篩選行為。篩選函數會取得模組實例與方法名稱,並回傳布林值,表示該方法呼叫的輸出是否應該儲存。
- 傳回
將函數包覆起來套用的 apply 函數
fn
。
- flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
建立初始化函式,使用連結的模組呼叫
fn
,同時也會傳回函數的輸出。不同於
Module.init_with_output
這個函數會傳回一個新的函數,它的簽章是(rngs, *args, **kwargs) -> (T, variables)
這裡T
是fn
的回傳型別。rngs 可以是 PRNGKeys 的字典或是單一的`PRNGKey
。這等於傳遞一個字典,其中包含一個名稱為 “params” 的 PRNGKey。回傳的 init 函數可以直接與 JAX 轉換組成,例如
jax.jit
>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 參數
fn – 應使用的函式。傳遞的第一個引數將是 `
module
` 的模組執行個體,其中包含已綁定的變數和 RNG。module – 將用於綁定變數和 RNG 的 `
Module
`。傳遞給 `fn
` 作為第一個引數的 `Module`,將會是 `module` 的複製。mutable – 可以是 bool、str 或 list。指定哪些集合應該被視為可變動的:
bool
:所有/沒有集合是可變動的。str
:單一可變動集合的名稱。list
:可變動集合名稱的清單。預設情況下,除了「intermediates」之外的所有集合都是可變的。capture_intermediates – 如果為 `
True
`,會擷取 `"intermediates"` 集合內所有 `Modules` 的中間回傳值。預設情況下,僅儲存所有 `__call__` 方法的回傳值。可以傳遞函式來變更篩選行為。篩選函式接收 `Module` 個體和方法名稱,並傳回布林值,指出是否應儲存該方法呼叫的輸出。
- 傳回
將函數包覆起來套用的 apply 函數
fn
。