flax.training 套件#
訓練狀態#
- class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[來源]#
適用於單一 Optax 優化器的常見案例的簡單訓練狀態。
使用範例
>>> import flax.linen as nn >>> from flax.training.train_state import TrainState >>> import jax, jax.numpy as jnp >>> import optax >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 2)) >>> model = nn.Dense(2) >>> variables = model.init(jax.random.key(0), x) >>> tx = optax.adam(1e-3) >>> state = TrainState.create( ... apply_fn=model.apply, ... params=variables['params'], ... tx=tx) >>> def loss_fn(params, x, y): ... predictions = state.apply_fn({'params': params}, x) ... loss = optax.l2_loss(predictions=predictions, targets=y).mean() ... return loss >>> loss_fn(state.params, x, y) Array(3.3514676, dtype=float32) >>> grads = jax.grad(loss_fn)(state.params, x, y) >>> state = state.apply_gradients(grads=grads) >>> loss_fn(state.params, x, y) Array(3.343844, dtype=float32)
請注意,您可以通過繼承此資料類別來輕鬆擴展它,以儲存額外的資料(例如,額外的變數集合)。
對於更特殊的使用案例(例如,多個優化器),最好複製該類別並修改它。
- 參數
step – 計數器從 0 開始,每次呼叫
.apply_gradients()
時都會遞增。apply_fn – 通常設定為
model.apply()
。為了方便起見,將其保留在此資料類別中,以便在您的訓練迴圈中,train_step()
函數的參數列表較短。params – 要由
tx
更新並由apply_fn
使用的參數。tx – Optax 梯度轉換。
opt_state –
tx
的狀態。
- apply_gradients(*, grads, **kwargs)[來源]#
更新傳回值中的
step
、params
、opt_state
和**kwargs
。請注意,此函數內部會呼叫
.tx.update()
,然後呼叫optax.apply_updates()
以更新params
和opt_state
。- 參數
grads – 梯度,其具有與
.params
相同的 pytree 結構。**kwargs – 應使用
.replace()
的其他資料類別屬性。
- 傳回值
已更新的
self
執行個體,其中step
遞增 1,params
和opt_state
已透過套用grads
來更新,並且其他屬性已由kwargs
指定的方式來取代。