批次正規化#
在此指南中,您將學習如何使用 批次正規化,方法是使用 flax.linen.BatchNorm
。
批次正規化是一種正規化技術,用於加快訓練速度並改善收斂性。在訓練期間,它會計算特徵維度的平均數。這會新增一種必須妥善處理的不可微狀態形式。
在整份指南中,你都將能夠比較含有 Flax BatchNorm
的程式碼範例和不含它的範例。
使用 BatchNorm
定義模型#
在 Flax 中,BatchNorm
是 flax.linen.Module
,它在訓練和推論之間展現出不同的執行時間行為。你必須透過 use_running_average
參數明確指定,如下所示。
一個常見的模式是,在父層的 Flax Module
中接受 train
(training
) 參數,並使用它來定義 BatchNorm
的 use_running_average
參數。
注意:在 PyTorch 或 TensorFlow (Keras) 等其他機器學習架構中,これは經由可變狀態或呼叫旗標來指定的(例如,在 torch.nn.Module.eval 或 tf.keras.Model
中,透過設定 training 旗標)。
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=4)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
class MLP(nn.Module):
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(features=4)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
在建立你的模型之後,請呼叫 flax.linen.init()
來初始化它,以取得 variables
結構。在此,沒有 BatchNorm
的程式碼和有 BatchNorm
的程式碼之間的主要差別在於,必須提供 train
參數。
The batch_stats
集合#
除了 params
集合之外,BatchNorm
還新增一個 batch_stats
集合,其中包含批次統計資料的執行平均數。
註解:您可以在 flax.linen
變數 API 文件中瞭解詳細資訊。
必須從 variables
中擷取 batch_stats
集合以供後續使用。
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x)
params = variables['params']
jax.tree_util.tree_map(jnp.shape, variables)
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x, train=False)
params = variables['params']
batch_stats = variables['batch_stats']
jax.tree_util.tree_map(jnp.shape, variables)
Flax BatchNorm
總共會新增 4 個變數:存在 batch_stats
集合中的 mean
和 var
以及存在 params
集合中的 scale
和 bias
。
FrozenDict({
'params': {
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
FrozenDict({
'batch_stats': {
'BatchNorm_0': {
'mean': (4,),
'var': (4,),
},
},
'params': {
'BatchNorm_0': {
'bias': (4,),
'scale': (4,),
},
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
修改 flax.linen.apply
#
當使用 flax.linen.apply
搭配 train==True
參數來執行模型時(也就是說,您在呼叫 BatchNorm
時具有 use_running_average==False
),您需要考量以下事項
batch_stats
必須傳遞為輸入變數。batch_stats
集合需要透過設定mutable=['batch_stats']
來標示為可變。已變異的變數返回為第二個輸出。您必須從這裡擷取已更新的
batch_stats
。
y = mlp.apply(
{'params': params},
x,
)
...
y, updates = mlp.apply(
{'params': params, 'batch_stats': batch_stats},
x,
train=True, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']
訓練和評估#
在將使用 BatchNorm
的模型整合到訓練迴圈中時,主要的問題是如何處理額外的 batch_stats
狀態。為此,您需要
將
batch_stats
欄位新增到自訂的flax.training.train_state.TrainState
類別。將
batch_stats
值傳遞給train_state.TrainState.create
方法。
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=mlp.apply,
params=params,
tx=optax.adam(1e-3),
)
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn=mlp.apply,
params=params,
batch_stats=batch_stats,
tx=optax.adam(1e-3),
)
此外,請更新您的 train_step
函數以反映這些變更
將所有新參數傳遞給
flax.linen.apply
(如前所述)。updates
對batch_stats
的更新必須從loss_fn
宣傳出去。必須更新來自
TrainState
的batch_stats
。
@jax.jit
def train_step(state: train_state.TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
@jax.jit
def train_step(state: TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits, updates = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
x=batch['image'], train=True, mutable=['batch_stats'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
return loss, (logits, updates)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (logits, updates)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=updates['batch_stats'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
eval_step
簡單許多。由於 batch_stats
不是可變的,因此無需宣傳更新。請務必將 batch_stats
傳遞到 flax.linen.apply
,並且將 train
參數設為 False
@jax.jit
def eval_step(state: train_state.TrainState, batch):
"""Train for a single step."""
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
@jax.jit
def eval_step(state: TrainState, batch):
"""Evaluate for a single step."""
logits = state.apply_fn(
{'params': state.params, 'batch_stats': state.batch_stats},
x=batch['image'], train=False)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics