Optimizer#
- class flax.nnx.optimizer.Optimizer(*args, **kwargs)#
對於具有單個 Optax 優化器的常見情況,簡化的訓練狀態。
使用範例
>>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> state = nnx.Optimizer(model, tx) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(1.7055722, dtype=float32) >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads) >>> loss_fn(model) Array(1.6925814, dtype=float32)
請注意,您可以透過子類化此類別來輕鬆擴展它,以儲存額外資料(例如,添加指標)。
使用範例
>>> class TrainState(nnx.Optimizer): ... def __init__(self, model, tx, metrics): ... self.metrics = metrics ... super().__init__(model, tx) ... def update(self, *, grads, **updates): ... self.metrics.update(**updates) ... super().update(grads) ... >>> metrics = nnx.metrics.Average() >>> state = TrainState(model, tx, metrics) ... >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.6925814, dtype=float32) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.68612, dtype=float32)
對於更特殊的用例(例如,多個優化器),最好分叉類別並修改它。
- step#
一個
OptState
Variable
,用於追蹤步數。
- model#
已封裝的
Module
。
- tx#
一個 Optax 梯度轉換。
- opt_state#
Optax 優化器狀態。
- __init__(model, tx, wrt=<class 'flax.nnx.variablelib.Param'>)#
實例化類別並封裝
Module
和 Optax 梯度轉換。實例化優化器狀態,以追蹤在wrt
中指定的Variable
類型。將步數設定為 0。- 參數
model – 一個 NNX 模組。
tx – 一個 Optax 梯度轉換。
wrt – 可選參數,用於篩選在優化器狀態中要追蹤的
Variable
。這些應該是您計劃更新的Variable
;即,此參數值應與傳遞至nnx.grad
呼叫的wrt
參數匹配,該呼叫將產生將傳遞至update()
方法的grads
參數中的梯度。
- update(grads, **kwargs)#
更新傳回值中的
step
、params
、opt_state
和**kwargs
。grads
必須從nnx.grad(..., wrt=self.wrt)
衍生而來,其中梯度是相對於此Optimizer
實例化期間在self.wrt
中定義的相同Variable
類型。例如>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(model)) State({ 'custom_variable': VariableState( type=CustomVariable, value=(1, 3) ), 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> # update: >>> # - only Linear layer parameters >>> # - only CustomVariable parameters >>> # - both Linear layer and CustomVariable parameters >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)): ... # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad` ... state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) ... grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))( ... state.model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) ... state.update(grads=grads)
請注意,此函數在內部會呼叫
.tx.update()
,然後呼叫optax.apply_updates()
來更新params
和opt_state
。- 參數
grads – 從
nnx.grad
衍生的梯度。**kwargs – 傳遞至 tx.update 的其他關鍵字引數,以支援
GradientTransformationExtraArgs –
optax.scale_by_backtracking_linesearch. (例如) –