轉換#

模組的 JAX 轉換。

Jax 函數式轉換操作純函式。Flax 擴充這些轉換,也可以操作具有狀態變數和 PRNG 序列的模組。我們稱這些擴充版本為「提升的轉換」。

提昇的轉換可套用於 Module 類別或將 Module 執行個體作為其第一個引數的函數。

flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#

jax.vmap 的提昇版本。

請參閱 jax.vmap,以取得 Jax 中未提昇的批次轉換。

vmap 可用於將批次軸新增至 Module。例如,我們可以建立 Dense 的一個版本,其批次軸不共用參數

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': 0},
...     split_rngs={'params': True})

透過使用 variable_axes={'params': 0},我們表示參數本身會進行對應,因此不會沿著對應的軸共用。因此,我們也將分割「params」RNG,否則參數將沿著對應的軸以相同方式初始化。

類似地,vmap 可用於新增參數共用的批次軸

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': None},
...     split_rngs={'params': False})

在這裡,我們使用 variable_axes={'params': None} 來表示參數變數將沿著對應的軸共用。因此,也必須共用「params」RNG。

參數
  • targetModule 或函數,以 Module 作為第一個引數。

  • variable_axes – 會提升到批次轉換中的變數集合。使用 None 表示廣播集合或整數,以對軸進行對應。例如,傳入 variable_axes={'params': None} 表示應沿著對應的軸共用參數變數。

  • split_rngs – 分割的 PRNG 序列對批次維度的每個索引都會不同。未分割的 PRNG 將廣播。

  • in_axes – 指定輸入引數的對應(請參閱 jax.vmap)。

  • out_axes – 指定傳回值的對應(請參閱 jax.vmap)。

  • axis_size – 指定批次軸的大小。如果無法從輸入引數導出,才需要指定。

  • axis_name – 為批次軸指定名稱。可以與並行簡化基元一起使用(例如 jax.lax.pmeanjax.lax.ppermute 等)。請注意,這只用於 pmap 和分片對應。對於 SPMD jit,您不需要手動同步。只要確保正確註解軸,而 XLA:SPMD 會插入必要的資料集。

  • methods – 如果 targetModule,則為 Module 的方法對應 vmap。

  • spmd_axis_name – 軸名稱新增到出現在 fn 中的任何 pjit 分片約束。另請參閱 google/flax

  • metadata_params – 在變數樹中傳遞給 AxisMetadata 實例的引數字典。

傳回

target 的批次/向量化版本,具有相同的引數,但多出軸在 in_axes 指示的位置,以及相同的傳回值,但多出軸在 out_axes 指示的位置。

flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None, _split_transpose=False)[來源碼]#

jax.lax.scan 的提升版本。

請參閱 jax.lax.scan 以了解 Jax 中未提升的掃描。

為了提升與 vmap 的一致性,此版本的掃描使用 in_axesout_axes 來決定掃描哪些參數以及沿著哪個軸掃描。

scan 區分迴圈內 3 種不同類型的值

  1. 掃描:在迴圈中反覆運算的值。所有掃描值在掃描的軸上的大小必須相同。掃描的輸出將沿掃描軸堆疊。

  2. 傳遞:傳遞值在每次迴圈反覆運算時更新。在整個迴圈中,它的形狀和資料類型必須相同。

  3. 廣播:迴圈中封閉的值。當變數廣播時,它們通常在迴圈內部初始化,但獨立於迴圈變數。

目標值應具有簽章(module, 進位數, *xs) -> (進位數, ys),其中xsys是出入迴圈的掃描值。

範例

>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
...
>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...     ScanLSTM = nn.scan(
...       nn.LSTMCell, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     lstm = ScanLSTM(self.features)
...     input_shape =  x[:, 0].shape
...     carry = lstm.initialize_carry(jax.random.key(0), input_shape)
...     carry, x = lstm(carry, x)
...     return x
...
>>> x = jnp.ones((4, 12, 7))
>>> module = LSTM(features=32)
>>> y, variables = module.init_with_output(jax.random.key(0), x)

請注意,提供函式給nn.scan時,掃描從第三個引數開始處理所有引數,如in_axes所指定。上一個範例也可以使用函式形式撰寫為

>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...
...     cell = nn.LSTMCell(self.features)
...     def body_fn(cell, carry, x):
...       carry, y = cell(carry, x)
...       return carry, y
...     scan = nn.scan(
...       body_fn, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     input_shape =  x[:, 0].shape
...     carry = cell.initialize_carry(
...       jax.random.key(0), input_shape)
...     carry, x = scan(cell, carry, x)
...     return x
...
>>> module = LSTM(features=32)
>>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))

您也可以使用scan將多個層次合併為單一掃描迴圈,以減少 JAX 程式編譯時間;當您有一系列您想要反覆套用於輸入的相同層次時,您可以這麼做。例如

>>> class ResidualMLPBlock(nn.Module):
...   @nn.compact
...   def __call__(self, x, _):
...     h = nn.Dense(features=2)(x)
...     h = nn.relu(h)
...     return x + h, None
...
>>> class ResidualMLP(nn.Module):
...   n_layers: int = 4
...
...   @nn.compact
...   def __call__(self, x):
...     ScanMLP = nn.scan(
...       ResidualMLPBlock, variable_axes={'params': 0},
...       variable_broadcast=False, split_rngs={'params': True},
...       length=self.n_layers)
...     x, _ = ScanMLP()(x, None)
...     return x
...
>>> model = ResidualMLP(n_layers=4)
>>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))

為同時減少編譯和記憶體使用量,您可以使用remat_scan(),它會在掃描迴圈中檢查點每個層次。

參數
  • targetModule 或函數,以 Module 作為第一個引數。

  • 可變軸 – 掃描的變數集合。

  • 可變廣播 – 指定廣播變數集合。廣播變數不應依賴於任何無法從迴圈中移除的運算。這通常用於在 fn 內部定義共用參數。

  • 可變進位 – 指定在迴圈中傳遞的變數集合。對這些變數的突變會傳遞到下一個反覆運算,且在掃描結束時仍會保留。

  • 分離隨機數產生器 – 分離的 PRNG 序列對於每個迴圈反覆運算都會不同。如果分離為 False,則 PRNG 會在反覆運算中相同。

  • in_axes – 指定要針對引數進行掃描的軸。應該是引數的前綴樹。使用flax.core.broadcast將整個輸入傳送到掃描主體的每個反覆運算。

  • out_axes – 指定要針對回傳值進行掃描的軸。應該是回傳值的前綴樹。

  • 長度 – 指定迴圈反覆運算的數量。這只需要在無法從掃描引數得出時指定。

  • 反轉 – 如果為真,則從結束反向掃描到開始。

  • 取消展開 – 取消在迴圈一次 iteration 內展開掃描迭代的次數(預設值:1)。

  • data_transform – 選擇性函式用於變換提升的掃描 body_fn 內的原始函數核心變數和 rng 群組,用於內聯 SPMD 註解。

  • metadata_params – 在變數樹中傳遞給 AxisMetadata 實例的引數字典。

  • methods – 如果 targetModule,則掃描 Module 的方法。

  • _split_transpose – 一項實驗性質功能,用於將掃描的轉置拆分成掃描和對應,由實驗性質 Jax lax.scan() 功能所支援。

傳回

具有簽章 (module, carry, *xs) -> (carry, ys) 的掃描函式,其中 xsys 是進出迴圈的掃描值。

flax.linen.jit(target, variables=True, rngs=True, static_argnums=(), static_argnames=(), donate_argnums=(), device=None, backend=None, methods=None)[來源]#

提升版本的 jax.jit

參數
  • targetModule 或函數,以 Module 作為第一個引數。

  • variables – 提升的變數集合。預設值會提升所有集合。

  • rng – 提升的 PRNG 序列。預設值會提升所有 PRNG 序列。

  • static_argnums – 指定將哪些位置參數視為靜態(編譯時期常數)的 int 或 int 集合。僅依賴於靜態參數的操作將摺疊成 Python 常數(在追蹤期間),因此對應的參數值可以是任何 Python 物件。靜態參數應該可進行雜湊(表示同時實作了 __hash____eq__)且不可變。針對這些常數呼叫 jitted 函式,將會觸發重新編譯。如果以比 static_argnums 指出的位置參數更少的參數呼叫 jitted 函式,將會引發錯誤。非陣列或其容器的參數必須標示為靜態。預設值為 ()。

  • static_argnames – 指定將哪些命名參數視為靜態(編譯時期常數)的選用字串或字串集合。請參閱 static_argnums 的註解以取得詳細資料。如果未提供,但已設定 static_argnums,則預設值會根據呼叫 inspect.signature(fun) 來找出對應的命名參數。

  • donate_argnums – 指定哪些參數「提供」給運算。如果您在運算完成後不再需要參數,則可以安全地提供參數。在某些情況下,XLA 可以利用提供的緩衝區減少執行運算所需的記憶體量,例如回收您的其中一個輸入緩衝區來儲存結果。您不應重複使用您提供給運算的緩衝區,如果您嘗試執行,JAX 會引發錯誤。

  • device – 這是實驗性功能,且 API 可能會變更。jitted 函式將執行的裝置(選用)。(可用裝置可透過 jax.devices() 擷取。)預設值會從 XLA 的 DeviceAssignment 邏輯繼承,通常會使用 jax.devices()[0]

  • backend – 表示 XLA 後端字串:'cpu''gpu''tpu'

  • methods – 如果 targetModule,則為 Module 的 jit 方法。

傳回

target 的包裝版本,設定為即時編譯。

flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None)#

解除 jax.checkpoint 的版本。

透過回傳啟動中反向傳播期間重新運算的檢查點式技巧能減少記憶體的使用。在訓練大型模型時,檢查點式的模型片段會很有幫助,以用額外運算來交換記憶體的使用。

範例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CheckpointedMLP(nn.Module):
...   @nn.checkpoint
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(128)(x)
...     x = nn.relu(x)
...     x = nn.Dense(1)(x)
...     return x
...
>>> model = CheckpointedMLP()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 16)))

這個函式被定義成 remat 就如同 jax.remat 一樣。

參數
  • target – 一個以模組做為其第一個參數的 Module 或函式。在運算目標的梯度時會重新運算中間的運算。

  • variables – 提升的變數集合。預設值會提升所有集合。

  • rng – 提升的 PRNG 序列。預設值會提升所有 PRNG 序列。

  • concrete – 選擇性的,表示 fun 是否會涉及根據價值的 Python 控制流程(預設 False)。此種控制流程支援是選擇性的,而且預設是停用的,因為在和 jax.jit() 的某個邊界案例組合中,這會導致一些額外的運算。

  • prevent_cse – 選擇性,布林值會指出在由微分產生的 HLO 中,是否防止慣用子運算式消除 (CSE) 優化。此 CSE 防止措施會造成成本,因為它可能妨礙其他優化,而且它可能會在某些後端產生很高的額外負擔,特別是 GPU。預設值為 True,因為否則在 jitpmap 中,CSE 會抵銷這個裝飾器的目的。但在某些設定中,例如在 scan 內部使用時,此 CSE 防止機制是不必要的,此時應將 prevent_cse 設為 False。

  • static_argnums – 選擇性,整數或整數序列,指出哪些引數值用於追蹤和快取目的。將引數指定為靜態值可在追蹤時避免具體化型態錯誤,但會花費較多重新追蹤的負擔。

  • policy – 試驗性檢查點政策,請參閱 jax.checkpoint

  • methods – 選擇性的方法名稱清單,這些名稱將會提升,如果 methods 為 None(預設值),則只會提升 __call__ 方法。如果``target`` 是函式,則 methods 會被忽略。

傳回

target 的封裝版本。在運算梯度時,中間運算會在反向傳遞中重新運算。

flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]#

結合 remat 和 scan 以達到記憶體效率和常時間編譯。

remat_scan 允許對應於模型深度,固定的編譯時間及次線性記憶體使用量。以極小的固定懲罰為代價。這通常對非常深的模型有益。

範例

>>> import flax.linen as nn

>>> class BigModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10))
...     # 100x dense with O(sqrt(N)) memory for gradient computation
...     return DenseStack(8, name="dense_stack")(x)
參數
  • targetModule 或函數,以 Module 作為第一個引數。

  • lengths – 給定層級的循環迭代次數。總迭代次數 n = prod(lengths)。每個循環都會重新具現化。這樣一來記憶體消耗量與 n^(1 / d) 成正比,其中 d = len(lengths)。最小記憶體消耗量需要調整長度,讓嵌套循環的每個層級消耗相同的記憶體量。

  • policy – 試驗性檢查點政策,請參閱 jax.checkpoint

  • 可變廣播 – 指定廣播變數集合。廣播變數不應依賴於任何無法從迴圈中移除的運算。這通常用於在 fn 內部定義共用參數。

  • 可變進位 – 指定在迴圈中傳遞的變數集合。對這些變數的突變會傳遞到下一個反覆運算,且在掃描結束時仍會保留。

  • variable_axes – 掃描的變數集合。預設為 {True: 0}

  • split_rngs – 分割 PRNG 序列將對每個循環迭代不同。如果 split 為 False,PRNG 在迭代中會相同。預設為 {True: True}

傳回

重複自身 prod(lengths) 次的 target 包裝版。

flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#

對模組內的變數進行對應。

map_variables 可用於在應用模組前後轉換模組內的變數。這可以用於將模組的權重進行遮罩,而不修改模組本身。

範例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CausalDense(nn.Module):
...   '''A dense layer that masks the weights such that the output is
...   causal, i.e. output i only depends on input <= i.
...   '''
...   features: int
...
...   def apply_mask(self, variables):
...     return (jax.tree_util.tree_map(jnp.triu, variables)
...             if not self.is_initializing() else variables)
...
...   def setup(self):
...     # temporary class
...     _CausalDense = nn.map_variables(
...       nn.Dense, 'params', self.apply_mask, init=self.is_initializing())
...     self.dense = _CausalDense(features=self.features, use_bias=False)
...
...   def __call__(self, x):
...     return self.dense(x)
...
>>> module = CausalDense(features=5)
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
參數
  • target – 要轉換的模組或函式。

  • mapped_collections – 要轉換的集合。

  • trans_in_fn – 在套用模組或函式之前修改變數。

  • trans_out_fn – 在套用模組或函式後修改變數,僅在 initmutable 不是 False 的情況下套用。

  • init – 如果為 True,變數會在轉換前初始化。

  • mutable – 如果為 True,已對應的變數集合將可變更。

  • rngs – 新增到已轉換範圍的 PRNG 序列(預設為全部)。

  • variables – 新增到已轉換範圍的其他變數集合。除了由 target 指定的集合外(預設為全部)。

  • methods – 如果 targetModule,對應 Module 的方法用於對應變數。

傳回

對應指定集合的 target 的封裝版。

flax.linen.jvp(fn, mdl, primals, tangents, variable_tangents, variables=True, rngs=True)[原始碼]#

jax.jvp 的提升版。

請參閱 jax.jvp 以取得未提升的雅可比向量積(向前梯度)。

請注意,不會傳回變數切線。當需要變數切線時,其值應由 fn 使用 Module.variables 明確傳回

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     p = self.param('test', nn.initializers._init(), ())
...     return p * x

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     scale = LearnScale()
...     vars_t = jax.tree_util.tree_map(jnp.ones_like,
...                                     scale.variables.get('params', {}))
...     _, out_t = nn.jvp(
...         lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),),
...         variable_tangents={'params': vars_t})
...     return out_t

範例

>>> def learn_scale(scope, x):
...   p = scope.param('scale', nn.initializers.zeros_init(), ())
...   return p * x

>>> def f(scope, x):
...   vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {}))
...   x, out_t = lift.jvp(
...       learn_scale, scope, (x,), (jnp.zeros_like(x),),
...       variable_tangents={'params': vars_t})
...   return out_t
參數
  • fn – 要區分的函式。其引數應該是陣列、純量或陣列或純量的標準 Python 容器。在傳回值方面,它應該是陣列、純量或陣列或純量的標準 Python 容器。函式會收到範圍和原值作為引數。

  • mdl – 將區分其變數的模組。

  • primals – 應對其評估 fun 雅可比行列式的原值。應該是引數的串列或清單,並且其長度應等於 fun 的位置參數數量。

  • 切線 – 評估雅可比向量乘積的切線向量。應該是元祖或切線清單,具有與 primals 相同的樹狀結構和陣列形狀。

  • 變數切線 - 一個與範疇具有相同結構的字典或 PyTree,字典中的每個條目指定一個變數集合的切線。在變數切線中不指定集合等於傳遞一個零向量作為切線。

  • 變數 - 其他在 fn 可用的變數集合,但不會接收切線。

  • rng - 在 fn 內可用的 PRNG。

傳回

(primals_out, tangents_out) 對,其中 primals_outfun(*primals),而 tangents_outfunction 的雅可比向量乘積,在 primals 中評估並使用 tangentstangents_out 值具有與 primals_out 相同的 Python 樹狀結構和形狀。

flax.linen.vjp(fn, mdl, *primals, has_aux=False, reduce_axes=(), vjp_variables='params', variables=True, rngs=True, multi_scope=False)[source]#

jax.vjp 的提升版本。

請參閱 jax.vjp,了解未提升的向量雅可比乘積(反向梯度)。

請注意,vjp_variables` 所指定的集合中會傳回所有變數的梯度。但是,反向函數只會希望 `fn` 的傳回值為餘切。如果變數也需要餘切,可以使用 `Module.variables` 從 `fn` 傳回。

範例

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x, y):
...     p = self.param('scale', nn.initializers.zeros_init(), ())
...     return p * x * y

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, y):
...     z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
...     params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape))
...     return z, params_grad, x_grad, y_grad
參數
  • fn – 要區分的函式。其引數應該是陣列、純量或陣列或純量的標準 Python 容器。在傳回值方面,它應該是陣列、純量或陣列或純量的標準 Python 容器。函式會收到範圍和原值作為引數。

  • mdl – 將區分其變數的模組。

  • *原值 – 在此位置,應評估 `fn` 的 Jacobian 值基本值順序。`primals` 的長度應等於 `fn` 的位置參數個數。每一個基本值都應為陣列、純量或其標準 Python 容器的組成。`

  • has_aux – 選用,布林值。指出 `fn` 是否傳回一對值作為要微分的數學函數的輸出,以及第二個元素的輔助資料。預設值 `False`。

  • reduce_axes – 選用,軸名稱組成。如果這裡列出軸,而且 `fn` 在該軸上隱式廣播值,反向傳遞會針對應對梯度的 `psum` 執行。否則,VJP 會在名為軸的名稱軸上針對每個範例執行。例如,如果 `'batch'` 為一個命名批次軸,`vjp(f, *args, reduce_axes=('batch',))` 會建立一個在批次上求和的 VJP 函數,而 `vjp(f, *args)` 會建立一個每個範例的 VJP。

  • vjp_variables – vjpfun 將傳回這項篩選器所指定的變數集合的餘切向量。

  • 變數 – `fn` 內可用的其他變數集合,但未收到餘切。

  • rng - 在 fn 內可用的 PRNG。

  • multi_scope – 針對從外部模組傳出的包含多個範圍的模組,允許針對多個範圍傳回變數梯度,而不是出錯。

傳回

has_auxFalse,傳回 (primals_out, vjpfun) 配對,其中 primals_outfn(*primals)vjpfun 是從切向量(形狀與 primals_out 相同)到切向量組元(形狀與 primals 相同)的函數,表示在 primals 評估 fn 的向量-雅可比乘積。若 has_auxTrue,傳回 (primals_out, vjpfun, aux) 組元,其中 auxfn 傳回的輔助資料。

flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[來源碼]#

jax.custom_vjp 的提升版本。

forward_fnbackward_fn 共同定義 fn 的自訂 vjp。若未計算 vjp(後向梯度),將執行原始的 fn

forward_fn 接收的參數與 fn 相同,但預期傳回包含 fn(mdl, *args) 輸出以及傳送給 backward_fn 的殘差元組。

backward_fn 接收非微分參數、殘差與輸出切線。它應傳回包含變數與輸入切線的元組。

請注意,nn.vjp 傳回的 vjp 函式可傳遞為殘差值並用於 backward_fn 中。反向傳遞期間,範圍不可用。如果在 backward_fn 中需要模組,可以擷取變數的快照並傳回為 forward_fn 中的殘差值。

範例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def f(mdl, x):
...       return mdl(x)
...
...     def fwd(mdl, x):
...       return nn.vjp(f, mdl, x)
...
...     def bwd(vjp_fn, y_t):
...       params_t, *inputs_t = vjp_fn(y_t)
...       params_t = jax.tree_util.tree_map(jnp.sign, params_t)
...       return (params_t, *inputs_t)
...
...     sign_grad = nn.custom_vjp(
...         f, forward_fn=fwd, backward_fn=bwd)
...     return sign_grad(nn.Dense(1), x).reshape(())

>>> x = jnp.ones((2,))
>>> variables = Foo().init(jax.random.key(0), x)
>>> grad = jax.grad(Foo().apply)(variables, x)
參數
  • function – 定義自訂 vjp 的函式。

  • forward_fn – 帶有與 function 相同引數的函式,傳回一個包含原始輸出和將傳遞給 backward_fn 的殘差值的元組。

  • backward_fn – 引數傳遞為 (*nondiff_args, residuals, tangents)。此函式應傳回一個元組,其中包含 grad_vars 指定的集合中變數的切線,以及輸入引數(模組和非 diff 引數除外)。

  • grad_vars – 需要計算 vjp 的集合(預設值: 「params」)。

  • nondiff_argnums – 不需要計算 vjp 的引數。

傳回

具有與 function 相同簽章的函式,包含自訂 vjp。

flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[source]#

jax.lax.while_loop 的提升版。

將提升的範圍傳遞給 cond_fnbody_fn。廣播變數無法變更。進位變數可變更,但無法變更形狀和資料類型。這也表示您無法在主體內初始化變數。如果需要初始化變數,請考慮在呼叫 while_loop 之前先手動呼叫一次 body_fn

範例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class WhileLoopExample(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def cond_fn(mdl, c):
...       return mdl.variables['state']['acc'] < 10
...     def body_fn(mdl, c):
...       acc = mdl.variable('state', 'acc', lambda: jnp.array(0))
...       acc.value += 1
...       y = nn.Dense(c.shape[-1])(c)
...       return y
...     c = x
...     if self.is_mutable_collection('params'):
...       return body_fn(self, c)
...     else:
...       return nn.while_loop(cond_fn, body_fn, self, c,
...                             carry_variables='state')

>>> k = jax.random.key(0)
>>> x = jnp.ones((2, 2))
>>> initial_vars = WhileLoopExample().init(k, x)
>>> result, state = WhileLoopExample().apply(initial_vars, x, mutable=['state'])
參數
  • cond_fn – 只要迴圈應繼續,就應該回傳 True。

  • body_fn – while 迴圈的主體。

  • mdl – 應該提升到迴圈的模組。

  • init – 傳遞給迴圈的初始狀態

  • carry_variables – 在迴圈中傳遞且因此可變更的集合(預設值:無)。

  • broadcast_variables – 封閉且因此唯讀的集合(預設值:所有集合)

  • 分離隨機數產生器 – 分離的 PRNG 序列對於每個迴圈反覆運算都會不同。如果分離為 False,則 PRNG 會在反覆運算中相同。

傳回

執行 while 迴圈後的最終狀態。

flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#

已提升的 jax.lax.cond 版本。

true_funfalse_fun 回傳的值必須具有相同的 Pytree 結構、形狀和資料類型。分支內部建立或更新的變數也必須具有相同的結構。請注意,如果只在一個分支中建立變數或子模組,則會違反此限制。因為只在一個分支中初始化變數會導致參數結構不同。

範例

>>> import flax.linen as nn

>>> class CondExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, pred):
...     self.variable('state', 'true_count', lambda: 0)
...     self.variable('state', 'false_count', lambda: 0)
...     def true_fn(mdl, x):
...       mdl.variable('state', 'true_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def false_fn(mdl, x):
...       mdl.variable('state', 'false_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     return nn.cond(pred, true_fn, false_fn, self, x)
參數
  • pred – 決定是評估 true_fun 或 false_fun。

  • true_fun – 當 predTrue 時評估的函數。簽章為: (模組、 *運算元) -> T。

  • false_fun – 評估的函數,當 predFalse。符號為 (module, *operands) -> T。

  • mdl – 目標 Module 可供傳遞。

  • *operands – 傳遞給 true_funfalse_fun 的引數

  • variables – 傳遞到條件分支的變數集合 (預設:全部)

  • rngs – 傳遞到條件式的 PRNG 順序 (預設:全部)

傳回

評估分支的結果 (true_funfalse_fun)。

flax.linen.switch(index, branches, mdl, *operands, variables=True, rngs=True)[來源碼]#

jax.lax.switch 的抬高版本。

branches 回傳的值必須擁有相同的 Pytree 結構、形狀與資料型態。在分支內建立或更新的變數也必須具有相同的結構。請注意,僅在一個分支內建立變數或子模組時,會違背此限制。這是因為僅在一個分支內初始化變數會導致參數結構有所不同。

範例

>>> import flax.linen as nn

>>> class SwitchExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, index):
...     self.variable('state', 'a_count', lambda: 0)
...     self.variable('state', 'b_count', lambda: 0)
...     self.variable('state', 'c_count', lambda: 0)
...     def a_fn(mdl, x):
...       mdl.variable('state', 'a_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def b_fn(mdl, x):
...       mdl.variable('state', 'b_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     def c_fn(mdl, x):
...       mdl.variable('state', 'c_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     return nn.switch(index, [a_fn, b_fn, c_fn], self, x)

如果您希望個別分支的參數結構有所不同,您應該在呼叫 switch 之前,於初始化中執行所有分支

>>> class MultiHeadSwitchExample(nn.Module):
...   def setup(self) -> None:
...     self.heads = [
...       nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]),
...       nn.Sequential([nn.Dense(11), nn.Dense(5)]),
...       nn.Dense(5),
...     ]
...
...   @nn.compact
...   def __call__(self, x, index):
...     def head_fn(i):
...       return lambda mdl, x: mdl.heads[i](x)
...     branches = [head_fn(i) for i in range(len(self.heads))]
...
...     # run all branches on init
...     if self.is_mutable_collection('params'):
...       for branch in branches:
...         _ = branch(self, x)
...
...     return nn.switch(index, branches, self, x)
參數
  • index – 整數標量型別,指出要套用的分支函數。

  • branches – 依據 index 套用的函數序列。各個函數的符號為 (module, *operands) -> T。

  • mdl – 目標 Module 可供傳遞。

  • *operands – 傳遞到分支的引數。

  • variables – 傳遞到條件分支的變數集合 (預設:全部)

  • rngs – 傳遞到條件式的 PRNG 順序 (預設:全部)

傳回

評估分支的結果。