rnglib#

class flax.nnx.Rngs(*args, **kwargs)[原始碼]#

NNX rng 容器類別。要實例化 Rngs,請傳入一個整數,指定起始種子。Rngs 可以有不同的「串流」,允許使用者產生不同的 rng 金鑰。例如,要產生一個用於 paramsdropout 串流的金鑰

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> rng1 = nnx.Rngs(0, params=1)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() != rng2.dropout()

因為我們傳入了 params=1,所以 params 的起始種子為 1,而 dropout 的起始種子預設為我們傳入的 0,因為我們沒有為 dropout 指定種子。如果我們沒有為 params 指定種子,那麼兩個串流都將預設使用我們傳入的 0

>>> rng1 = nnx.Rngs(0)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() == rng2.dropout()

Rngs 容器類別為每個串流包含一個單獨的計數器。每次呼叫串流產生新的 rng 金鑰時,計數器都會遞增 1。要產生新的 rng 金鑰,我們會將目前 rng 串流的計數器值折疊到其對應的起始種子中。如果我們嘗試為在實例化時未指定的串流產生 rng 金鑰,則會使用 default 串流 (即,在實例化期間傳遞給 Rngs 的第一個位置引數是 default 起始種子)

>>> rng1 = nnx.Rngs(100, params=42)
>>> # `params` stream starting seed is 42, counter is 0
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0)
>>> # `dropout` stream starting seed is defaulted to 100, counter is 0
>>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0)
>>> # empty stream starting seed is defaulted to 100, counter is 1
>>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1)
>>> # `params` stream starting seed is 42, counter is 1
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)

讓我們來看一個在 Module 中使用 Rngs 的範例,並通過手動執行 Rngs 來驗證輸出

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     # Linear uses the `params` stream twice for kernel and bias
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     # Dropout uses the `dropout` stream once
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))

>>> def assert_same(x, rng_seed, **rng_kwargs):
...   model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs))
...   out = model(x)
...
...   # manual forward propagation
...   rngs = nnx.Rngs(rng_seed, **rng_kwargs)
...   kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3))
...   assert (model.linear.kernel.value==kernel).all()
...   bias = nnx.initializers.zeros_init()(rngs.params(), (3,))
...   assert (model.linear.bias.value==bias).all()
...   mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3))
...   # dropout scales the output proportional to the dropout rate
...   manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5
...   assert (out == manual_out).all()

>>> x = jnp.ones((1, 2))
>>> assert_same(x, 0)
>>> assert_same(x, 0, params=1)
>>> assert_same(x, 0, params=1, dropout=2)
__init__(default=None, /, **rngs)[原始碼]#
參數
  • defaultdefault 串流的起始種子。從未在 **rngs 關鍵字引數中指定的串流產生的任何金鑰,都將預設使用此起始種子。

  • **rngs – 可選的關鍵字引數,用於指定不同 rng 串流的起始種子。關鍵字是串流名稱,其值是該串流對應的起始種子。

class flax.nnx.RngStream(*args: 'Any', **kwargs: 'Any')[原始碼]#
flax.nnx.reseed(node, /, **stream_keys)[原始碼]#

使用新的金鑰更新指定的 RNG 串流的金鑰。

參數
  • node – 要在其中重新設定 RNG 串流種子的節點。

  • **stream_keys – 串流名稱到新金鑰的對應。金鑰可以是整數或 jax 陣列。如果傳入整數,則將使用 jax.random.key 產生金鑰。

提出

ValueError – 如果現有的串流金鑰不是純量。

範例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))
...
>>> model = Model(nnx.Rngs(params=0, dropout=42))
>>> x = jnp.ones((1, 2))
...
>>> y1 = model(x)
...
>>> # reset the ``dropout`` stream key to 42
>>> nnx.reseed(model, dropout=42)
>>> y2 = model(x)
...
>>> jnp.allclose(y1, y2)
Array(True, dtype=bool)