Flax NNX 詞彙表

Flax NNX 詞彙表#

如需其他術語,請參考 JAX 詞彙表

篩選器#

一種僅從 Flax NNX 模組 (nnx.Module) 中提取特定 nnx.Variable 物件的方式。這通常透過在 nnx.Module 上呼叫 nnx.split 來完成。請參閱篩選器指南以了解更多資訊。

折疊輸入#

在 Flax 中,折疊輸入表示在給定輸入 PRNG 金鑰和整數的情況下,產生新的 JAX 虛擬隨機數產生器 (PRNG) 金鑰。這通常用於當您想要產生新的金鑰,但仍然能夠在之後使用原始 PRNG 金鑰時。您也可以在 JAX 中使用 jax.random.split 來執行此操作,但此方法實際上會建立兩個 PRNG 金鑰,這會比較慢。請在 隨機性/PRNG 指南中了解 Flax 如何自動產生新的 PRNG 金鑰。

GraphDef#

nnx.GraphDef 是一個類別,代表 Flax 模組 (nnx.Module) 的所有靜態、無狀態和 Python 式的部分。

合併#

請參閱分割與合併

模組#

nnx.Module 是一個資料類別,可讓您以參照透明的方式定義和初始化參數。它負責儲存和更新其內部的 :term:`Variable<Variable> 物件和參數。

Params / 參數#

nnx.Paramnnx.Variable 的特定子類別,通常包含可訓練的權重。

PRNG 狀態#

Flax nnx.Module 可以保留 虛擬隨機數產生器 (PRNG) 狀態物件 nnx.Rngs 的參照,該物件可以產生新的 JAX PRNG 金鑰。這些金鑰用於透過 JAX 的函數式 PRNG 產生隨機 JAX 陣列。您可以使用具有不同種子的 PRNG 狀態,以更精細地控制您的模型 (例如,讓參數和 dropout 遮罩擁有獨立的隨機數)。請參閱 Flax 隨機性/PRNG 指南以了解更多詳細資訊。

分割與合併#

nnx.split 是一種使用兩個部分來表示 nnx.Module 的方法:1) 一個靜態的 Flax NNX GraphDef,用於擷取其 Python 式的靜態資訊;以及 2) 一個或多個 變數狀態,用於擷取其 JAX 陣列 (jax.Array),其形式為 JAX pytree。它們可以使用 nnx.merge 合併回原始 nnx.Module

轉換#

Flax NNX 轉換 (transform) 是 JAX 轉換的封裝版本,允許正在轉換的函數將 Flax NNX 模組 (nnx.Module) 作為輸入或輸出。例如,jax.jit 的「提升」版本是 nnx.jit。請查看 Flax NNX 轉換指南以了解更多資訊。

變數#

位於 Flax 模組 中的權重/參數/資料/陣列 nnx.Variable。變數在模組內定義為 nnx.Variable 或其子類別。

變數狀態#

nnx.VariableState 是所有 變數模組 內部的純粹函數式 JAX pytree。由於它是純粹的,它可以是 JAX 轉換函數的輸入或輸出。nnx.VariableState 是透過在 nnx.Module 上使用 nnx.split 取得。(請參閱分割模組以了解更多資訊。)