轉換#
- flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[原始碼]#
`jax.grad` 的提升版本,可以處理模組/圖形節點作為引數。
每個圖形節點的可微分狀態由 wrt 篩選器定義,預設設定為 nnx.Param。在內部,圖形節點的
State
會被提取,根據 wrt 篩選器進行篩選,並傳遞到基礎的jax.grad
函式。圖形節點的梯度類型為State
。範例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) })
- 參數
fun – 要微分的函式。由
argnums
指定位置的引數應為陣列、純量、圖形節點或標準 Python 容器。由argnums
指定位置的引數陣列必須為非精確類型(即浮點或複數)。它應返回純量(包括形狀為()
的陣列,但不包括形狀為(1,)
等的陣列)argnums – 可選,整數或整數序列。指定要針對哪個位置引數進行微分(預設為 0)。
has_aux – 可選,布林值。表示
fun
是否返回一個配對,其中第一個元素被視為要微分的數學函式的輸出,而第二個元素是輔助資料。預設為 False。holomorphic – 可選,布林值。表示
fun
是否保證為全純的。如果為 True,則輸入和輸出必須為複數。預設為 False。allow_int – 可選,布林值。是否允許針對整數值輸入進行微分。整數輸入的梯度將具有一個微不足道的向量空間資料類型 (float0)。預設為 False。
reduce_axes – 可選,軸名稱的元組。如果此處列出了一個軸,且
fun
隱式地在該軸上廣播值,則反向傳遞將執行相應梯度的psum
。否則,梯度將在命名軸上為每個範例計算。例如,如果'batch'
是一個命名的批次軸,則grad(f, reduce_axes=('batch',))
將建立一個計算總梯度的函式,而grad(f)
將建立一個計算每個範例梯度的函式。
- flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[原始碼]#
可以處理模組/圖形節點作為引數的
jax.jit
的提升版本。- 參數
fun –
要 JIT 編譯的函式。
fun
應為純函式,因為副作用可能只執行一次。fun
的引數和傳回值應為陣列、純量或(巢狀)標準 Python 容器(tuple/list/dict)。由static_argnums
指示的位置引數可以是任何東西,只要它們是可雜湊的且定義了相等運算。靜態引數會被包含在編譯快取鍵的一部分,這就是為什麼必須定義雜湊和相等運算子的原因。JAX 會保留
fun
的弱參考,以用作編譯快取鍵,因此物件fun
必須是可弱參考的。大多數Callable
物件已經滿足此要求。in_shardings –
與
fun
的引數結構匹配的 Pytree,其中所有實際引數都替換為資源指派規範。指定 Pytree 前置詞(例如,用一個值取代整個子樹)也有效,在這種情況下,葉節點會廣播到該子樹中的所有值。in_shardings
參數是選填的。JAX 會從輸入的jax.Array
推斷分片方式,如果無法推斷,則預設為複製輸入。- 有效的資源分配規範如下:
Sharding
,它將決定數值如何被分割。有了這個,就不需要使用網格上下文管理器。
None
,JAX 可以自由選擇它想要的分片方式。對於 in_shardings,JAX 會將其標記為複製,但此行為未來可能會改變。對於 out_shardings,我們將依賴 XLA GSPMD 分割器來決定輸出分片。
每個維度的大小必須是分配給它的資源總數的倍數。這與 pjit 的 in_shardings 類似。
out_shardings –
與
in_shardings
類似,但指定函數輸出的資源分配。這與 pjit 的 out_shardings 類似。out_shardings
參數是選填的。如果未指定,jax.jit()
將使用 GSPMD 的分片傳播來計算輸出應該的分片方式。static_argnums –
一個可選的整數或整數集合,用於指定哪些位置參數應被視為靜態(編譯時常數)。僅依賴於靜態參數的操作將在 Python 中(在追蹤期間)進行常數折疊,因此相應的參數值可以是任何 Python 物件。
靜態參數應該是可雜湊的,這表示
__hash__
和__eq__
都已實作,並且是不可變的。使用不同的常數值呼叫已 jit 的函數將觸發重新編譯。不是陣列或其容器的參數必須標記為靜態。如果沒有提供
static_argnums
或static_argnames
,則不會將任何參數視為靜態。如果沒有提供static_argnums
但提供了static_argnames
,反之亦然,JAX 會使用inspect.signature(fun)
來尋找任何對應於static_argnames
的位置參數(反之亦然)。如果同時提供了static_argnums
和static_argnames
,則不會使用inspect.signature
,並且只會將static_argnums
或static_argnames
中列出的實際參數視為靜態。static_argnames – 一個可選的字串或字串集合,用於指定哪些具名參數應被視為靜態(編譯時常數)。有關詳細資訊,請參閱關於
static_argnums
的註解。如果沒有提供,但設定了static_argnums
,則預設值基於呼叫inspect.signature(fun)
來尋找對應的具名參數。donate_argnums –
指定哪些位置參數緩衝區會「捐贈」給計算。如果計算完成後您不再需要參數緩衝區,則可以安全地捐贈它們。在某些情況下,XLA 可以使用捐贈的緩衝區來減少執行計算所需的記憶體量,例如,回收您的其中一個輸入緩衝區來儲存結果。您不應重複使用捐贈給計算的緩衝區,如果您嘗試這樣做,JAX 會引發錯誤。預設情況下,不會捐贈任何參數緩衝區。
如果沒有提供
donate_argnums
或donate_argnames
,則不會捐贈任何參數。如果沒有提供donate_argnums
但提供了donate_argnames
,反之亦然,JAX 會使用inspect.signature(fun)
來尋找任何對應於donate_argnames
的位置參數(反之亦然)。如果同時提供了donate_argnums
和donate_argnames
,則不會使用inspect.signature
,並且只會捐贈donate_argnums
或donate_argnames
中列出的實際參數。有關緩衝區捐贈的更多詳細資訊,請參閱 常見問題解答。
donate_argnames – 一個可選的字串或字串集合,用於指定哪些具名參數會捐贈給計算。有關詳細資訊,請參閱關於
donate_argnums
的註解。如果沒有提供,但設定了donate_argnums
,則預設值基於呼叫inspect.signature(fun)
來尋找對應的具名參數。keep_unused – 如果為 False (預設值),則 JAX 判斷為 fun 未使用的參數可能會從產生的已編譯 XLA 可執行檔中移除。這些參數不會傳輸到裝置或提供給底層的可執行檔。如果為 True,則不會修剪未使用的參數。
device – 這是一個實驗性功能,API 很可能會變更。選填,Jitted 函數將在其上執行的裝置。(可以透過
jax.devices()
檢索可用的裝置。)預設值繼承自 XLA 的 DeviceAssignment 邏輯,通常是使用jax.devices()[0]
。backend – 這是一個實驗性功能,API 很可能會變更。選填,表示 XLA 後端的字串:
'cpu'
、'gpu'
或'tpu'
。inline – 指定是否應將此函數內聯到封閉的 jaxpr 中(而不是表示為具有其自身 subjaxpr 的 xla_call primitive 的應用)。預設值為 False。
- 返回
fun
的包裝版本,已設定為即時編譯。
-
flax.nnx.remat
(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[來源]#
-
flax.nnx.scan
(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[來源]#
-
flax.nnx.value_and_grad
(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[來源]#
-
flax.nnx.vmap
(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[來源]# 參考感知版本的 jax.vmap。
- 參數
f – 要在額外軸上映射的函數。
in_axes – 一個整數、None 或一系列值,用於指定要映射的輸入陣列軸 (請參閱 jax.vmap)。除了整數和 None 之外,
StateAxes
可用於控制圖形節點 (如 Modules) 的向量化方式,方法是指定要應用於圖形節點子狀態的軸,並給定一個 Filter。out_axes – 一個整數、None 或 pytree,指示映射的軸應出現在輸出中的位置 (請參閱 jax.vmap)。
axis_name – 選填,一個可雜湊的 Python 物件,用於識別映射的軸,以便可以應用並行集合。
axis_size – 選填,一個整數,指示要映射的軸的大小。如果未提供,則會從參數推斷映射的軸大小。
- 返回
f
的批次/向量化版本,其參數對應於f
的參數,但在in_axes
指示的位置有額外的陣列軸,以及一個傳回值,其對應於f
的傳回值,但在out_axes
指示的位置有額外的陣列軸。
範例
>>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3)
>>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return jnp.dot(x, model.w.value) ... >>> y = forward(model, x) >>> y.shape (5, 3)
為了控制圖形節點子狀態如何向量化,可以將
StateAxes
傳遞給in_axes
和out_axes
,以指定要套用至每個子狀態的軸,並給定一個篩選器。以下範例展示如何在集成成員之間共享參數,同時保持不同的批次統計資訊和 dropout 隨機狀態。>>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32)
- flax.nnx.custom_vjp(fun=<flax.typing.Missing object>, *, nondiff_argnums=())[原始碼]#
jax.custom_vjp 的參考感知版本。
nnx.custom_vjp
接受模組和其他 Flax NNX 物件作為引數。與 JAX 版本的主要區別在於,由於模組遵循參考語義,它們會將輸入的狀態更新作為輔助輸出傳播。這表示bwd
函數中的傳入梯度將具有(input_updates_g, out_g)
的形式,其中input_updates_g
是輸入的梯度更新狀態(相對於輸入)。輸入中所有模組項都將在input_updates_g
中有相關聯的State
項,而所有非模組項將顯示為 None。切線的形狀預期與輸入的形狀相同,並在相應的模組項的位置有State
項。範例
>>> import jax >>> import jax.numpy as jnp >>> from flax import nnx ... >>> class Foo(nnx.Module): ... def __init__(self, x, y): ... self.x = nnx.Param(x) ... self.y = nnx.Param(y) ... >>> @nnx.custom_vjp ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, sin_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g['x'].value = cos_x * out_g * m.y ... m_g['y'].value = sin_x * out_g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grads = nnx.grad(f)(m) ... >>> jax.tree.map(jnp.shape, grads) State({ 'x': VariableState( type=Param, value=() ), 'y': VariableState( type=Param, value=() ) })
請注意,在
input_updates_g
中表示模組項的 State 物件與輸出切線中預期的 State 物件具有相同的形狀。這表示您通常可以直接從input_updates_g
複製它們,並使用它們對應的梯度值更新它們。您可以透過將
DiffState
傳遞給nondiff_argnums
,來選擇模組和其他圖形節點哪些子狀態是可微分的(具有切線)。例如,如果您只想區分Foo
類的x
屬性,您可以執行以下操作:>>> x_attribute = nnx.PathContains('x') >>> diff_state = nnx.DiffState(0, x_attribute) ... >>> @nnx.custom_vjp(nondiff_argnums=(diff_state,)) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y # type: ignore >>> def f_fwd(m: Foo): ... y = f(m) ... res = (jnp.cos(m.x), m) # type: ignore ... return y, res ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g.x.value = cos_x * out_g * m.y ... del m_g['y'] # y is not differentiable ... return (m_g,) >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m) ... >>> jax.tree.map(jnp.shape, grad) State({ 'x': VariableState( type=Param, value=() ) })
請注意,
grad
無法計算沒有由custom_vjp
定義切線的狀態的梯度,在上面的範例中,我們重複使用相同的x_attribute
篩選器,以使custom_vjp
和grad
保持同步。- 參數
fun – 可呼叫的基本函數。
nondiff_argnums – 指定不進行微分的引數索引的整數或 DiffState 物件的元組。預設情況下,所有引數都會被微分。整數不能用於將圖形節點(如模組)標記為不可微分,在這種情況下,請使用 DiffState 物件。DiffState 物件定義可微分子狀態的集合,與此引數的名稱所暗示的相反,這樣做是為了與
grad
相容。
- flax.nnx.while_loop(cond_fun, body_fun, init_val)[原始碼]#
Flax NNX 對 jax.lax.while_loop 的轉換。
注意:為了使 NNX 內部參考追蹤機制正常運作,您不能更改
body_fun
內init_val
的變數參考結構。範例
>>> import jax >>> from flax import nnx >>> def fwd_fn(input): ... module, x, count = input ... return module, module(x), count - 1.0 >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> # `module` will be called three times >>> _, y, _ = nnx.while_loop( ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
- 參數
cond_fun – while 迴圈的繼續條件函數,採用類型為
T
的單一輸入並輸出布林值。body_fun – 一個函數,接受類型為
T
的輸入並輸出一個T
。請注意,T
的資料和模組在輸入和輸出之間必須具有相同的參考結構。init_val –
cond_fun
和body_fun
的初始輸入。必須為T
類型。
- flax.nnx.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[原始碼]#
Flax NNX 對 jax.lax.fori_loop 的轉換。
注意:為了使 NNX 內部參考追蹤機制正常運作,您不能更改 body_fun 內 init_val 的變數參考結構。
範例
>>> import jax >>> from flax import nnx >>> def fwd_fn(i, input): ... m, x = input ... m.kernel.value = jnp.identity(10) * i ... return m, m(x) >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) >>> np.testing.assert_array_equal(y, x * 2 * 3)
- 參數
lower – 表示迴圈索引下限(包含)的整數。
upper – 表示迴圈索引上限(不包含)的整數。
body_fun – 一個函數,接受類型為
T
的輸入並輸出一個T
。請注意,T
的資料和模組在輸入和輸出之間必須具有相同的參考結構。init_val – body_fun 的初始輸入。必須為
T
類型。unroll – 一個可選的整數或布林值,決定要展開迴圈多少。如果提供整數,它會決定在迴圈的單一展開迭代中執行多少個展開的迴圈迭代。如果提供布林值,它會決定迴圈是否完全展開(即
unroll=True
)或完全不展開(即unroll=False
)。此引數僅適用於靜態已知迴圈邊界的情況。
- 返回
來自最後一次迭代的迴圈值,類型為
T
。