隨機性#
與 Haiku 和 Flax Linen 等系統相比,Flax NNX 中隨機狀態的處理方式已大幅簡化,因為 Flax NNX 將隨機狀態定義為物件狀態。 本質上,這表示在 Flax NNX 中,隨機狀態:1) 只是另一種狀態;2) 儲存在 nnx.Variable
中;3) 由模型本身持有。
Flax NNX 假隨機數產生器 (PRNG) 系統具有以下主要特性
它是明確的。
它是基於順序的。
它使用動態計數器。
這與 Flax Linen 的 PRNG 系統 有些不同,後者是基於 (路徑 + 順序)
,並使用靜態計數器。
注意:若要深入瞭解 JAX 中的隨機數產生、
jax.random
API 和 PRNG 產生的序列,請查看此 JAX PRNG 教學。
讓我們先從一些必要的匯入開始
from flax import nnx
import jax
from jax import random, numpy as jnp
Rngs
、RngStream
和 RngState
#
在 Flax NNX 中,nnx.Rngs
類型是管理隨機狀態的主要便利 API。 延續 Flax Linen 的腳步,nnx.Rngs
能夠建立多個具名的 PRNG 金鑰 串流,每個串流都有自己的狀態,以便在 JAX 轉換 (transforms) 的情況下嚴格控制隨機性。
以下是 Flax NNX 中主要的 PRNG 相關類型
nnx.Rngs
:主要使用者介面。 它定義了一組具名的nnx.RngStream
物件。nnx.RngStream
:可以產生 PRNG 金鑰串流的物件。 它在nnx.RngKey
和nnx.RngCount
nnx.Variable
中分別保存一個根key
和一個count
。 當產生新的金鑰時,計數會遞增。nnx.RngState
:所有 RNG 相關狀態的基礎類型。nnx.RngKey
:用於保存 PRNG 金鑰的 NNX 變數類型。 它包含一個tag
屬性,其中包含 PRNG 金鑰串流的名稱。nnx.RngCount
:用於保存 PRNG 計數的 NNX 變數類型。 它包含一個tag
屬性,其中包含 PRNG 金鑰串流名稱。
若要建立 nnx.Rngs
物件,您可以簡單地將整數種子或 jax.random.key
實例傳遞給建構函式中您選擇的任何關鍵字引數。
這是一個範例
rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)
請注意,key
和 count
nnx.Variable
在 tag
屬性中包含 PRNG 金鑰串流名稱。 這主要用於篩選,我們稍後會看到。
若要產生新的金鑰,您可以存取其中一個串流,並使用其 __call__
方法,且不帶任何引數。 這將透過使用 random.fold_in
與目前的 key
和 count
來傳回新的金鑰。 然後,count
會遞增,以便後續呼叫將傳回新的金鑰。
params_key = rngs.params()
dropout_key = rngs.dropout()
nnx.display(rngs)
請注意,當產生新的 PRNG 金鑰時,key
屬性不會變更。
標準 PRNG 金鑰串流名稱#
Flax NNX 的內建模組只使用兩個標準 PRNG 金鑰串流名稱,如下表所示
PRNG 金鑰串流名稱 |
描述 |
---|---|
|
用於參數初始化 |
|
由 |
params
在建構期間被大多數標準層 (例如nnx.Linear
、nnx.Conv
、nnx.MultiHeadAttention
等) 用於初始化其參數。dropout
由nnx.Dropout
和nnx.MultiHeadAttention
用於產生 dropout 遮罩。
以下是一個簡單的模型範例,該模型使用 params
和 dropout
PRNG 金鑰串流
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.drop(self.linear(x)))
model = Model(nnx.Rngs(params=0, dropout=1))
y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
y.shape = (1, 10)
預設 PRNG 金鑰串流#
具有具名串流的缺點之一是,使用者在建立 nnx.Rngs
物件時,需要知道模型將使用的所有可能的名稱。 雖然這可以透過一些文件來解決,但 Flax NNX 提供了一個 default
串流,當找不到串流時,可以用作回退。 若要使用預設 PRNG 金鑰串流,您可以簡單地將整數種子或 jax.random.key
作為第一個位置引數傳遞。
rngs = nnx.Rngs(0, params=1)
key1 = rngs.params() # Call params.
key2 = rngs.dropout() # Fallback to the default stream.
key3 = rngs() # Call the default stream directly.
# Test with the `Model` that uses `params` and `dropout`.
model = Model(rngs)
y = model(jnp.ones((1, 20)))
nnx.display(rngs)
如上所示,也可以透過呼叫 nnx.Rngs
物件本身來產生來自 default
串流的 PRNG 金鑰。
注意
對於大型專案,建議使用具名串流以避免潛在的衝突。 對於小型專案或快速原型設計,僅使用default
串流是一個不錯的選擇。
篩選隨機狀態#
可以使用 篩選器 來操作隨機狀態,就像任何其他類型的狀態一樣。 可以使用類型 (nnx.RngState
、nnx.RngKey
、nnx.RngCount
) 或使用與串流名稱對應的字串來篩選 (請參閱 Flax NNX Filter
DSL)。 以下範例使用 nnx.state
和各種篩選器,以選取 Model
內 Rngs
的不同子狀態
model = Model(nnx.Rngs(params=0, dropout=1))
rng_state = nnx.state(model, nnx.RngState) # All random states.
key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.
count_state = nnx.state(model, nnx.RngCount) # Only counts.
rng_params_state = nnx.state(model, 'params') # Only `params`.
rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # `Params` PRNG keys.
nnx.display(params_key_state)
重新設定種子#
在 Haiku 和 Flax Linen 中,每次在呼叫模型之前,都會將隨機狀態明確地傳遞至 Module.apply
。 這使得在需要時 (例如,為了重現性) 輕鬆控制模型的隨機性。
在 Flax NNX 中,有兩種方法可以處理此問題
透過手動將
nnx.Rngs
物件傳遞到__call__
堆疊中。標準層(如nnx.Dropout
和nnx.MultiHeadAttention
)會接受rngs
參數,如果您想嚴格控制隨機狀態的話。透過使用
nnx.reseed
將模型的隨機狀態設定為特定組態。這個選項的侵入性較小,即使模型並非設計為允許手動控制隨機狀態,也可以使用。
nnx.reseed
是一個函式,它接受任意的圖形節點(包括 pytree 的 nnx.Module
)以及一些包含 nnx.RngStream
新種子或鍵值的關鍵字參數(由參數名稱指定)。nnx.reseed
接著會遍歷圖形並更新符合的 nnx.RngStream
的隨機狀態,這包括將 key
設定為可能的新值,並將 count
重置為零。
以下範例示範如何使用 nnx.reseed
重置 nnx.Dropout
層的隨機狀態,並驗證其計算結果與首次呼叫模型時相同。
model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))
y1 = model(x)
y2 = model(x)
nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)
assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3) # same
分割 PRNG 鍵#
當與 Flax NNX 轉換(如 nnx.vmap
或 nnx.pmap
)互動時,通常需要分割隨機狀態,以便每個副本都有自己獨特的狀態。這可以透過兩種方式完成:
在將鍵傳遞到其中一個
nnx.Rngs
流之前,手動分割一個鍵;或使用
nnx.split_rngs
修飾詞,它會自動分割函式輸入中找到的任何nnx.RngStream
的隨機狀態,並在函式呼叫結束時自動「降低」它們。
使用 nnx.split_rngs
更方便,因為它能與 Flax NNX 轉換良好搭配,以下為一個範例:
rngs = nnx.Rngs(params=0, dropout=1)
@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
print('Inside:')
# rngs.dropout() # ValueError: fold_in accepts a single key...
nnx.display(rngs)
f(rngs)
print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)
Inside:
Outside:
注意:
nnx.split_rngs
允許將 NNXFilter
傳遞到only
關鍵字參數,以便選擇在函式內部應分割的nnx.RngStream
。在這種情況下,您只需要分割dropout
PRNG 鍵流。
轉換#
如前所述,在 Flax NNX 中,隨機狀態只是另一種類型的狀態。這表示在 Flax NNX 轉換方面它沒有任何特殊之處,這代表您應該能夠使用每個轉換的 Flax NNX 狀態處理 API 來取得您想要的結果。
在本節中,您將逐步了解在 Flax NNX 轉換中使用隨機狀態的兩個範例:一個是使用 nnx.pmap
,您將學習如何分割 PRNG 狀態;另一個是使用 nnx.scan
,您將凍結 PRNG 狀態。
資料平行 dropout#
在第一個範例中,您將探索如何使用 nnx.pmap
在資料平行環境中呼叫 nnx.Model
。
由於
nnx.Model
使用nnx.Dropout
,因此您需要分割dropout
的隨機狀態,以確保每個副本取得不同的 dropout 遮罩。將
nnx.StateAxes
傳遞到in_axes
,以指定model
的dropout
PRNG 鍵流將跨軸0
平行化,而其餘狀態將被複製。使用
nnx.split_rngs
將dropout
PRNG 鍵流的鍵分割成 N 個獨特的鍵,每個副本各一個。
model = Model(nnx.Rngs(params=0, dropout=1))
num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})
@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
return model(x)
y = forward(model, x)
print(y.shape)
(1, 16, 10)
遞迴 dropout#
接下來,讓我們探索如何實作使用遞迴 dropout 的 RNNCell
。為此:
首先,您將建立一個
nnx.Dropout
層,它將從自訂的recurrent_dropout
流中取樣 PRNG 鍵。您將把 dropout (
drop
) 套用至RNNCell
的隱藏狀態h
。接著,定義一個
initial_state
函式,以建立RNNCell
的初始狀態。最後,實例化
RNNCell
。
class Count(nnx.Variable): pass
class RNNCell(nnx.Module):
def __init__(self, din, dout, rngs):
self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
self.dout = dout
self.count = Count(jnp.array(0, jnp.uint32))
def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
h = self.drop(h) # Recurrent dropout.
y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
self.count += 1
return y, y
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.dout))
cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))
接下來,您將在 unroll
函式上使用 nnx.scan
來實作 rnn_forward
操作。
遞迴 dropout 的關鍵在於在所有時間步長上套用相同的 dropout 遮罩。因此,為了達成這個目標,您將
nnx.StateAxes
傳遞到nnx.scan
的in_axes
,指定將廣播cell
的recurrent_dropout
PRNG 流,並傳遞RNNCell
的其餘狀態。此外,隱藏狀態
h
將會是nnx.scan
的Carry
變數,而序列x
將會跨其軸1
進行scan
。
@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
h = cell.initial_state(batch_size=x.shape[0])
# Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.
state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
@nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
h, y = cell(h, x)
return h, y
h, y = unroll(cell, h, x)
return y
x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)
print(f'{y.shape = }')
print(f'{cell.count.value = }')
y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)