將我的程式碼庫升級到 Optax#

我們已提出在 2021 年用 Optax 取代 flax.optim,讓其符合 FLIP #1009,而 Flax 優化器已移除在 v0.6.0 中 - 此指南提供給 flax.optim 使用者,協助他們將程式碼更新為 Optax。

請參閱 Optax 快速入門文件: https://optax.readthedocs.io/en/latest/getting_started.html

optax 取代 flax.optim #

Optax 提供所有 Flax 優化器的替代方案。請參閱 Optax 的文件 常見的最佳化,了解 API 詳細資訊。

用法非常類似,不同之處在於 optax 沒有儲存 params 的副本,因此需要另外傳遞。Flax 提供公用程式 TrainState,目的是在單一資料類別中儲存最佳化狀態、參數和其他關聯資料(以下程式碼未採用)。

@jax.jit
def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)


  return optimizer.apply_gradient(grads)

optimizer_def = flax.optim.Momentum(
    learning_rate, momentum)
optimizer = optimizer_def.create(variables['params'])

for batch in get_ds_train():
  optimizer = train_step(optimizer, batch)
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)

for batch in ds_train:
  params, opt_state = train_step(params, opt_state, batch)

可組合的梯度轉換 #

上述程式碼範本中所使用的函式 optax.sgd() 基本上是兩個梯度轉換連續應用的包裝函式。与其使用此別名,通常會使用 optax.chain() 來結合多個此類的泛用建構模組。

# Note that the aliases follow the convention to use positive
# values for the learning rate by default.
tx = optax.sgd(learning_rate, momentum)
#

tx = optax.chain(
    # 1. Step: keep a trace of past updates and add to gradients.
    optax.trace(decay=momentum),
    # 2. Step: multiply result from step 1 with negative learning rate.
    # Note that `optax.apply_updates()` simply adds the final updates to the
    # parameters, so we must make sure to flip the sign here for gradient
    # descent.
    optax.scale(-learning_rate),
)

權重衰減 #

Flax 的部分最佳化也包含權重衰減。在 Optax 中,部分最佳化也有權重衰減參數(例如 optax.adamw()),而其他最佳化的權重衰減則可以另新增一個「梯度轉換」 optax.add_decayed_weights(),能新增來自參數的更新。

optimizer_def = flax.optim.Adam(
    learning_rate, weight_decay=weight_decay)
optimizer = optimizer_def.create(variables['params'])
# (Note that you could also use `optax.adamw()` in this case)
tx = optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay),
    # params -= learning_rate * (adam(grads) + params * weight_decay)
    optax.scale(-learning_rate),
)
# Note that you'll need to specify `params` when computing the udpates:
# tx.update(grads, opt_state, params)

梯度裁剪 #

訓練可以透過將梯度剪裁至全球範數來穩定,(Pascanu 等人,2012)。在 Flax 中,這通常是在將梯度傳遞給最佳化器之前處理梯度時進行。使用 Optax,這就成為另一個梯度轉換 optax.clip_by_global_norm()

def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  grads_flat, _ = jax.tree_util.tree_flatten(grads)
  global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
  g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
  grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)
  return optimizer.apply_gradient(grads)
tx = optax.chain(
    optax.clip_by_global_norm(grad_clip_norm),
    optax.trace(decay=momentum),
    optax.scale(-learning_rate),
)

學習率排程#

對於學習率排程,Flax 允許在應用梯度時覆寫超參數。Optax 維護步驟計數器,並將其提供作為一個函式的引數,以調整使用 optax.scale_by_schedule() 新增的更新。Optax 還允許透過 optax.inject_hyperparams() 指定函式,以插入其他梯度更新的任意純量值。

lr_schedule 指南中深入了解學習率排程。

最佳化器排程 下深入了解在 Optax 中定義的排程。標準最佳化器(例如 optax.adam()optax.sgd() 等)也接受學習率排程作為 learning_rate 的參數。

def train_step(step, optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))
tx = optax.chain(
    optax.trace(decay=momentum),
    # Note that we still want a negative value for scaling the updates!
    optax.scale_by_schedule(lambda step: -schedule(step)),
)

多個最佳化器/更新參數子集#

在 Flax 中,遍歷用於指定最佳化器應該更新哪些參數。您可以使用 flax.optim.MultiOptimizer 結合遍歷,以對不同參數套用不同的最佳化器。Optax 中的等效寫法是 optax.masked()optax.chain()

請注意,以下範例使用 flax.traverse_util 來建立 optax.masked() 所需的布林遮罩,您也可以手動建立或使用 optax.multi_transform(),讓 pytree 能接受多值,以指定梯度轉換。

請小心 optax.masked() 會在內部將 pytree 扁平化,而且內部梯度轉換僅使用 params/gradients 部份的扁平化視圖呼叫,這通常不會造成問題,但會造成難以巢狀多層次的遮蔽梯度轉換(原因是內部遮罩期望遮罩會依據外部遮罩外的部分扁平化視圖來定義)。

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

kernel_opt = flax.optim.Momentum(learning_rate, momentum)
bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)


optimizer = flax.optim.MultiOptimizer(
    (kernels, kernel_opt),
    (biases, bias_opt)
).create(variables['params'])
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

tx = optax.chain(
    optax.trace(decay=momentum),
    optax.masked(optax.scale(-learning_rate), kernels_mask),
    optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),
)

最後的話#

當然,上述所有模式都可以混合使用,而 Optax 讓您可以在主訓練迴圈外將所有這些轉換封裝在單一位置,這讓測試工作容易許多。