🔪 Flax - 精髓所在 🔪#
Flax 發揮了 JAX 的全部效能。然而,就像使用 JAX 時一樣,當使用 Flax 時您可能會遇到某些「棘手的部分」。這份不斷演進的文件旨在協助您解決這些部分。
首先,安裝及/或更新 Flax
! pip install -qq flax
🔪 flax.linen.Dropout
層及其隨機性#
TL;DR#
處理帶有輟學率的模型時(由 Flax Module
子類化),僅在向前傳遞期間新增'dropout'
PRNG 金鑰。
從
jax.random.split()
開始,為'params'
和'dropout'
明確建立 PRNG 金鑰。將
flax.linen.Dropout
層新增至您的模型(由 FlaxModule
子類化)。在初始化模型(
flax.linen.init()
)時,無需傳遞額外的'dropout'
PRNG 金鑰,就像在「較為簡單」的模型中一樣,只需傳遞'params'
金鑰即可。在使用
flax.linen.apply()
進行向前傳遞時,傳遞rngs={'dropout': dropout_key}
請在下方查看完整的範例。
此方法運作的原因#
在內部,
flax.linen.Dropout
使用flax.linen.Module.make_rng
為輟學率建立金鑰(查看 原始程式碼)。每次呼叫
make_rng
(在本例中,呼叫是隱式的,發生在Dropout
中),您會從主/根 PRNG 金鑰中拆分一個新的 PRNG 金鑰。make_rng
仍確保完全可重製。
背景#
隨機正規化的 dropout 技術會隨機刪除網路中的隱含和可見單元。Dropout 是項隨機運算,需要 PRNG 狀態,而 Flax(像 JAX)使用可分割的 Threefry PRNG。
注意:請回想 JAX 有種明確的方式可以提供 PRNG 金鑰:你可以將主 PRNG 狀態(例如
key = jax.random.key(seed=0)
)分叉成多個新的 PRNG 金鑰,方法為key, subkey = jax.random.split(key)
。在 🔪 JAX - 犀利的部分 🔪 隨機性和 PRNG 金鑰 中更新你的記憶。
Flax 提供一種藉由 Flax Module
的 flax.linen.Module.make_rng
輔助函數,來隱式處理 PRNG 金鑰串流的方法。如此一來,Flax Module
(或其子-Module
)中的程式碼就能「拉取 PRNG 金鑰」。make_rng
保證每次呼叫時都能提供獨一無二的金鑰。進一步的詳細資訊,請參閱 RNG 指南。
注意:請回想
flax.linen.Module
是所有神經網路模組的基本類別。所有的層級和模型都是其子類別。
範例#
請謹記每個 Flax PRNG 串流都有個名稱。以下範例使用 'params'
串流來初始化參數,也使用 'dropout'
串流。提供給 flax.linen.init()
的 PRNG 金鑰,是提供 給 'params'
PRNG 金鑰串流種子的金鑰。要在正向通過(有 dropout)期間抽取 PRNG 金鑰,請呼叫 Module.apply()
時提供 PRNG 金鑰來提供那個串流種子('dropout'
)。
# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Randomness.
seed = 0
root_key = jax.random.key(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
# A simple network.
class MyModel(nn.Module):
num_neurons: int
training: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a rate of 50% .
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
return x
# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)
x = jax.random.uniform(key=main_key, shape=(3, 4, 4))
# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.)
variables = my_model.init(params_key, x)
# Perform the forward pass with `flax.linen.apply()`.
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})
實際範例