橋接#

class flax.nnx.bridge.ToNNX(*args, **kwargs)[原始碼]#

一個將任何 Linen 模組轉換為 NNX 模組的包裝器。

產生的 NNX 模組可以獨立使用所有 NNX API,或者作為另一個 NNX 模組的子模組。

由於 Linen 模組初始化需要一個範例輸入,您需要使用一個參數呼叫 lazy_init 來初始化變數。

範例

>>> from flax import linen as nn, nnx
>>> import jax
>>> linen_module = nn.Dense(features=64)
>>> x = jax.numpy.ones((1, 32))
>>> # Like Linen init(), initialize with a sample input
>>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
>>> # Like Linen apply(), but using NNX's direct call method
>>> y = model(x)
>>> model.kernel.shape
(32, 64)
參數
  • module – Linen 模組實例。

  • rngs – 傳遞到任何 NNX 模組的 nnx.Rngs 實例。

回傳值

一個有狀態的 NNX 模組,其行為與包裝的 Linen 模組相同。

__call__(*args, rngs=None, method=None, **kwargs)[原始碼]#

將自身作為函數呼叫。

lazy_init(*args, **kwargs)[原始碼]#

在此模組上呼叫 nnx.bridge.lazy_init() 的快捷方式。

方法

lazy_init(*args, **kwargs)

在此模組上呼叫 nnx.bridge.lazy_init() 的快捷方式。

class flax.nnx.bridge.ToLinen(nnx_class, args=(), kwargs=FrozenDict({}), skip_rng=False, metadata_type=<class 'flax.nnx.bridge.variables.NNXMeta'>, parent=<flax.linen.module._Sentinel object>, name=None)[原始碼]#

一個將任何 NNX 模組轉換為 Linen 模組的包裝器。

產生的 Linen 模組可以獨立使用所有 Linen API,或者作為另一個 Linen 模組的子模組。

由於 NNX 模組是有狀態的並擁有狀態,我們僅在初始化時建立它一次,並將其狀態和靜態資料作為單獨的變數進行追蹤。

範例

>>> from flax import linen as nn, nnx
>>> import jax
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> x = jax.numpy.ones((1, 32))
>>> y, variables = model.init_with_output(jax.random.key(0), x)
>>> y.shape
(1, 64)
>>> variables['params']['kernel'].shape
(32, 64)
>>> # The static GraphDef of the underlying NNX module
>>> variables.keys()
dict_keys(['nnx', 'params'])
>>> type(variables['nnx']['graphdef'])
<class 'flax.nnx.graph.NodeDef'>
參數
  • nnx_class – NNX 模組類別(而非實例!)。

  • args – 通常會傳遞以建立 NNX 模組的引數。

  • kwargs – 通常會傳遞以建立 NNX 模組的關鍵字引數。

  • skip_rng – 如果此 NNX 模組在初始化期間不需要 rngs 引數,則為 True(不常見)。

回傳值

一個有狀態的 NNX 模組,其行為與包裝的 Linen 模組相同。

__call__(*args, **kwargs)[原始碼]#

將自身作為函數呼叫。

方法

flax.nnx.bridge.to_linen(nnx_class, *args, name=None, **kwargs)[原始碼]#

如果使用者未變更其任何預設欄位,則為 nnx.bridge.ToLinen 的快捷方式。

class flax.nnx.bridge.NNXMeta(var_type, value, metadata)[原始碼]#

nnx.VariableState 的預設 Flax 元資料類別。

__call__(**kwargs)#

將自身作為函數呼叫。

add_axis(index, params)[原始碼]#

將新軸新增至軸元資料。

請注意,add_axis 和 remove_axis 應互為反運算(意思是:x.add_axis(i, p).remove_axis(i, p) == x

參數
  • index – 新軸將插入的位置

  • params – 轉換引入新軸傳遞的任意參數字典(例如:nn.scannn.vmap)。使用者將此字典作為 metadata_param 引數傳遞給轉換。

回傳值

一個與 self 類型相同,且具有相同 unbox 內容以及已更新軸元資料的新實例。

get_partition_spec()[原始碼]#

傳回此已分割值的 Partitionspec

remove_axis(index, params)[原始碼]#

從軸元資料中移除軸。

請注意,add_axis 和 remove_axis 應互為反運算(意思是:x.remove_axis(i, p).add_axis(i, p) == x

參數
  • index – 要移除的軸位置

  • params – 轉換引入軸傳遞的任意參數字典(例如:nn.scannn.vmap)。使用者將此字典作為 metadata_param 引數傳遞給轉換。

回傳值

一個與 self 類型相同,且具有相同 unbox 內容以及已更新軸元資料的新實例。

replace(**updates)#

傳回一個新物件,將指定的欄位替換為新值。

replace_boxed(val)[原始碼]#

將已裝箱的值替換為提供的值。

參數

val – 要由此 AxisMetadata 包裝器裝箱的新值

回傳值

與自身類型相同的新實例,並以 val 作為新的 unbox 內容

to_nnx_variable()[原始碼]#
unbox()[原始碼]#

傳回 AxisMetadata 盒子的內容。

請注意,與 meta.unbox 不同,unbox 呼叫不應遞迴地解開 metadata。它應該直接傳回它包裝的值,即使該值本身是 AxisMetadata 的實例。

實際上,AxisMetadata 子類別應該註冊為 PyTree 節點,以支援將實例傳遞給 JAX 和 Flax API。此節點傳回的葉子應該對應於 unbox 傳回的值。

回傳值

解箱的值。

方法

add_axis(index, params)

將新軸新增至軸元資料。

get_partition_spec()

傳回此已分割值的 Partitionspec

remove_axis(index, params)

從軸元資料中移除軸。

replace(**updates)

傳回一個新物件,將指定的欄位替換為新值。

replace_boxed(val)

將已裝箱的值替換為提供的值。

to_nnx_variable()

unbox()

傳回 AxisMetadata 盒子的內容。