中斷#
本指南提供如何使用 中斷 的概述,方法是使用 flax.linen.Dropout()
。
中斷是一種隨機正則化技術,用於隨機移除網路中的隱藏和可見單元。
在整個指南中,你將能夠比較使用和未使用 Flax Dropout
的程式碼範例。
分割 PRNG 金鑰#
由於中斷是一種隨機操作,因此需要一個偽亂數產生器 (PRNG) 狀態。Flax 使用 JAX 的 (可分割) PRNG 金鑰,這些金鑰對神經網路具有許多理想的特性。若要深入了解,請參閱 JAX 教學課程中的偽亂數。
注意:請記住 JAX 有明確的方法可以提供你 PRNG 金鑰:你可以將主 PRNG 狀態(例如 key = jax.random.key(seed=0)
)分割成多個新的 PRNG 金鑰,方法為 key, subkey = jax.random.split(key)
。你可以在 🔪 JAX - The Sharp Bits 🔪 隨機性和 PRNG 金鑰 中更新你的記憶。
首先使用 jax.random.split() 將 PRNG 金鑰分割成三個金鑰,包括一個適用於 Flax Linen Dropout
的金鑰。
root_key = jax.random.key(seed=0)
main_key, params_key = jax.random.split(key=root_key)
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
注意:在 Flax 中,你需要使用 名稱 來提供 PRNG 資料流,以便你稍後可以在 flax.linen.Module()
中使用它們。例如,你要傳遞 'params'
資料流來初始化參數,並傳遞 'dropout'
資料流來套用 flax.linen.Dropout()
。
使用 Dropout
定義你的模型#
若要建立有中斷的模型
子類別
flax.linen.Module()
,然後使用flax.linen.Dropout()
來新增 dropout 層。記住flax.linen.Module()
是所有神經網路模組的 基礎類別,所有層和模型都是其子類別。在
flax.linen.Dropout()
中,deterministic
參數必須作為關鍵字參數傳遞,以下何者皆可:在建構
flax.linen.Module()
時;或在已建構的
Module
上呼叫flax.linen.init()
或flax.linen.apply()
時。(請參閱flax.linen.module.merge_param()
以取得更多詳細資料。)
因為
deterministic
是布林值如果它設定為
False
,輸入會以rate
設定的機率進行遮罩(亦即設為 0)。而剩餘輸入會乘上1 / (1 - rate)
,這將確保輸入的平均值得以保留。如果它設定為
True
,不會套用遮罩(關閉 dropout),而會原樣回傳輸入。
一種常見的模式是在父層 Flax Module
中接受一個 training
(或 train
) 參數(一個布林值),並用它來啟用或停用中斷(如本指南後面的章節所示)。在其他機器學習框架中,例如 PyTorch 或 TensorFlow (Keras),這會透過可變狀態或呼叫旗標來指定(例如,在 torch.nn.Module.eval 或設定 training 旗標的 tf.keras.Model
中)。
注意:Flax 提供了一個隱含方法,透過 Flax flax.linen.Module()
的 flax.linen.Module.make_rng()
方法來處理 PRNG 關鍵串流。這允許你從 PRNG 串流中,在 Flax Modules(或其子 Modules)內分離出一個新的 PRNG 關鍵。每次呼叫 make_rng
方法時,都保證能提供一個唯一的關鍵。在內部,flax.linen.Dropout()
使用 flax.linen.Module.make_rng()
來建立一個中斷的關鍵。你可以查看 原始碼。簡而言之,flax.linen.Module.make_rng()
保證完全重現。
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
return x
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a `rate` of 50%.
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
初始化模型#
在建立模型後
建立模型實體。
然後,在
flax.linen.init()
呼叫中,設定training=False
。最後,從 變數詞典 中萃取
params
。
在此,沒有 Flax Dropout
和有 Dropout
的程式碼之間的主要差別,在於如果你需要啟用中斷,就必須提供 training
(或 train
)參數。
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x)
params = variables['params']
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']
在訓練期間執行前向傳遞#
當使用 flax.linen.apply()
執行模型時
將
training=True
傳遞至flax.linen.apply()
。接著,若要在前向傳遞期間繪製 PRNG 金鑰(透過中斷),在呼叫
flax.linen.apply()
時提供一個 PRNG 金鑰來為'dropout'
串流植入種子。
# No need to pass the `training` and `rngs` flags.
y = my_model.apply({'params': params}, x)
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})
在此,沒有 Flax Dropout
和有 Dropout
的程式碼之間的主要差別,在於如果你需要啟用中斷,就必須提供 training
(或 train
)和 rngs
參數。
在評估期間,使用未啟用中斷的上述程式碼(這表示你也不必傳遞 RNG)。
TrainState
和訓練步驟#
此區段說明如何修改訓練步驟函式內的程式碼,如果你已啟用中斷。
備註:請回想 Flax 有以下常規模式,其中你會建立一個資料類別,來表示整個訓練狀態,包括參數和最佳化器狀態。接著,你就可以將單一參數 state: TrainState
傳遞至訓練步驟函式。請參閱 flax.training.train_state.TrainState()
API 文件以了解更多。
首先,新增一個
key
欄位至自訂flax.training.train_state.TrainState()
類別。然後,傳遞
key
值,本範例中是dropout_key
,到train_state.TrainState.create()
方法。
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=my_model.apply,
params=params,
tx=optax.adam(1e-3)
)
from flax.training import train_state
class TrainState(train_state.TrainState):
key: jax.Array
state = TrainState.create(
apply_fn=my_model.apply,
params=params,
key=dropout_key,
tx=optax.adam(1e-3)
)
接著,在 Flax 訓練步驟函式
train_step
,從dropout_key
產生新的 PRNG 金鑰,以便在每個步驟套用遞減。這可以使用下列其中一種方法完成:。使用
jax.random.fold_in()
通常較快。使用jax.random.split()
時,會分割一個 PRNG 金鑰,以便稍後重複使用。但是,使用jax.random.fold_in()
可以確保 1) 摺疊唯一資料;以及 2) 可以產生更長的 PRNG 串流序列。最後,執行前向傳遞時,將新的 PRNG 金鑰傳遞到
state.apply_fn()
作為額外參數。
@jax.jit
def train_step(state: train_state.TrainState, batch):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
training=True,
rngs={'dropout': dropout_train_key}
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
有遞減的 Flax 範例#
基於 Transformer 的模型,於 WMT 機器翻譯資料集上訓練。此範例使用遞減和注意力遞減。
在 文字分類背景下,將字詞遞減套用至批次輸入 ID。此範例使用自訂
flax.linen.Dropout()
圖層。