轉換#
模組的 JAX 轉換。
Jax 函數式轉換操作純函式。Flax 擴充這些轉換,也可以操作具有狀態變數和 PRNG 序列的模組。我們稱這些擴充版本為「提升的轉換」。
提昇的轉換可套用於 Module
類別或將 Module
執行個體作為其第一個引數的函數。
- flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#
jax.vmap
的提昇版本。請參閱
jax.vmap
,以取得 Jax 中未提昇的批次轉換。vmap
可用於將批次軸新增至Module
。例如,我們可以建立Dense
的一個版本,其批次軸不共用參數>>> import flax.linen as nn >>> BatchDense = nn.vmap( ... nn.Dense, ... in_axes=0, out_axes=0, ... variable_axes={'params': 0}, ... split_rngs={'params': True})
透過使用
variable_axes={'params': 0}
,我們表示參數本身會進行對應,因此不會沿著對應的軸共用。因此,我們也將分割「params」RNG,否則參數將沿著對應的軸以相同方式初始化。類似地,
vmap
可用於新增參數共用的批次軸>>> import flax.linen as nn >>> BatchDense = nn.vmap( ... nn.Dense, ... in_axes=0, out_axes=0, ... variable_axes={'params': None}, ... split_rngs={'params': False})
在這裡,我們使用
variable_axes={'params': None}
來表示參數變數將沿著對應的軸共用。因此,也必須共用「params」RNG。- 參數
target –
Module
或函數,以Module
作為第一個引數。variable_axes – 會提升到批次轉換中的變數集合。使用
None
表示廣播集合或整數,以對軸進行對應。例如,傳入variable_axes={'params': None}
表示應沿著對應的軸共用參數變數。split_rngs – 分割的 PRNG 序列對批次維度的每個索引都會不同。未分割的 PRNG 將廣播。
in_axes – 指定輸入引數的對應(請參閱
jax.vmap
)。out_axes – 指定傳回值的對應(請參閱
jax.vmap
)。axis_size – 指定批次軸的大小。如果無法從輸入引數導出,才需要指定。
axis_name – 為批次軸指定名稱。可以與並行簡化基元一起使用(例如
jax.lax.pmean
、jax.lax.ppermute
等)。請注意,這只用於 pmap 和分片對應。對於 SPMD jit,您不需要手動同步。只要確保正確註解軸,而 XLA:SPMD 會插入必要的資料集。methods – 如果
target
是Module
,則為Module
的方法對應 vmap。spmd_axis_name – 軸名稱新增到出現在
fn
中的任何 pjit 分片約束。另請參閱 google/flax。metadata_params – 在變數樹中傳遞給 AxisMetadata 實例的引數字典。
- 傳回
target
的批次/向量化版本,具有相同的引數,但多出軸在in_axes
指示的位置,以及相同的傳回值,但多出軸在out_axes
指示的位置。
- flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None, _split_transpose=False)[來源碼]#
jax.lax.scan
的提升版本。請參閱
jax.lax.scan
以了解 Jax 中未提升的掃描。為了提升與
vmap
的一致性,此版本的掃描使用in_axes
和out_axes
來決定掃描哪些參數以及沿著哪個軸掃描。scan
區分迴圈內 3 種不同類型的值掃描:在迴圈中反覆運算的值。所有掃描值在掃描的軸上的大小必須相同。掃描的輸出將沿掃描軸堆疊。
傳遞:傳遞值在每次迴圈反覆運算時更新。在整個迴圈中,它的形狀和資料類型必須相同。
廣播:迴圈中封閉的值。當變數廣播時,它們通常在迴圈內部初始化,但獨立於迴圈變數。
應具有簽章目標值
,其中(module, 進位數, *xs) -> (進位數, ys)
和xs
是出入迴圈的掃描值。ys
範例
>>> import flax.linen as nn >>> import jax >>> import jax.numpy as jnp ... >>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ScanLSTM = nn.scan( ... nn.LSTMCell, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... lstm = ScanLSTM(self.features) ... input_shape = x[:, 0].shape ... carry = lstm.initialize_carry(jax.random.key(0), input_shape) ... carry, x = lstm(carry, x) ... return x ... >>> x = jnp.ones((4, 12, 7)) >>> module = LSTM(features=32) >>> y, variables = module.init_with_output(jax.random.key(0), x)
請注意,提供函式給
時,掃描從第三個引數開始處理所有引數,如nn.scan
所指定。上一個範例也可以使用函式形式撰寫為in_axes
>>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ... cell = nn.LSTMCell(self.features) ... def body_fn(cell, carry, x): ... carry, y = cell(carry, x) ... return carry, y ... scan = nn.scan( ... body_fn, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... input_shape = x[:, 0].shape ... carry = cell.initialize_carry( ... jax.random.key(0), input_shape) ... carry, x = scan(cell, carry, x) ... return x ... >>> module = LSTM(features=32) >>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))
您也可以使用
將多個層次合併為單一掃描迴圈,以減少 JAX 程式編譯時間;當您有一系列您想要反覆套用於輸入的相同層次時,您可以這麼做。例如scan
>>> class ResidualMLPBlock(nn.Module): ... @nn.compact ... def __call__(self, x, _): ... h = nn.Dense(features=2)(x) ... h = nn.relu(h) ... return x + h, None ... >>> class ResidualMLP(nn.Module): ... n_layers: int = 4 ... ... @nn.compact ... def __call__(self, x): ... ScanMLP = nn.scan( ... ResidualMLPBlock, variable_axes={'params': 0}, ... variable_broadcast=False, split_rngs={'params': True}, ... length=self.n_layers) ... x, _ = ScanMLP()(x, None) ... return x ... >>> model = ResidualMLP(n_layers=4) >>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))
為同時減少編譯和記憶體使用量,您可以使用
,它會在掃描迴圈中檢查點每個層次。remat_scan()
- 參數
target –
Module
或函數,以Module
作為第一個引數。可變軸 – 掃描的變數集合。
可變廣播 – 指定廣播變數集合。廣播變數不應依賴於任何無法從迴圈中移除的運算。這通常用於在 fn 內部定義共用參數。
可變進位 – 指定在迴圈中傳遞的變數集合。對這些變數的突變會傳遞到下一個反覆運算,且在掃描結束時仍會保留。
分離隨機數產生器 – 分離的 PRNG 序列對於每個迴圈反覆運算都會不同。如果分離為 False,則 PRNG 會在反覆運算中相同。
in_axes – 指定要針對引數進行掃描的軸。應該是引數的前綴樹。使用
將整個輸入傳送到掃描主體的每個反覆運算。flax.core.broadcast
out_axes – 指定要針對回傳值進行掃描的軸。應該是回傳值的前綴樹。
長度 – 指定迴圈反覆運算的數量。這只需要在無法從掃描引數得出時指定。
反轉 – 如果為真,則從結束反向掃描到開始。
取消展開 – 取消在迴圈一次 iteration 內展開掃描迭代的次數(預設值:1)。
data_transform – 選擇性函式用於變換提升的掃描 body_fn 內的原始函數核心變數和 rng 群組,用於內聯 SPMD 註解。
metadata_params – 在變數樹中傳遞給 AxisMetadata 實例的引數字典。
methods – 如果
target
是Module
,則掃描Module
的方法。_split_transpose – 一項實驗性質功能,用於將掃描的轉置拆分成掃描和對應,由實驗性質 Jax lax.scan() 功能所支援。
- 傳回
具有簽章
(module, carry, *xs) -> (carry, ys)
的掃描函式,其中xs
和ys
是進出迴圈的掃描值。
- flax.linen.jit(target, variables=True, rngs=True, static_argnums=(), static_argnames=(), donate_argnums=(), device=None, backend=None, methods=None)[來源]#
提升版本的
jax.jit
。- 參數
target –
Module
或函數,以Module
作為第一個引數。variables – 提升的變數集合。預設值會提升所有集合。
rng – 提升的 PRNG 序列。預設值會提升所有 PRNG 序列。
static_argnums – 指定將哪些位置參數視為靜態(編譯時期常數)的 int 或 int 集合。僅依賴於靜態參數的操作將摺疊成 Python 常數(在追蹤期間),因此對應的參數值可以是任何 Python 物件。靜態參數應該可進行雜湊(表示同時實作了
__hash__
和__eq__
)且不可變。針對這些常數呼叫 jitted 函式,將會觸發重新編譯。如果以比static_argnums
指出的位置參數更少的參數呼叫 jitted 函式,將會引發錯誤。非陣列或其容器的參數必須標示為靜態。預設值為 ()。static_argnames – 指定將哪些命名參數視為靜態(編譯時期常數)的選用字串或字串集合。請參閱
static_argnums
的註解以取得詳細資料。如果未提供,但已設定static_argnums
,則預設值會根據呼叫inspect.signature(fun)
來找出對應的命名參數。donate_argnums – 指定哪些參數「提供」給運算。如果您在運算完成後不再需要參數,則可以安全地提供參數。在某些情況下,XLA 可以利用提供的緩衝區減少執行運算所需的記憶體量,例如回收您的其中一個輸入緩衝區來儲存結果。您不應重複使用您提供給運算的緩衝區,如果您嘗試執行,JAX 會引發錯誤。
device – 這是實驗性功能,且 API 可能會變更。jitted 函式將執行的裝置(選用)。(可用裝置可透過
jax.devices()
擷取。)預設值會從 XLA 的 DeviceAssignment 邏輯繼承,通常會使用jax.devices()[0]
backend – 表示 XLA 後端字串:
'cpu'
、'gpu'
或'tpu'
。methods – 如果
target
是Module
,則為Module
的 jit 方法。
- 傳回
target 的包裝版本,設定為即時編譯。
- flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None)#
解除
jax.checkpoint
的版本。透過回傳啟動中反向傳播期間重新運算的檢查點式技巧能減少記憶體的使用。在訓練大型模型時,檢查點式的模型片段會很有幫助,以用額外運算來交換記憶體的使用。
範例
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CheckpointedMLP(nn.Module): ... @nn.checkpoint ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(128)(x) ... x = nn.relu(x) ... x = nn.Dense(1)(x) ... return x ... >>> model = CheckpointedMLP() >>> variables = model.init(jax.random.key(0), jnp.ones((1, 16)))
這個函式被定義成
remat
就如同jax.remat
一樣。- 參數
target – 一個以模組做為其第一個參數的
Module
或函式。在運算目標的梯度時會重新運算中間的運算。variables – 提升的變數集合。預設值會提升所有集合。
rng – 提升的 PRNG 序列。預設值會提升所有 PRNG 序列。
concrete – 選擇性的,表示
fun
是否會涉及根據價值的 Python 控制流程(預設False
)。此種控制流程支援是選擇性的,而且預設是停用的,因為在和jax.jit()
的某個邊界案例組合中,這會導致一些額外的運算。prevent_cse – 選擇性,布林值會指出在由微分產生的 HLO 中,是否防止慣用子運算式消除 (CSE) 優化。此 CSE 防止措施會造成成本,因為它可能妨礙其他優化,而且它可能會在某些後端產生很高的額外負擔,特別是 GPU。預設值為 True,因為否則在
jit
或pmap
中,CSE 會抵銷這個裝飾器的目的。但在某些設定中,例如在scan
內部使用時,此 CSE 防止機制是不必要的,此時應將prevent_cse
設為 False。static_argnums – 選擇性,整數或整數序列,指出哪些引數值用於追蹤和快取目的。將引數指定為靜態值可在追蹤時避免具體化型態錯誤,但會花費較多重新追蹤的負擔。
policy – 試驗性檢查點政策,請參閱
jax.checkpoint
。methods – 選擇性的方法名稱清單,這些名稱將會提升,如果
methods
為 None(預設值),則只會提升__call__
方法。如果``target`` 是函式,則methods
會被忽略。
- 傳回
target
的封裝版本。在運算梯度時,中間運算會在反向傳遞中重新運算。
- flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]#
結合 remat 和 scan 以達到記憶體效率和常時間編譯。
remat_scan
允許對應於模型深度,固定的編譯時間及次線性記憶體使用量。以極小的固定懲罰為代價。這通常對非常深的模型有益。範例
>>> import flax.linen as nn >>> class BigModel(nn.Module): ... @nn.compact ... def __call__(self, x): ... DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10)) ... # 100x dense with O(sqrt(N)) memory for gradient computation ... return DenseStack(8, name="dense_stack")(x)
- 參數
target –
Module
或函數,以Module
作為第一個引數。lengths – 給定層級的循環迭代次數。總迭代次數
n = prod(lengths)
。每個循環都會重新具現化。這樣一來記憶體消耗量與n^(1 / d)
成正比,其中d = len(lengths)
。最小記憶體消耗量需要調整長度,讓嵌套循環的每個層級消耗相同的記憶體量。policy – 試驗性檢查點政策,請參閱
jax.checkpoint
。可變廣播 – 指定廣播變數集合。廣播變數不應依賴於任何無法從迴圈中移除的運算。這通常用於在 fn 內部定義共用參數。
可變進位 – 指定在迴圈中傳遞的變數集合。對這些變數的突變會傳遞到下一個反覆運算,且在掃描結束時仍會保留。
variable_axes – 掃描的變數集合。預設為
{True: 0}
。split_rngs – 分割 PRNG 序列將對每個循環迭代不同。如果 split 為 False,PRNG 在迭代中會相同。預設為
{True: True}
。
- 傳回
重複自身 prod(lengths) 次的
target
包裝版。
- flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#
對模組內的變數進行對應。
map_variables
可用於在應用模組前後轉換模組內的變數。這可以用於將模組的權重進行遮罩,而不修改模組本身。範例
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CausalDense(nn.Module): ... '''A dense layer that masks the weights such that the output is ... causal, i.e. output i only depends on input <= i. ... ''' ... features: int ... ... def apply_mask(self, variables): ... return (jax.tree_util.tree_map(jnp.triu, variables) ... if not self.is_initializing() else variables) ... ... def setup(self): ... # temporary class ... _CausalDense = nn.map_variables( ... nn.Dense, 'params', self.apply_mask, init=self.is_initializing()) ... self.dense = _CausalDense(features=self.features, use_bias=False) ... ... def __call__(self, x): ... return self.dense(x) ... >>> module = CausalDense(features=5) >>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
- 參數
target – 要轉換的模組或函式。
mapped_collections – 要轉換的集合。
trans_in_fn – 在套用模組或函式之前修改變數。
trans_out_fn – 在套用模組或函式後修改變數,僅在
init
或mutable
不是 False 的情況下套用。init – 如果為 True,變數會在轉換前初始化。
mutable – 如果為 True,已對應的變數集合將可變更。
rngs – 新增到已轉換範圍的 PRNG 序列(預設為全部)。
variables – 新增到已轉換範圍的其他變數集合。除了由
target
指定的集合外(預設為全部)。methods – 如果
target
是Module
,對應Module
的方法用於對應變數。
- 傳回
對應指定集合的
target
的封裝版。
- flax.linen.jvp(fn, mdl, primals, tangents, variable_tangents, variables=True, rngs=True)[原始碼]#
jax.jvp
的提升版。請參閱
jax.jvp
以取得未提升的雅可比向量積(向前梯度)。請注意,不會傳回變數切線。當需要變數切線時,其值應由
fn
使用Module.variables
明確傳回>>> import flax.linen as nn >>> import jax.numpy as jnp >>> class LearnScale(nn.Module): ... @nn.compact ... def __call__(self, x): ... p = self.param('test', nn.initializers._init(), ()) ... return p * x >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... scale = LearnScale() ... vars_t = jax.tree_util.tree_map(jnp.ones_like, ... scale.variables.get('params', {})) ... _, out_t = nn.jvp( ... lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),), ... variable_tangents={'params': vars_t}) ... return out_t
範例
>>> def learn_scale(scope, x): ... p = scope.param('scale', nn.initializers.zeros_init(), ()) ... return p * x >>> def f(scope, x): ... vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {})) ... x, out_t = lift.jvp( ... learn_scale, scope, (x,), (jnp.zeros_like(x),), ... variable_tangents={'params': vars_t}) ... return out_t
- 參數
fn – 要區分的函式。其引數應該是陣列、純量或陣列或純量的標準 Python 容器。在傳回值方面,它應該是陣列、純量或陣列或純量的標準 Python 容器。函式會收到範圍和原值作為引數。
mdl – 將區分其變數的模組。
primals – 應對其評估
fun
雅可比行列式的原值。應該是引數的串列或清單,並且其長度應等於fun
的位置參數數量。切線 – 評估雅可比向量乘積的切線向量。應該是元祖或切線清單,具有與
primals
相同的樹狀結構和陣列形狀。變數切線 - 一個與範疇具有相同結構的字典或 PyTree,字典中的每個條目指定一個變數集合的切線。在變數切線中不指定集合等於傳遞一個零向量作為切線。
變數 - 其他在
fn
可用的變數集合,但不會接收切線。rng - 在
fn
內可用的 PRNG。
- 傳回
(primals_out, tangents_out)
對,其中primals_out
是fun(*primals)
,而tangents_out
是function
的雅可比向量乘積,在primals
中評估並使用tangents
。tangents_out
值具有與primals_out
相同的 Python 樹狀結構和形狀。
- flax.linen.vjp(fn, mdl, *primals, has_aux=False, reduce_axes=(), vjp_variables='params', variables=True, rngs=True, multi_scope=False)[source]#
jax.vjp
的提升版本。請參閱
jax.vjp
,了解未提升的向量雅可比乘積(反向梯度)。請注意,
vjp_variables
` 所指定的集合中會傳回所有變數的梯度。但是,反向函數只會希望 `fn
` 的傳回值為餘切。如果變數也需要餘切,可以使用 `Module.variables
` 從 `fn
` 傳回。範例
>>> import flax.linen as nn >>> import jax.numpy as jnp >>> class LearnScale(nn.Module): ... @nn.compact ... def __call__(self, x, y): ... p = self.param('scale', nn.initializers.zeros_init(), ()) ... return p * x * y >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, y): ... z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) ... params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape)) ... return z, params_grad, x_grad, y_grad
- 參數
fn – 要區分的函式。其引數應該是陣列、純量或陣列或純量的標準 Python 容器。在傳回值方面,它應該是陣列、純量或陣列或純量的標準 Python 容器。函式會收到範圍和原值作為引數。
mdl – 將區分其變數的模組。
*原值 – 在此位置,應評估 `
fn
` 的 Jacobian 值基本值順序。`primals
` 的長度應等於 `fn
` 的位置參數個數。每一個基本值都應為陣列、純量或其標準 Python 容器的組成。`has_aux – 選用,布林值。指出 `
fn
` 是否傳回一對值作為要微分的數學函數的輸出,以及第二個元素的輔助資料。預設值 `False
`。reduce_axes – 選用,軸名稱組成。如果這裡列出軸,而且 `
fn
` 在該軸上隱式廣播值,反向傳遞會針對應對梯度的 `psum
` 執行。否則,VJP 會在名為軸的名稱軸上針對每個範例執行。例如,如果 `'batch'
` 為一個命名批次軸,`vjp(f, *args, reduce_axes=('batch',))
` 會建立一個在批次上求和的 VJP 函數,而 `vjp(f, *args)
` 會建立一個每個範例的 VJP。vjp_variables – vjpfun 將傳回這項篩選器所指定的變數集合的餘切向量。
變數 – `
fn
` 內可用的其他變數集合,但未收到餘切。rng - 在
fn
內可用的 PRNG。multi_scope – 針對從外部模組傳出的包含多個範圍的模組,允許針對多個範圍傳回變數梯度,而不是出錯。
- 傳回
若
has_aux
為False
,傳回(primals_out, vjpfun)
配對,其中primals_out
為fn(*primals)
。vjpfun
是從切向量(形狀與primals_out
相同)到切向量組元(形狀與primals
相同)的函數,表示在primals
評估fn
的向量-雅可比乘積。若has_aux
為True
,傳回(primals_out, vjpfun, aux)
組元,其中aux
是fn
傳回的輔助資料。
- flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[來源碼]#
jax.custom_vjp
的提升版本。forward_fn
與backward_fn
共同定義fn
的自訂 vjp。若未計算 vjp(後向梯度),將執行原始的fn
。forward_fn
接收的參數與fn
相同,但預期傳回包含fn(mdl, *args)
輸出以及傳送給backward_fn
的殘差元組。backward_fn
接收非微分參數、殘差與輸出切線。它應傳回包含變數與輸入切線的元組。請注意,
nn.vjp
傳回的 vjp 函式可傳遞為殘差值並用於backward_fn
中。反向傳遞期間,範圍不可用。如果在backward_fn
中需要模組,可以擷取變數的快照並傳回為forward_fn
中的殘差值。範例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... def f(mdl, x): ... return mdl(x) ... ... def fwd(mdl, x): ... return nn.vjp(f, mdl, x) ... ... def bwd(vjp_fn, y_t): ... params_t, *inputs_t = vjp_fn(y_t) ... params_t = jax.tree_util.tree_map(jnp.sign, params_t) ... return (params_t, *inputs_t) ... ... sign_grad = nn.custom_vjp( ... f, forward_fn=fwd, backward_fn=bwd) ... return sign_grad(nn.Dense(1), x).reshape(()) >>> x = jnp.ones((2,)) >>> variables = Foo().init(jax.random.key(0), x) >>> grad = jax.grad(Foo().apply)(variables, x)
- 參數
function – 定義自訂 vjp 的函式。
forward_fn – 帶有與
function
相同引數的函式,傳回一個包含原始輸出和將傳遞給backward_fn
的殘差值的元組。backward_fn – 引數傳遞為
(*nondiff_args, residuals, tangents)
。此函式應傳回一個元組,其中包含grad_vars
指定的集合中變數的切線,以及輸入引數(模組和非 diff 引數除外)。grad_vars – 需要計算 vjp 的集合(預設值: 「params」)。
nondiff_argnums – 不需要計算 vjp 的引數。
- 傳回
具有與
function
相同簽章的函式,包含自訂 vjp。
- flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[source]#
jax.lax.while_loop 的提升版。
將提升的範圍傳遞給
cond_fn
和body_fn
。廣播變數無法變更。進位變數可變更,但無法變更形狀和資料類型。這也表示您無法在主體內初始化變數。如果需要初始化變數,請考慮在呼叫while_loop
之前先手動呼叫一次body_fn
。範例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class WhileLoopExample(nn.Module): ... @nn.compact ... def __call__(self, x): ... def cond_fn(mdl, c): ... return mdl.variables['state']['acc'] < 10 ... def body_fn(mdl, c): ... acc = mdl.variable('state', 'acc', lambda: jnp.array(0)) ... acc.value += 1 ... y = nn.Dense(c.shape[-1])(c) ... return y ... c = x ... if self.is_mutable_collection('params'): ... return body_fn(self, c) ... else: ... return nn.while_loop(cond_fn, body_fn, self, c, ... carry_variables='state') >>> k = jax.random.key(0) >>> x = jnp.ones((2, 2)) >>> initial_vars = WhileLoopExample().init(k, x) >>> result, state = WhileLoopExample().apply(initial_vars, x, mutable=['state'])
- 參數
cond_fn – 只要迴圈應繼續,就應該回傳 True。
body_fn – while 迴圈的主體。
mdl – 應該提升到迴圈的模組。
init – 傳遞給迴圈的初始狀態
carry_variables – 在迴圈中傳遞且因此可變更的集合(預設值:無)。
broadcast_variables – 封閉且因此唯讀的集合(預設值:所有集合)
分離隨機數產生器 – 分離的 PRNG 序列對於每個迴圈反覆運算都會不同。如果分離為 False,則 PRNG 會在反覆運算中相同。
- 傳回
執行 while 迴圈後的最終狀態。
- flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#
已提升的
jax.lax.cond
版本。從
true_fun
和false_fun
回傳的值必須具有相同的 Pytree 結構、形狀和資料類型。分支內部建立或更新的變數也必須具有相同的結構。請注意,如果只在一個分支中建立變數或子模組,則會違反此限制。因為只在一個分支中初始化變數會導致參數結構不同。範例
>>> import flax.linen as nn >>> class CondExample(nn.Module): ... @nn.compact ... def __call__(self, x, pred): ... self.variable('state', 'true_count', lambda: 0) ... self.variable('state', 'false_count', lambda: 0) ... def true_fn(mdl, x): ... mdl.variable('state', 'true_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... def false_fn(mdl, x): ... mdl.variable('state', 'false_count').value += 1 ... return -nn.Dense(2, name='dense')(x) ... return nn.cond(pred, true_fn, false_fn, self, x)
- 參數
- 傳回
評估分支的結果 (
true_fun
或false_fun
)。
- flax.linen.switch(index, branches, mdl, *operands, variables=True, rngs=True)[來源碼]#
jax.lax.switch
的抬高版本。branches
回傳的值必須擁有相同的 Pytree 結構、形狀與資料型態。在分支內建立或更新的變數也必須具有相同的結構。請注意,僅在一個分支內建立變數或子模組時,會違背此限制。這是因為僅在一個分支內初始化變數會導致參數結構有所不同。範例
>>> import flax.linen as nn >>> class SwitchExample(nn.Module): ... @nn.compact ... def __call__(self, x, index): ... self.variable('state', 'a_count', lambda: 0) ... self.variable('state', 'b_count', lambda: 0) ... self.variable('state', 'c_count', lambda: 0) ... def a_fn(mdl, x): ... mdl.variable('state', 'a_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... def b_fn(mdl, x): ... mdl.variable('state', 'b_count').value += 1 ... return -nn.Dense(2, name='dense')(x) ... def c_fn(mdl, x): ... mdl.variable('state', 'c_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... return nn.switch(index, [a_fn, b_fn, c_fn], self, x)
如果您希望個別分支的參數結構有所不同,您應該在呼叫 switch 之前,於初始化中執行所有分支
>>> class MultiHeadSwitchExample(nn.Module): ... def setup(self) -> None: ... self.heads = [ ... nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]), ... nn.Sequential([nn.Dense(11), nn.Dense(5)]), ... nn.Dense(5), ... ] ... ... @nn.compact ... def __call__(self, x, index): ... def head_fn(i): ... return lambda mdl, x: mdl.heads[i](x) ... branches = [head_fn(i) for i in range(len(self.heads))] ... ... # run all branches on init ... if self.is_mutable_collection('params'): ... for branch in branches: ... _ = branch(self, x) ... ... return nn.switch(index, branches, self, x)
- 參數
index – 整數標量型別,指出要套用的分支函數。
branches – 依據 index 套用的函數序列。各個函數的符號為 (module, *operands) -> T。
mdl – 目標 Module 可供傳遞。
*operands – 傳遞到分支的引數。
variables – 傳遞到條件分支的變數集合 (預設:全部)
rngs – 傳遞到條件式的 PRNG 順序 (預設:全部)
- 傳回
評估分支的結果。