rnglib#
- class flax.nnx.Rngs(*args, **kwargs)[原始碼]#
NNX rng 容器類別。要實例化
Rngs
,請傳入一個整數,指定起始種子。Rngs
可以有不同的「串流」,允許使用者產生不同的 rng 金鑰。例如,要產生一個用於params
和dropout
串流的金鑰>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> rng1 = nnx.Rngs(0, params=1) >>> rng2 = nnx.Rngs(0) >>> assert rng1.params() != rng2.dropout()
因為我們傳入了
params=1
,所以params
的起始種子為1
,而dropout
的起始種子預設為我們傳入的0
,因為我們沒有為dropout
指定種子。如果我們沒有為params
指定種子,那麼兩個串流都將預設使用我們傳入的0
>>> rng1 = nnx.Rngs(0) >>> rng2 = nnx.Rngs(0) >>> assert rng1.params() == rng2.dropout()
Rngs
容器類別為每個串流包含一個單獨的計數器。每次呼叫串流產生新的 rng 金鑰時,計數器都會遞增1
。要產生新的 rng 金鑰,我們會將目前 rng 串流的計數器值折疊到其對應的起始種子中。如果我們嘗試為在實例化時未指定的串流產生 rng 金鑰,則會使用default
串流 (即,在實例化期間傳遞給Rngs
的第一個位置引數是default
起始種子)>>> rng1 = nnx.Rngs(100, params=42) >>> # `params` stream starting seed is 42, counter is 0 >>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0) >>> # `dropout` stream starting seed is defaulted to 100, counter is 0 >>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0) >>> # empty stream starting seed is defaulted to 100, counter is 1 >>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1) >>> # `params` stream starting seed is 42, counter is 1 >>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)
讓我們來看一個在
Module
中使用Rngs
的範例,並通過手動執行Rngs
來驗證輸出>>> class Model(nnx.Module): ... def __init__(self, rngs): ... # Linear uses the `params` stream twice for kernel and bias ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... # Dropout uses the `dropout` stream once ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) >>> def assert_same(x, rng_seed, **rng_kwargs): ... model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs)) ... out = model(x) ... ... # manual forward propagation ... rngs = nnx.Rngs(rng_seed, **rng_kwargs) ... kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3)) ... assert (model.linear.kernel.value==kernel).all() ... bias = nnx.initializers.zeros_init()(rngs.params(), (3,)) ... assert (model.linear.bias.value==bias).all() ... mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3)) ... # dropout scales the output proportional to the dropout rate ... manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5 ... assert (out == manual_out).all() >>> x = jnp.ones((1, 2)) >>> assert_same(x, 0) >>> assert_same(x, 0, params=1) >>> assert_same(x, 0, params=1, dropout=2)
- flax.nnx.reseed(node, /, **stream_keys)[原始碼]#
使用新的金鑰更新指定的 RNG 串流的金鑰。
- 參數
node – 要在其中重新設定 RNG 串流種子的節點。
**stream_keys – 串流名稱到新金鑰的對應。金鑰可以是整數或 jax 陣列。如果傳入整數,則將使用
jax.random.key
產生金鑰。
- 提出
ValueError – 如果現有的串流金鑰不是純量。
範例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) ... >>> model = Model(nnx.Rngs(params=0, dropout=42)) >>> x = jnp.ones((1, 2)) ... >>> y1 = model(x) ... >>> # reset the ``dropout`` stream key to 42 >>> nnx.reseed(model, dropout=42) >>> y2 = model(x) ... >>> jnp.allclose(y1, y2) Array(True, dtype=bool)