中斷#

本指南提供如何使用 中斷 的概述,方法是使用 flax.linen.Dropout()

中斷是一種隨機正則化技術,用於隨機移除網路中的隱藏和可見單元。

在整個指南中,你將能夠比較使用和未使用 Flax Dropout 的程式碼範例。

分割 PRNG 金鑰#

由於中斷是一種隨機操作,因此需要一個偽亂數產生器 (PRNG) 狀態。Flax 使用 JAX 的 (可分割) PRNG 金鑰,這些金鑰對神經網路具有許多理想的特性。若要深入了解,請參閱 JAX 教學課程中的偽亂數

注意:請記住 JAX 有明確的方法可以提供你 PRNG 金鑰:你可以將主 PRNG 狀態(例如 key = jax.random.key(seed=0))分割成多個新的 PRNG 金鑰,方法為 key, subkey = jax.random.split(key)。你可以在 🔪 JAX - The Sharp Bits 🔪 隨機性和 PRNG 金鑰 中更新你的記憶。

首先使用 jax.random.split() 將 PRNG 金鑰分割成三個金鑰,包括一個適用於 Flax Linen Dropout 的金鑰。

root_key = jax.random.key(seed=0)
main_key, params_key = jax.random.split(key=root_key)
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

注意:在 Flax 中,你需要使用 名稱 來提供 PRNG 資料流,以便你稍後可以在 flax.linen.Module() 中使用它們。例如,你要傳遞 'params' 資料流來初始化參數,並傳遞 'dropout' 資料流來套用 flax.linen.Dropout()

使用 Dropout 定義你的模型#

若要建立有中斷的模型

  • 子類別 flax.linen.Module(),然後使用 flax.linen.Dropout() 來新增 dropout 層。記住 flax.linen.Module() 是所有神經網路模組的 基礎類別,所有層和模型都是其子類別。

  • flax.linen.Dropout() 中,deterministic 參數必須作為關鍵字參數傳遞,以下何者皆可:

  • 因為 deterministic 是布林值

    • 如果它設定為 False,輸入會以 rate 設定的機率進行遮罩(亦即設為 0)。而剩餘輸入會乘上 1 / (1 - rate),這將確保輸入的平均值得以保留。

    • 如果它設定為 True,不會套用遮罩(關閉 dropout),而會原樣回傳輸入。

一種常見的模式是在父層 Flax Module 中接受一個 training (或 train) 參數(一個布林值),並用它來啟用或停用中斷(如本指南後面的章節所示)。在其他機器學習框架中,例如 PyTorch 或 TensorFlow (Keras),這會透過可變狀態或呼叫旗標來指定(例如,在 torch.nn.Module.eval 或設定 training 旗標的 tf.keras.Model 中)。

注意:Flax 提供了一個隱含方法,透過 Flax flax.linen.Module()flax.linen.Module.make_rng() 方法來處理 PRNG 關鍵串流。這允許你從 PRNG 串流中,在 Flax Modules(或其子 Modules)內分離出一個新的 PRNG 關鍵。每次呼叫 make_rng 方法時,都保證能提供一個唯一的關鍵。在內部,flax.linen.Dropout() 使用 flax.linen.Module.make_rng() 來建立一個中斷的關鍵。你可以查看 原始碼。簡而言之,flax.linen.Module.make_rng() 保證完全重現

class MyModel(nn.Module):
  num_neurons: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)

    return x
class MyModel(nn.Module):
  num_neurons: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a `rate` of 50%.
    # When the `deterministic` flag is `True`, dropout is turned off.
    x = nn.Dropout(rate=0.5, deterministic=not training)(x)
    return x

初始化模型#

在建立模型後

在此,沒有 Flax Dropout 和有 Dropout 的程式碼之間的主要差別,在於如果你需要啟用中斷,就必須提供 training(或 train)參數。

my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))

variables = my_model.init(params_key, x)
params = variables['params']
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']

在訓練期間執行前向傳遞#

當使用 flax.linen.apply() 執行模型時

  • training=True 傳遞至 flax.linen.apply()

  • 接著,若要在前向傳遞期間繪製 PRNG 金鑰(透過中斷),在呼叫 flax.linen.apply() 時提供一個 PRNG 金鑰來為 'dropout' 串流植入種子。

# No need to pass the `training` and `rngs` flags.
y = my_model.apply({'params': params}, x)
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})

在此,沒有 Flax Dropout 和有 Dropout 的程式碼之間的主要差別,在於如果你需要啟用中斷,就必須提供 training(或 train)和 rngs 參數。

在評估期間,使用未啟用中斷的上述程式碼(這表示你也不必傳遞 RNG)。

TrainState 和訓練步驟#

此區段說明如何修改訓練步驟函式內的程式碼,如果你已啟用中斷。

備註:請回想 Flax 有以下常規模式,其中你會建立一個資料類別,來表示整個訓練狀態,包括參數和最佳化器狀態。接著,你就可以將單一參數 state: TrainState 傳遞至訓練步驟函式。請參閱 flax.training.train_state.TrainState() API 文件以了解更多。

  • 首先,新增一個 key 欄位至自訂 flax.training.train_state.TrainState() 類別。

  • 然後,傳遞 key 值,本範例中是 dropout_key,到 train_state.TrainState.create() 方法。

from flax.training import train_state

state = train_state.TrainState.create(
  apply_fn=my_model.apply,
  params=params,

  tx=optax.adam(1e-3)
)
from flax.training import train_state

class TrainState(train_state.TrainState):
  key: jax.Array

state = TrainState.create(
  apply_fn=my_model.apply,
  params=params,
  key=dropout_key,
  tx=optax.adam(1e-3)
)

  • 接著,在 Flax 訓練步驟函式 train_step,從 dropout_key 產生新的 PRNG 金鑰,以便在每個步驟套用遞減。這可以使用下列其中一種方法完成:。

    使用 jax.random.fold_in() 通常較快。使用 jax.random.split() 時,會分割一個 PRNG 金鑰,以便稍後重複使用。但是,使用 jax.random.fold_in() 可以確保 1) 摺疊唯一資料;以及 2) 可以產生更長的 PRNG 串流序列。

  • 最後,執行前向傳遞時,將新的 PRNG 金鑰傳遞到 state.apply_fn() 作為額外參數。

@jax.jit
def train_step(state: train_state.TrainState, batch):

  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'],


      )
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
  dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'],
      training=True,
      rngs={'dropout': dropout_train_key}
      )
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

有遞減的 Flax 範例#

更多使用模組 make_rng() 的 Flax 範例#