隨機性#
- class flax.nnx.Dropout(*args, **kwargs)[原始碼]#
建立一個 Dropout 層。
要使用 dropout,請調用
train()
方法(或在建構子或呼叫時傳入deterministic=False
)。要停用 dropout,請調用
eval()
方法(或在建構子或呼叫時傳入deterministic=True
)。範例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> class MLP(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... x = self.linear(x) ... x = self.dropout(x) ... return x >>> model = MLP(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 3)) >>> model.train() # use dropout >>> model(x) Array([[-0.9353421, 0. , 1.434417 , 0. ]], dtype=float32) >>> model.eval() # don't use dropout >>> model(x) Array([[-0.46767104, -0.7213411 , 0.7172085 , -0.31562346]], dtype=float32)
- rate#
dropout 的機率。(_不是_保留率!)
- 類型
float
- broadcast_dims#
將共享相同 dropout 遮罩的維度
- 類型
collections.abc.Sequence[int]
- deterministic#
如果為 false,則輸入會按
1 / (1 - rate)
縮放並遮罩,而如果為 true,則不套用遮罩,並按原樣返回輸入。- 類型
bool
- rng_collection#
請求 rng 金鑰時要使用的 rng 集合名稱。
- 類型
str
- rngs#
rng 金鑰。
- 類型
flax.nnx.rnglib.Rngs | None