管理參數與狀態

管理參數與狀態#

我們將展示如何...

  • 從初始化到更新,管理變數。

  • 分割和重新組裝參數與狀態。

  • 與批次相關狀態搭配使用 vmap

class BiasAdderWithRunningMean(nn.Module):
  momentum: float = 0.9

  @nn.compact
  def __call__(self, x):
    is_initialized = self.has_variable('batch_stats', 'mean')
    mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      mean.value = (self.momentum * mean.value +
                    (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True))
    return mean.value + bias

此範例模型是一個簡易範例,包含參數(使用 self.param 宣告)和狀態變數(使用 self.variable 宣告)。

初始化此處棘手的地方在於我們需要將用於最佳化的狀態變數和參數分開。

我們首先以如下方式定義 update_step(使用應由您的內容取代的虛擬損失)

def update_step(apply_fn, x, opt_state, params, state):
  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum() # Replace with your loss here.
    return l, updated_state

  (l, updated_state), grads = jax.value_and_grad(
      loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state

然後,我們可以撰寫實際訓練程式碼。

model = BiasAdderWithRunningMean()
variables = model.init(random.key(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state = update_step(
      model.apply, dummy_input, opt_state, params, state)

vmap 橫跨批次維度

當使用 vmap,並且管理狀態取決於批次維度(例如當使用 BatchNorm 時),上述設定必須略作修改。這是因為任何狀態取決於批次維度的圖層都不是嚴格向量化的。在 BatchNorm 的情況下,必須使用 lax.pmean() 在批次維度上平均統計資訊,以便狀態與批次中的每個項目同步。

這需要兩項小變更。首先,我們需要在模型定義中命名批次軸。在此,這會透過指定 BatchNormaxis_name 參數來進行。在您自己的程式碼中,這可能需要直接指定 lax.pmean()axis_name 參數。

class MLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x, train=False):
    norm = partial(
        nn.BatchNorm,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        axis_name="batch", # Name batch dim
    )

    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    y = nn.Dense(self.out_size)(x)

    return y

其次,我們需要在訓練程式碼中呼叫 vmap 時指定相同名稱

def update_step(apply_fn, x_batch, y_batch, opt_state, params, state):

  def batch_loss(params):
    def loss_fn(x, y):
      pred, updated_state = apply_fn(
        {'params': params, **state},
        x, mutable=list(state.keys())
      )
      return (pred - y) ** 2, updated_state

    loss, updated_state = jax.vmap(
      loss_fn, out_axes=(0, None),  # Do not vmap `updated_state`.
      axis_name='batch'  # Name batch dim
    )(x_batch, y_batch)  # vmap only `x`, `y`, but not `state`.
    return jnp.mean(loss), updated_state

  (loss, updated_state), grads = jax.value_and_grad(
    batch_loss, has_aux=True
  )(params)

  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state, loss

請注意,我們還需要指定模型狀態沒有批次維度。現在,我們可以訓練模型了

model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.key(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state, loss = update_step(
      model.apply, X, Y, opt_state, params, state)