Flax NNX 詞彙表#
如需其他術語,請參考 JAX 詞彙表。
- 篩選器#
一種僅從 Flax NNX 模組 (
nnx.Module
) 中提取特定 nnx.Variable 物件的方式。這通常透過在nnx.Module
上呼叫nnx.split
來完成。請參閱篩選器指南以了解更多資訊。- 折疊輸入#
在 Flax 中,折疊輸入表示在給定輸入 PRNG 金鑰和整數的情況下,產生新的 JAX 虛擬隨機數產生器 (PRNG) 金鑰。這通常用於當您想要產生新的金鑰,但仍然能夠在之後使用原始 PRNG 金鑰時。您也可以在 JAX 中使用 jax.random.split 來執行此操作,但此方法實際上會建立兩個 PRNG 金鑰,這會比較慢。請在 隨機性/PRNG 指南中了解 Flax 如何自動產生新的 PRNG 金鑰。
- GraphDef#
nnx.GraphDef
是一個類別,代表 Flax 模組 (nnx.Module
) 的所有靜態、無狀態和 Python 式的部分。- 合併#
請參閱分割與合併。
- 模組#
nnx.Module
是一個資料類別,可讓您以參照透明的方式定義和初始化參數。它負責儲存和更新其內部的 :term:`Variable<Variable> 物件和參數。- Params / 參數#
nnx.Param
是nnx.Variable
的特定子類別,通常包含可訓練的權重。- PRNG 狀態#
Flax
nnx.Module
可以保留 虛擬隨機數產生器 (PRNG) 狀態物件nnx.Rngs
的參照,該物件可以產生新的 JAX PRNG 金鑰。這些金鑰用於透過 JAX 的函數式 PRNG 產生隨機 JAX 陣列。您可以使用具有不同種子的 PRNG 狀態,以更精細地控制您的模型 (例如,讓參數和 dropout 遮罩擁有獨立的隨機數)。請參閱 Flax 隨機性/PRNG 指南以了解更多詳細資訊。- 分割與合併#
nnx.split
是一種使用兩個部分來表示nnx.Module
的方法:1) 一個靜態的 Flax NNX GraphDef,用於擷取其 Python 式的靜態資訊;以及 2) 一個或多個 變數狀態,用於擷取其 JAX 陣列 (jax.Array
),其形式為 JAX pytree。它們可以使用nnx.merge
合併回原始nnx.Module
。- 轉換#
Flax NNX 轉換 (transform) 是 JAX 轉換的封裝版本,允許正在轉換的函數將 Flax NNX 模組 (
nnx.Module
) 作為輸入或輸出。例如,jax.jit 的「提升」版本是nnx.jit
。請查看 Flax NNX 轉換指南以了解更多資訊。- 變數#
位於 Flax 模組 中的權重/參數/資料/陣列
nnx.Variable
。變數在模組內定義為nnx.Variable
或其子類別。- 變數狀態#
nnx.VariableState
是所有 變數 在 模組 內部的純粹函數式 JAX pytree。由於它是純粹的,它可以是 JAX 轉換函數的輸入或輸出。nnx.VariableState
是透過在nnx.Module
上使用nnx.split
取得。(請參閱分割和模組以了解更多資訊。)