隨機性#

class flax.nnx.Dropout(*args, **kwargs)[原始碼]#

建立一個 Dropout 層。

要使用 dropout,請調用 train() 方法(或在建構子或呼叫時傳入 deterministic=False)。

要停用 dropout,請調用 eval() 方法(或在建構子或呼叫時傳入 deterministic=True)。

範例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class MLP(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     x = self.dropout(x)
...     return x

>>> model = MLP(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 3))

>>> model.train() # use dropout
>>> model(x)
Array([[-0.9353421,  0.       ,  1.434417 ,  0.       ]], dtype=float32)

>>> model.eval() # don't use dropout
>>> model(x)
Array([[-0.46767104, -0.7213411 ,  0.7172085 , -0.31562346]], dtype=float32)
rate#

dropout 的機率。(_不是_保留率!)

類型

float

broadcast_dims#

將共享相同 dropout 遮罩的維度

類型

collections.abc.Sequence[int]

deterministic#

如果為 false,則輸入會按 1 / (1 - rate) 縮放並遮罩,而如果為 true,則不套用遮罩,並按原樣返回輸入。

類型

bool

rng_collection#

請求 rng 金鑰時要使用的 rng 集合名稱。

類型

str

rngs#

rng 金鑰。

類型

flax.nnx.rnglib.Rngs | None