🔪 Flax - 精髓所在 🔪#

Flax 發揮了 JAX 的全部效能。然而,就像使用 JAX 時一樣,當使用 Flax 時您可能會遇到某些「棘手的部分」。這份不斷演進的文件旨在協助您解決這些部分。

首先,安裝及/或更新 Flax

! pip install -qq flax

🔪 flax.linen.Dropout 層及其隨機性#

TL;DR#

處理帶有輟學率的模型時(由 Flax Module 子類化),僅在向前傳遞期間新增'dropout' PRNG 金鑰。

  1. jax.random.split() 開始,為 'params''dropout' 明確建立 PRNG 金鑰。

  2. flax.linen.Dropout 層新增至您的模型(由 Flax Module 子類化)。

  3. 在初始化模型(flax.linen.init())時,無需傳遞額外的 'dropout' PRNG 金鑰,就像在「較為簡單」的模型中一樣,只需傳遞 'params' 金鑰即可。

  4. 在使用 flax.linen.apply() 進行向前傳遞時,傳遞 rngs={'dropout': dropout_key}

請在下方查看完整的範例。

此方法運作的原因#

  • 在內部,flax.linen.Dropout 使用 flax.linen.Module.make_rng 為輟學率建立金鑰(查看 原始程式碼)。

  • 每次呼叫 make_rng(在本例中,呼叫是隱式的,發生在 Dropout 中),您會從主/根 PRNG 金鑰中拆分一個新的 PRNG 金鑰。

  • make_rng確保完全可重製

背景#

隨機正規化的 dropout 技術會隨機刪除網路中的隱含和可見單元。Dropout 是項隨機運算,需要 PRNG 狀態,而 Flax(像 JAX)使用可分割的 Threefry PRNG。

注意:請回想 JAX 有種明確的方式可以提供 PRNG 金鑰:你可以將主 PRNG 狀態(例如 key = jax.random.key(seed=0))分叉成多個新的 PRNG 金鑰,方法為 key, subkey = jax.random.split(key)。在 🔪 JAX - 犀利的部分 🔪 隨機性和 PRNG 金鑰 中更新你的記憶。

Flax 提供一種藉由 Flax Moduleflax.linen.Module.make_rng 輔助函數,來隱式處理 PRNG 金鑰串流的方法。如此一來,Flax Module(或其子-Module)中的程式碼就能「拉取 PRNG 金鑰」。make_rng 保證每次呼叫時都能提供獨一無二的金鑰。進一步的詳細資訊,請參閱 RNG 指南

注意:請回想 flax.linen.Module 是所有神經網路模組的基本類別。所有的層級和模型都是其子類別。

範例#

請謹記每個 Flax PRNG 串流都有個名稱。以下範例使用 'params' 串流來初始化參數,也使用 'dropout' 串流。提供給 flax.linen.init() 的 PRNG 金鑰,是提供 給 'params' PRNG 金鑰串流種子的金鑰。要在正向通過(有 dropout)期間抽取 PRNG 金鑰,請呼叫 Module.apply() 時提供 PRNG 金鑰來提供那個串流種子('dropout')。

# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Randomness.
seed = 0
root_key = jax.random.key(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

# A simple network.
class MyModel(nn.Module):
  num_neurons: int
  training: bool
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a rate of 50% .
    # When the `deterministic` flag is `True`, dropout is turned off.
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
    return x

# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)

x = jax.random.uniform(key=main_key, shape=(3, 4, 4))

# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.) 
variables = my_model.init(params_key, x)

# Perform the forward pass with `flax.linen.apply()`.
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})

實際範例