轉換#

class flax.nnx.Jit(*args, **kwargs)[原始碼]#
class flax.nnx.Remat(*args, **kwargs)[原始碼]#
class flax.nnx.Scan(*args, **kwargs)[原始碼]#
class flax.nnx.Vmap(*args, **kwargs)[原始碼]#
flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[原始碼]#

`jax.grad` 的提升版本,可以處理模組/圖形節點作為引數。

每個圖形節點的可微分狀態由 wrt 篩選器定義,預設設定為 nnx.Param。在內部,圖形節點的 State 會被提取,根據 wrt 篩選器進行篩選,並傳遞到基礎的 jax.grad 函式。圖形節點的梯度類型為 State

範例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
  'bias': VariableState(
    type=Param,
    value=(3,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(2, 3)
  )
})
參數
  • fun – 要微分的函式。由 argnums 指定位置的引數應為陣列、純量、圖形節點或標準 Python 容器。由 argnums 指定位置的引數陣列必須為非精確類型(即浮點或複數)。它應返回純量(包括形狀為 () 的陣列,但不包括形狀為 (1,) 等的陣列)

  • argnums – 可選,整數或整數序列。指定要針對哪個位置引數進行微分(預設為 0)。

  • has_aux – 可選,布林值。表示 fun 是否返回一個配對,其中第一個元素被視為要微分的數學函式的輸出,而第二個元素是輔助資料。預設為 False。

  • holomorphic – 可選,布林值。表示 fun 是否保證為全純的。如果為 True,則輸入和輸出必須為複數。預設為 False。

  • allow_int – 可選,布林值。是否允許針對整數值輸入進行微分。整數輸入的梯度將具有一個微不足道的向量空間資料類型 (float0)。預設為 False。

  • reduce_axes – 可選,軸名稱的元組。如果此處列出了一個軸,且 fun 隱式地在該軸上廣播值,則反向傳遞將執行相應梯度的 psum。否則,梯度將在命名軸上為每個範例計算。例如,如果 'batch' 是一個命名的批次軸,則 grad(f, reduce_axes=('batch',)) 將建立一個計算總梯度的函式,而 grad(f) 將建立一個計算每個範例梯度的函式。

flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[原始碼]#

可以處理模組/圖形節點作為引數的 jax.jit 的提升版本。

參數
  • fun

    要 JIT 編譯的函式。fun 應為純函式,因為副作用可能只執行一次。

    fun 的引數和傳回值應為陣列、純量或(巢狀)標準 Python 容器(tuple/list/dict)。由 static_argnums 指示的位置引數可以是任何東西,只要它們是可雜湊的且定義了相等運算。靜態引數會被包含在編譯快取鍵的一部分,這就是為什麼必須定義雜湊和相等運算子的原因。

    JAX 會保留 fun 的弱參考,以用作編譯快取鍵,因此物件 fun 必須是可弱參考的。大多數 Callable 物件已經滿足此要求。

  • in_shardings

    fun 的引數結構匹配的 Pytree,其中所有實際引數都替換為資源指派規範。指定 Pytree 前置詞(例如,用一個值取代整個子樹)也有效,在這種情況下,葉節點會廣播到該子樹中的所有值。

    in_shardings 參數是選填的。JAX 會從輸入的 jax.Array 推斷分片方式,如果無法推斷,則預設為複製輸入。

    有效的資源分配規範如下:
    • Sharding,它將決定數值如何

      被分割。有了這個,就不需要使用網格上下文管理器。

    • None,JAX 可以自由選擇它想要的分片方式。對於 in_shardings,JAX 會將其標記為複製,但此行為未來可能會改變。對於 out_shardings,我們將依賴 XLA GSPMD 分割器來決定輸出分片。

    每個維度的大小必須是分配給它的資源總數的倍數。這與 pjit 的 in_shardings 類似。

  • out_shardings

    in_shardings 類似,但指定函數輸出的資源分配。這與 pjit 的 out_shardings 類似。

    out_shardings 參數是選填的。如果未指定,jax.jit() 將使用 GSPMD 的分片傳播來計算輸出應該的分片方式。

  • static_argnums

    一個可選的整數或整數集合,用於指定哪些位置參數應被視為靜態(編譯時常數)。僅依賴於靜態參數的操作將在 Python 中(在追蹤期間)進行常數折疊,因此相應的參數值可以是任何 Python 物件。

    靜態參數應該是可雜湊的,這表示 __hash____eq__ 都已實作,並且是不可變的。使用不同的常數值呼叫已 jit 的函數將觸發重新編譯。不是陣列或其容器的參數必須標記為靜態。

    如果沒有提供 static_argnumsstatic_argnames,則不會將任何參數視為靜態。如果沒有提供 static_argnums 但提供了 static_argnames,反之亦然,JAX 會使用 inspect.signature(fun) 來尋找任何對應於 static_argnames 的位置參數(反之亦然)。如果同時提供了 static_argnumsstatic_argnames,則不會使用 inspect.signature,並且只會將 static_argnumsstatic_argnames 中列出的實際參數視為靜態。

  • static_argnames – 一個可選的字串或字串集合,用於指定哪些具名參數應被視為靜態(編譯時常數)。有關詳細資訊,請參閱關於 static_argnums 的註解。如果沒有提供,但設定了 static_argnums,則預設值基於呼叫 inspect.signature(fun) 來尋找對應的具名參數。

  • donate_argnums

    指定哪些位置參數緩衝區會「捐贈」給計算。如果計算完成後您不再需要參數緩衝區,則可以安全地捐贈它們。在某些情況下,XLA 可以使用捐贈的緩衝區來減少執行計算所需的記憶體量,例如,回收您的其中一個輸入緩衝區來儲存結果。您不應重複使用捐贈給計算的緩衝區,如果您嘗試這樣做,JAX 會引發錯誤。預設情況下,不會捐贈任何參數緩衝區。

    如果沒有提供 donate_argnumsdonate_argnames,則不會捐贈任何參數。如果沒有提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 會使用 inspect.signature(fun) 來尋找任何對應於 donate_argnames 的位置參數(反之亦然)。如果同時提供了 donate_argnumsdonate_argnames,則不會使用 inspect.signature,並且只會捐贈 donate_argnumsdonate_argnames 中列出的實際參數。

    有關緩衝區捐贈的更多詳細資訊,請參閱 常見問題解答

  • donate_argnames – 一個可選的字串或字串集合,用於指定哪些具名參數會捐贈給計算。有關詳細資訊,請參閱關於 donate_argnums 的註解。如果沒有提供,但設定了 donate_argnums,則預設值基於呼叫 inspect.signature(fun) 來尋找對應的具名參數。

  • keep_unused – 如果為 False (預設值),則 JAX 判斷為 fun 未使用的參數可能會從產生的已編譯 XLA 可執行檔中移除。這些參數不會傳輸到裝置或提供給底層的可執行檔。如果為 True,則不會修剪未使用的參數。

  • device – 這是一個實驗性功能,API 很可能會變更。選填,Jitted 函數將在其上執行的裝置。(可以透過 jax.devices() 檢索可用的裝置。)預設值繼承自 XLA 的 DeviceAssignment 邏輯,通常是使用 jax.devices()[0]

  • backend – 這是一個實驗性功能,API 很可能會變更。選填,表示 XLA 後端的字串:'cpu''gpu''tpu'

  • inline – 指定是否應將此函數內聯到封閉的 jaxpr 中(而不是表示為具有其自身 subjaxpr 的 xla_call primitive 的應用)。預設值為 False。

返回

fun 的包裝版本,已設定為即時編譯。

flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[來源]#
flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[來源]#
flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[來源]#
flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[來源]#

參考感知版本的 jax.vmap

參數
  • f – 要在額外軸上映射的函數。

  • in_axes – 一個整數、None 或一系列值,用於指定要映射的輸入陣列軸 (請參閱 jax.vmap)。除了整數和 None 之外,StateAxes 可用於控制圖形節點 (如 Modules) 的向量化方式,方法是指定要應用於圖形節點子狀態的軸,並給定一個 Filter

  • out_axes – 一個整數、None 或 pytree,指示映射的軸應出現在輸出中的位置 (請參閱 jax.vmap)。

  • axis_name – 選填,一個可雜湊的 Python 物件,用於識別映射的軸,以便可以應用並行集合。

  • axis_size – 選填,一個整數,指示要映射的軸的大小。如果未提供,則會從參數推斷映射的軸大小。

返回

f 的批次/向量化版本,其參數對應於 f 的參數,但在 in_axes 指示的位置有額外的陣列軸,以及一個傳回值,其對應於 f 的傳回值,但在 out_axes 指示的位置有額外的陣列軸。

範例

>>> from flax import nnx
>>> from jax import random, numpy as jnp
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((5, 2))
...
>>> @nnx.vmap(in_axes=(None, 0), out_axes=0)
... def forward(model, x):
...   return model(x)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)
>>> class LinearEnsemble(nnx.Module):
...   def __init__(self, num, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3)))
...
>>> model = LinearEnsemble(5, rngs=nnx.Rngs(0))
>>> x = jnp.ones((2,))
...
>>> @nnx.vmap(in_axes=(0, None), out_axes=0)
... def forward(model, x):
...   return jnp.dot(x, model.w.value)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)

為了控制圖形節點子狀態如何向量化,可以將 StateAxes 傳遞給 in_axesout_axes,以指定要套用至每個子狀態的軸,並給定一個篩選器。以下範例展示如何在集成成員之間共享參數,同時保持不同的批次統計資訊和 dropout 隨機狀態。

>>> class Foo(nnx.Module):
...   def __init__(self):
...     self.a = nnx.Param(jnp.arange(4))
...     self.b = nnx.BatchStat(jnp.arange(4))
...
>>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None})
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=0)
... def mul(foo):
...   return foo.a * foo.b
...
>>> foo = Foo()
>>> y = mul(foo)
>>> y
Array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]], dtype=int32)
flax.nnx.eval_shape(f, *args, **kwargs)[原始碼]#
flax.nnx.custom_vjp(fun=<flax.typing.Missing object>, *, nondiff_argnums=())[原始碼]#

jax.custom_vjp 的參考感知版本。

nnx.custom_vjp 接受模組和其他 Flax NNX 物件作為引數。與 JAX 版本的主要區別在於,由於模組遵循參考語義,它們會將輸入的狀態更新作為輔助輸出傳播。這表示 bwd 函數中的傳入梯度將具有 (input_updates_g, out_g) 的形式,其中 input_updates_g 是輸入的梯度更新狀態(相對於輸入)。輸入中所有模組項都將在 input_updates_g 中有相關聯的 State 項,而所有非模組項將顯示為 None。切線的形狀預期與輸入的形狀相同,並在相應的模組項的位置有 State 項。

範例

>>> import jax
>>> import jax.numpy as jnp
>>> from flax import nnx
...
>>> class Foo(nnx.Module):
...   def __init__(self, x, y):
...     self.x = nnx.Param(x)
...     self.y = nnx.Param(y)
...
>>> @nnx.custom_vjp
... def f(m: Foo):
...   return jnp.sin(m.x) * m.y
...
>>> def f_fwd(m: Foo):
...   return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
...
>>> def f_bwd(res, g):
...   input_updates_g, out_g = g
...   cos_x, sin_x, m = res
...   (m_updates_g,) = input_updates_g
...   m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
...   m_g['x'].value = cos_x * out_g * m.y
...   m_g['y'].value = sin_x * out_g
...   return (m_g,)
...
>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grads = nnx.grad(f)(m)
...
>>> jax.tree.map(jnp.shape, grads)
State({
  'x': VariableState(
    type=Param,
    value=()
  ),
  'y': VariableState(
    type=Param,
    value=()
  )
})

請注意,在 input_updates_g 中表示模組項的 State 物件與輸出切線中預期的 State 物件具有相同的形狀。這表示您通常可以直接從 input_updates_g 複製它們,並使用它們對應的梯度值更新它們。

您可以透過將 DiffState 傳遞給 nondiff_argnums,來選擇模組和其他圖形節點哪些子狀態是可微分的(具有切線)。例如,如果您只想區分 Foo 類的 x 屬性,您可以執行以下操作:

>>> x_attribute = nnx.PathContains('x')
>>> diff_state = nnx.DiffState(0, x_attribute)
...
>>> @nnx.custom_vjp(nondiff_argnums=(diff_state,))
... def f(m: Foo):
...   return jnp.sin(m.x) * m.y  # type: ignore

>>> def f_fwd(m: Foo):
...   y = f(m)
...   res = (jnp.cos(m.x), m)  # type: ignore
...   return y, res
...
>>> def f_bwd(res, g):
...   input_updates_g, out_g = g
...   cos_x, m = res
...   (m_updates_g,) = input_updates_g
...   m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
...   m_g.x.value = cos_x * out_g * m.y
...   del m_g['y'] # y is not differentiable
...   return (m_g,)

>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m)
...
>>> jax.tree.map(jnp.shape, grad)
State({
  'x': VariableState(
    type=Param,
    value=()
  )
})

請注意,grad 無法計算沒有由 custom_vjp 定義切線的狀態的梯度,在上面的範例中,我們重複使用相同的 x_attribute 篩選器,以使 custom_vjpgrad 保持同步。

參數
  • fun – 可呼叫的基本函數。

  • nondiff_argnums – 指定不進行微分的引數索引的整數或 DiffState 物件的元組。預設情況下,所有引數都會被微分。整數不能用於將圖形節點(如模組)標記為不可微分,在這種情況下,請使用 DiffState 物件。DiffState 物件定義可微分子狀態的集合,與此引數的名稱所暗示的相反,這樣做是為了與 grad 相容。

flax.nnx.cond(pred, true_fun, false_fun, *operands, **kwargs)[原始碼]#
flax.nnx.switch(index, branches, *operands)[原始碼]#
flax.nnx.while_loop(cond_fun, body_fun, init_val)[原始碼]#

Flax NNX 對 jax.lax.while_loop 的轉換。

注意:為了使 NNX 內部參考追蹤機制正常運作,您不能更改 body_funinit_val 的變數參考結構。

範例

>>> import jax
>>> from flax import nnx
>>> def fwd_fn(input):
...   module, x, count = input
...   return module, module(x), count - 1.0

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> # `module` will be called three times
>>> _, y, _ = nnx.while_loop(
...   lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
參數
  • cond_fun – while 迴圈的繼續條件函數,採用類型為 T 的單一輸入並輸出布林值。

  • body_fun – 一個函數,接受類型為 T 的輸入並輸出一個 T。請注意,T 的資料和模組在輸入和輸出之間必須具有相同的參考結構。

  • init_valcond_funbody_fun 的初始輸入。必須為 T 類型。

flax.nnx.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[原始碼]#

Flax NNX 對 jax.lax.fori_loop 的轉換。

注意:為了使 NNX 內部參考追蹤機制正常運作,您不能更改 body_funinit_val 的變數參考結構。

範例

>>> import jax
>>> from flax import nnx

>>> def fwd_fn(i, input):
...   m, x = input
...   m.kernel.value = jnp.identity(10) * i
...   return m, m(x)

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
>>> np.testing.assert_array_equal(y, x * 2 * 3)
參數
  • lower – 表示迴圈索引下限(包含)的整數。

  • upper – 表示迴圈索引上限(不包含)的整數。

  • body_fun – 一個函數,接受類型為 T 的輸入並輸出一個 T。請注意,T 的資料和模組在輸入和輸出之間必須具有相同的參考結構。

  • init_val – body_fun 的初始輸入。必須為 T 類型。

  • unroll – 一個可選的整數或布林值,決定要展開迴圈多少。如果提供整數,它會決定在迴圈的單一展開迭代中執行多少個展開的迴圈迭代。如果提供布林值,它會決定迴圈是否完全展開(即 unroll=True)或完全不展開(即 unroll=False)。此引數僅適用於靜態已知迴圈邊界的情況。

返回

來自最後一次迭代的迴圈值,類型為 T