Optimizer#

class flax.nnx.optimizer.Optimizer(*args, **kwargs)#

對於具有單個 Optax 優化器的常見情況,簡化的訓練狀態。

使用範例

>>> import jax, jax.numpy as jnp
>>> from flax import nnx
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     return self.linear2(self.linear1(x))
...
>>> x = jax.random.normal(jax.random.key(0), (1, 2))
>>> y = jnp.ones((1, 4))
...
>>> model = Model(nnx.Rngs(0))
>>> tx = optax.adam(1e-3)
>>> state = nnx.Optimizer(model, tx)
...
>>> loss_fn = lambda model: ((model(x) - y) ** 2).mean()
>>> loss_fn(model)
Array(1.7055722, dtype=float32)
>>> grads = nnx.grad(loss_fn)(state.model)
>>> state.update(grads)
>>> loss_fn(model)
Array(1.6925814, dtype=float32)

請注意,您可以透過子類化此類別來輕鬆擴展它,以儲存額外資料(例如,添加指標)。

使用範例

>>> class TrainState(nnx.Optimizer):
...   def __init__(self, model, tx, metrics):
...     self.metrics = metrics
...     super().__init__(model, tx)
...   def update(self, *, grads, **updates):
...     self.metrics.update(**updates)
...     super().update(grads)
...
>>> metrics = nnx.metrics.Average()
>>> state = TrainState(model, tx, metrics)
...
>>> grads = nnx.grad(loss_fn)(state.model)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
Array(1.6925814, dtype=float32)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
Array(1.68612, dtype=float32)

對於更特殊的用例(例如,多個優化器),最好分叉類別並修改它。

step#

一個 OptState Variable,用於追蹤步數。

model#

已封裝的 Module

tx#

一個 Optax 梯度轉換。

opt_state#

Optax 優化器狀態。

__init__(model, tx, wrt=<class 'flax.nnx.variablelib.Param'>)#

實例化類別並封裝 Module 和 Optax 梯度轉換。實例化優化器狀態,以追蹤在 wrt 中指定的 Variable 類型。將步數設定為 0。

參數
  • model – 一個 NNX 模組。

  • tx – 一個 Optax 梯度轉換。

  • wrt – 可選參數,用於篩選在優化器狀態中要追蹤的 Variable。這些應該是您計劃更新的 Variable;即,此參數值應與傳遞至 nnx.grad 呼叫的 wrt 參數匹配,該呼叫將產生將傳遞至 update() 方法的 grads 參數中的梯度。

update(grads, **kwargs)#

更新傳回值中的 stepparamsopt_state**kwargsgrads 必須從 nnx.grad(..., wrt=self.wrt) 衍生而來,其中梯度是相對於此 Optimizer 實例化期間在 self.wrt 中定義的相同 Variable 類型。例如

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> import optax

>>> class CustomVariable(nnx.Variable):
...   pass

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.custom_variable = CustomVariable(jnp.ones((1, 3)))
...   def __call__(self, x):
...     return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(model))
State({
  'custom_variable': VariableState(
    type=CustomVariable,
    value=(1, 3)
  ),
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})

>>> # update:
>>> # - only Linear layer parameters
>>> # - only CustomVariable parameters
>>> # - both Linear layer and CustomVariable parameters
>>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
>>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)):
...   # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad`
...   state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable)
...   grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))(
...     state.model, jnp.ones((1, 2)), jnp.ones((1, 3))
...   )
...   state.update(grads=grads)

請注意,此函數在內部會呼叫 .tx.update(),然後呼叫 optax.apply_updates() 來更新 paramsopt_state

參數
  • grads – 從 nnx.grad 衍生的梯度。

  • **kwargs – 傳遞至 tx.update 的其他關鍵字引數,以支援

  • GradientTransformationExtraArgs

  • optax.scale_by_backtracking_linesearch. (例如) –