批次正規化#

在此指南中,您將學習如何使用 批次正規化,方法是使用 flax.linen.BatchNorm

批次正規化是一種正規化技術,用於加快訓練速度並改善收斂性。在訓練期間,它會計算特徵維度的平均數。這會新增一種必須妥善處理的不可微狀態形式。

在整份指南中,你都將能夠比較含有 Flax BatchNorm 的程式碼範例和不含它的範例。

使用 BatchNorm 定義模型#

在 Flax 中,BatchNormflax.linen.Module,它在訓練和推論之間展現出不同的執行時間行為。你必須透過 use_running_average 參數明確指定,如下所示。

一個常見的模式是,在父層的 Flax Module 中接受 train (training) 參數,並使用它來定義 BatchNormuse_running_average 參數。

注意:在 PyTorch 或 TensorFlow (Keras) 等其他機器學習架構中,これは經由可變狀態或呼叫旗標來指定的(例如,在 torch.nn.Module.evaltf.keras.Model 中,透過設定 training 旗標)。

class MLP(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=4)(x)

    x = nn.relu(x)
    x = nn.Dense(features=1)(x)
    return x
class MLP(nn.Module):
  @nn.compact
  def __call__(self, x, train: bool):
    x = nn.Dense(features=4)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.Dense(features=1)(x)
    return x

在建立你的模型之後,請呼叫 flax.linen.init() 來初始化它,以取得 variables 結構。在此,沒有 BatchNorm 的程式碼和有 BatchNorm 的程式碼之間的主要差別在於,必須提供 train 參數。

The batch_stats 集合#

除了 params 集合之外,BatchNorm 還新增一個 batch_stats 集合,其中包含批次統計資料的執行平均數。

註解:您可以在 flax.linen 變數 API 文件中瞭解詳細資訊。

必須從 variables 中擷取 batch_stats 集合以供後續使用。

mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x)
params = variables['params']


jax.tree_util.tree_map(jnp.shape, variables)
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x, train=False)
params = variables['params']
batch_stats = variables['batch_stats']

jax.tree_util.tree_map(jnp.shape, variables)

Flax BatchNorm 總共會新增 4 個變數:存在 batch_stats 集合中的 meanvar 以及存在 params 集合中的 scalebias

FrozenDict({
  'params': {
    'Dense_0': {
        'bias': (4,),
        'kernel': (3, 4),
    },
    'Dense_1': {
        'bias': (1,),
        'kernel': (4, 1),
    },
  },
})
FrozenDict({
  'batch_stats': {
    'BatchNorm_0': {
        'mean': (4,),
        'var': (4,),
    },
  },
  'params': {
    'BatchNorm_0': {
        'bias': (4,),
        'scale': (4,),
    },
    'Dense_0': {
        'bias': (4,),
        'kernel': (3, 4),
    },
    'Dense_1': {
        'bias': (1,),
        'kernel': (4, 1),
    },
  },
})

修改 flax.linen.apply#

當使用 flax.linen.apply 搭配 train==True 參數來執行模型時(也就是說,您在呼叫 BatchNorm 時具有 use_running_average==False),您需要考量以下事項

  • batch_stats 必須傳遞為輸入變數。

  • batch_stats 集合需要透過設定 mutable=['batch_stats'] 來標示為可變。

  • 已變異的變數返回為第二個輸出。您必須從這裡擷取已更新的 batch_stats

y = mlp.apply(
  {'params': params},
  x,
)
...
y, updates = mlp.apply(
  {'params': params, 'batch_stats': batch_stats},
  x,
  train=True, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']

訓練和評估#

在將使用 BatchNorm 的模型整合到訓練迴圈中時,主要的問題是如何處理額外的 batch_stats 狀態。為此,您需要

from flax.training import train_state


state = train_state.TrainState.create(
  apply_fn=mlp.apply,
  params=params,

  tx=optax.adam(1e-3),
)
from flax.training import train_state

class TrainState(train_state.TrainState):
  batch_stats: Any

state = TrainState.create(
  apply_fn=mlp.apply,
  params=params,
  batch_stats=batch_stats,
  tx=optax.adam(1e-3),
)

此外,請更新您的 train_step 函數以反映這些變更

  • 將所有新參數傳遞給 flax.linen.apply(如前所述)。

  • updatesbatch_stats 的更新必須從 loss_fn 宣傳出去。

  • 必須更新來自 TrainStatebatch_stats

@jax.jit
def train_step(state: train_state.TrainState, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label']).mean()
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)

  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics
@jax.jit
def train_step(state: TrainState, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits, updates = state.apply_fn(
      {'params': params, 'batch_stats': state.batch_stats},
      x=batch['image'], train=True, mutable=['batch_stats'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label']).mean()
    return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics

eval_step 簡單許多。由於 batch_stats 不是可變的,因此無需宣傳更新。請務必將 batch_stats 傳遞到 flax.linen.apply,並且將 train 參數設為 False

@jax.jit
def eval_step(state: train_state.TrainState, batch):
  """Train for a single step."""
  logits = state.apply_fn(
    {'params': params},
    x=batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']).mean()
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics
@jax.jit
def eval_step(state: TrainState, batch):
  """Evaluate for a single step."""
  logits = state.apply_fn(
    {'params': state.params, 'batch_stats': state.batch_stats},
    x=batch['image'], train=False)
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']).mean()
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics