隨機性#

與 Haiku 和 Flax Linen 等系統相比,Flax NNX 中隨機狀態的處理方式已大幅簡化,因為 Flax NNX 將隨機狀態定義為物件狀態。 本質上,這表示在 Flax NNX 中,隨機狀態:1) 只是另一種狀態;2) 儲存在 nnx.Variable 中;3) 由模型本身持有。

Flax NNX 假隨機數產生器 (PRNG) 系統具有以下主要特性

  • 它是明確的

  • 它是基於順序的

  • 它使用動態計數器

這與 Flax Linen 的 PRNG 系統 有些不同,後者是基於 (路徑 + 順序),並使用靜態計數器。

注意:若要深入瞭解 JAX 中的隨機數產生、jax.random API 和 PRNG 產生的序列,請查看此 JAX PRNG 教學

讓我們先從一些必要的匯入開始

from flax import nnx
import jax
from jax import random, numpy as jnp

RngsRngStreamRngState#

在 Flax NNX 中,nnx.Rngs 類型是管理隨機狀態的主要便利 API。 延續 Flax Linen 的腳步,nnx.Rngs 能夠建立多個具名的 PRNG 金鑰 串流,每個串流都有自己的狀態,以便在 JAX 轉換 (transforms) 的情況下嚴格控制隨機性。

以下是 Flax NNX 中主要的 PRNG 相關類型

  • nnx.Rngs:主要使用者介面。 它定義了一組具名的 nnx.RngStream 物件。

  • nnx.RngStream:可以產生 PRNG 金鑰串流的物件。 它在 nnx.RngKeynnx.RngCount nnx.Variable 中分別保存一個根 key 和一個 count。 當產生新的金鑰時,計數會遞增。

  • nnx.RngState:所有 RNG 相關狀態的基礎類型。

    • nnx.RngKey:用於保存 PRNG 金鑰的 NNX 變數類型。 它包含一個 tag 屬性,其中包含 PRNG 金鑰串流的名稱。

    • nnx.RngCount:用於保存 PRNG 計數的 NNX 變數類型。 它包含一個 tag 屬性,其中包含 PRNG 金鑰串流名稱。

若要建立 nnx.Rngs 物件,您可以簡單地將整數種子或 jax.random.key 實例傳遞給建構函式中您選擇的任何關鍵字引數。

這是一個範例

rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)

請注意,keycount nnx.Variabletag 屬性中包含 PRNG 金鑰串流名稱。 這主要用於篩選,我們稍後會看到。

若要產生新的金鑰,您可以存取其中一個串流,並使用其 __call__ 方法,且不帶任何引數。 這將透過使用 random.fold_in 與目前的 keycount 來傳回新的金鑰。 然後,count 會遞增,以便後續呼叫將傳回新的金鑰。

params_key = rngs.params()
dropout_key = rngs.dropout()

nnx.display(rngs)

請注意,當產生新的 PRNG 金鑰時,key 屬性不會變更。

標準 PRNG 金鑰串流名稱#

Flax NNX 的內建模組只使用兩個標準 PRNG 金鑰串流名稱,如下表所示

PRNG 金鑰串流名稱

描述

params

用於參數初始化

dropout

nnx.Dropout 用於建立 dropout 遮罩

  • params 在建構期間被大多數標準層 (例如 nnx.Linearnnx.Convnnx.MultiHeadAttention 等) 用於初始化其參數。

  • dropoutnnx.Dropoutnnx.MultiHeadAttention 用於產生 dropout 遮罩。

以下是一個簡單的模型範例,該模型使用 paramsdropout PRNG 金鑰串流

class Model(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    self.linear = nnx.Linear(20, 10, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.drop(self.linear(x)))

model = Model(nnx.Rngs(params=0, dropout=1))

y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
y.shape = (1, 10)

預設 PRNG 金鑰串流#

具有具名串流的缺點之一是,使用者在建立 nnx.Rngs 物件時,需要知道模型將使用的所有可能的名稱。 雖然這可以透過一些文件來解決,但 Flax NNX 提供了一個 default 串流,當找不到串流時,可以用作回退。 若要使用預設 PRNG 金鑰串流,您可以簡單地將整數種子或 jax.random.key 作為第一個位置引數傳遞。

rngs = nnx.Rngs(0, params=1)

key1 = rngs.params() # Call params.
key2 = rngs.dropout() # Fallback to the default stream.
key3 = rngs() # Call the default stream directly.

# Test with the `Model` that uses `params` and `dropout`.
model = Model(rngs)
y = model(jnp.ones((1, 20)))

nnx.display(rngs)

如上所示,也可以透過呼叫 nnx.Rngs 物件本身來產生來自 default 串流的 PRNG 金鑰。

注意
對於大型專案,建議使用具名串流以避免潛在的衝突。 對於小型專案或快速原型設計,僅使用 default 串流是一個不錯的選擇。

篩選隨機狀態#

可以使用 篩選器 來操作隨機狀態,就像任何其他類型的狀態一樣。 可以使用類型 (nnx.RngStatennx.RngKeynnx.RngCount) 或使用與串流名稱對應的字串來篩選 (請參閱 Flax NNX Filter DSL)。 以下範例使用 nnx.state 和各種篩選器,以選取 ModelRngs 的不同子狀態

model = Model(nnx.Rngs(params=0, dropout=1))

rng_state = nnx.state(model, nnx.RngState) # All random states.
key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.
count_state = nnx.state(model, nnx.RngCount) # Only counts.
rng_params_state = nnx.state(model, 'params') # Only `params`.
rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # `Params` PRNG keys.

nnx.display(params_key_state)

重新設定種子#

在 Haiku 和 Flax Linen 中,每次在呼叫模型之前,都會將隨機狀態明確地傳遞至 Module.apply。 這使得在需要時 (例如,為了重現性) 輕鬆控制模型的隨機性。

在 Flax NNX 中,有兩種方法可以處理此問題

  1. 透過手動將 nnx.Rngs 物件傳遞到 __call__ 堆疊中。標準層(如 nnx.Dropoutnnx.MultiHeadAttention)會接受 rngs 參數,如果您想嚴格控制隨機狀態的話。

  2. 透過使用 nnx.reseed 將模型的隨機狀態設定為特定組態。這個選項的侵入性較小,即使模型並非設計為允許手動控制隨機狀態,也可以使用。

nnx.reseed 是一個函式,它接受任意的圖形節點(包括 pytreennx.Module)以及一些包含 nnx.RngStream 新種子或鍵值的關鍵字參數(由參數名稱指定)。nnx.reseed 接著會遍歷圖形並更新符合的 nnx.RngStream 的隨機狀態,這包括將 key 設定為可能的新值,並將 count 重置為零。

以下範例示範如何使用 nnx.reseed 重置 nnx.Dropout 層的隨機狀態,並驗證其計算結果與首次呼叫模型時相同。

model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))

y1 = model(x)
y2 = model(x)

nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)

assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3)     # same

分割 PRNG 鍵#

當與 Flax NNX 轉換(如 nnx.vmapnnx.pmap)互動時,通常需要分割隨機狀態,以便每個副本都有自己獨特的狀態。這可以透過兩種方式完成:

  • 在將鍵傳遞到其中一個 nnx.Rngs 流之前,手動分割一個鍵;或

  • 使用 nnx.split_rngs 修飾詞,它會自動分割函式輸入中找到的任何 nnx.RngStream 的隨機狀態,並在函式呼叫結束時自動「降低」它們。

使用 nnx.split_rngs 更方便,因為它能與 Flax NNX 轉換良好搭配,以下為一個範例:

rngs = nnx.Rngs(params=0, dropout=1)

@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
  print('Inside:')
  # rngs.dropout() # ValueError: fold_in accepts a single key...
  nnx.display(rngs)

f(rngs)

print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)
Inside:
Outside:

注意: nnx.split_rngs 允許將 NNX Filter 傳遞到 only 關鍵字參數,以便選擇在函式內部應分割的 nnx.RngStream。在這種情況下,您只需要分割 dropout PRNG 鍵流。

轉換#

如前所述,在 Flax NNX 中,隨機狀態只是另一種類型的狀態。這表示在 Flax NNX 轉換方面它沒有任何特殊之處,這代表您應該能夠使用每個轉換的 Flax NNX 狀態處理 API 來取得您想要的結果。

在本節中,您將逐步了解在 Flax NNX 轉換中使用隨機狀態的兩個範例:一個是使用 nnx.pmap,您將學習如何分割 PRNG 狀態;另一個是使用 nnx.scan,您將凍結 PRNG 狀態。

資料平行 dropout#

在第一個範例中,您將探索如何使用 nnx.pmap 在資料平行環境中呼叫 nnx.Model

  • 由於 nnx.Model 使用 nnx.Dropout,因此您需要分割 dropout 的隨機狀態,以確保每個副本取得不同的 dropout 遮罩。

  • nnx.StateAxes 傳遞到 in_axes,以指定 modeldropout PRNG 鍵流將跨軸 0 平行化,而其餘狀態將被複製。

  • 使用 nnx.split_rngsdropout PRNG 鍵流的鍵分割成 N 個獨特的鍵,每個副本各一個。

model = Model(nnx.Rngs(params=0, dropout=1))

num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})

@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
  return model(x)

y = forward(model, x)
print(y.shape)
(1, 16, 10)

遞迴 dropout#

接下來,讓我們探索如何實作使用遞迴 dropout 的 RNNCell。為此:

  • 首先,您將建立一個 nnx.Dropout 層,它將從自訂的 recurrent_dropout 流中取樣 PRNG 鍵。

  • 您將把 dropout (drop) 套用至 RNNCell 的隱藏狀態 h

  • 接著,定義一個 initial_state 函式,以建立 RNNCell 的初始狀態。

  • 最後,實例化 RNNCell

class Count(nnx.Variable): pass

class RNNCell(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
    self.dout = dout
    self.count = Count(jnp.array(0, jnp.uint32))

  def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
    h = self.drop(h) # Recurrent dropout.
    y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
    self.count += 1
    return y, y

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.dout))

cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))

接下來,您將在 unroll 函式上使用 nnx.scan 來實作 rnn_forward 操作。

  • 遞迴 dropout 的關鍵在於在所有時間步長上套用相同的 dropout 遮罩。因此,為了達成這個目標,您將 nnx.StateAxes 傳遞到 nnx.scanin_axes,指定將廣播 cellrecurrent_dropout PRNG 流,並傳遞 RNNCell 的其餘狀態。

  • 此外,隱藏狀態 h 將會是 nnx.scanCarry 變數,而序列 x 將會跨其軸 1 進行 scan

@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
  h = cell.initial_state(batch_size=x.shape[0])

  # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.
  state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
  @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
  def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
    h, y = cell(h, x)
    return h, y

  h, y = unroll(cell, h, x)
  return y

x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)

print(f'{y.shape = }')
print(f'{cell.count.value = }')
y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)