flax.training 套件#

訓練狀態#

class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[來源]#

適用於單一 Optax 優化器的常見案例的簡單訓練狀態。

使用範例

>>> import flax.linen as nn
>>> from flax.training.train_state import TrainState
>>> import jax, jax.numpy as jnp
>>> import optax

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 2))
>>> model = nn.Dense(2)
>>> variables = model.init(jax.random.key(0), x)
>>> tx = optax.adam(1e-3)

>>> state = TrainState.create(
...     apply_fn=model.apply,
...     params=variables['params'],
...     tx=tx)

>>> def loss_fn(params, x, y):
...   predictions = state.apply_fn({'params': params}, x)
...   loss = optax.l2_loss(predictions=predictions, targets=y).mean()
...   return loss
>>> loss_fn(state.params, x, y)
Array(3.3514676, dtype=float32)

>>> grads = jax.grad(loss_fn)(state.params, x, y)
>>> state = state.apply_gradients(grads=grads)
>>> loss_fn(state.params, x, y)
Array(3.343844, dtype=float32)

請注意,您可以通過繼承此資料類別來輕鬆擴展它,以儲存額外的資料(例如,額外的變數集合)。

對於更特殊的使用案例(例如,多個優化器),最好複製該類別並修改它。

參數
  • step – 計數器從 0 開始,每次呼叫 .apply_gradients() 時都會遞增。

  • apply_fn – 通常設定為 model.apply()。為了方便起見,將其保留在此資料類別中,以便在您的訓練迴圈中,train_step() 函數的參數列表較短。

  • params – 要由 tx 更新並由 apply_fn 使用的參數。

  • tx – Optax 梯度轉換。

  • opt_statetx 的狀態。

apply_gradients(*, grads, **kwargs)[來源]#

更新傳回值中的 stepparamsopt_state**kwargs

請注意,此函數內部會呼叫 .tx.update(),然後呼叫 optax.apply_updates() 以更新 paramsopt_state

參數
  • grads – 梯度,其具有與 .params 相同的 pytree 結構。

  • **kwargs – 應使用 .replace() 的其他資料類別屬性。

傳回值

已更新的 self 執行個體,其中 step 遞增 1,paramsopt_state 已透過套用 grads 來更新,並且其他屬性已由 kwargs 指定的方式來取代。

classmethod create(*, apply_fn, params, tx, **kwargs)[來源]#

建立 step=0 並初始化 opt_state 的新執行個體。