激活函數#
- flax.nnx.celu(x, alpha=1.0)[原始碼]#
連續可微分指數線性單元激活。
計算逐元素函數
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]有關更多資訊,請參閱 連續可微分指數線性單元。
- 參數
x – 輸入陣列
alpha – 陣列或純量 (預設值:1.0)
- 回傳
一個陣列。
- flax.nnx.elu(x, alpha=1.0)[原始碼]#
指數線性單元激活函數。
計算逐元素函數
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- 參數
x – 輸入陣列
alpha – alpha 值的純量或陣列 (預設值:1.0)
- 回傳
一個陣列。
另請參閱
- flax.nnx.gelu(x, approximate=True)[原始碼]#
高斯誤差線性單元激活函數。
如果
approximate=False
,則計算逐元素函數\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]如果
approximate=True
,則使用 GELU 的近似公式\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]有關更多資訊,請參閱 高斯誤差線性單元 (GELU),第 2 節。
- 參數
x – 輸入陣列
approximate – 是否使用近似公式或精確公式。
- flax.nnx.glu(x, axis=-1)[原始碼]#
閘控線性單元激活函數。
計算函數
\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]其中陣列沿著
axis
分成兩部分。axis
維度的大小必須可被 2 整除。- 參數
x – 輸入陣列
axis – 應沿著計算分割的軸 (預設值:-1)
- 回傳
一個陣列。
另請參閱
- flax.nnx.hard_sigmoid(x)[原始碼]#
Hard Sigmoid 激活函數。
計算逐元素函數
\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]- 參數
x – 輸入陣列
- 回傳
一個陣列。
另請參閱
relu6()
- flax.nnx.hard_silu(x)[原始碼]#
Hard SiLU (swish) 激活函數
計算逐元素函數
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]hard_silu()
和hard_swish()
都是同一個函數的別名。- 參數
x – 輸入陣列
- 回傳
一個陣列。
另請參閱
- flax.nnx.hard_swish(x)#
Hard SiLU (swish) 激活函數
計算逐元素函數
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]hard_silu()
和hard_swish()
都是同一個函數的別名。- 參數
x – 輸入陣列
- 回傳
一個陣列。
另請參閱
- flax.nnx.hard_tanh(x)[原始碼]#
硬性 \(\mathrm{tanh}\) 激活函數。
計算逐元素函數
\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]- 參數
x – 輸入陣列
- 回傳
一個陣列。
- flax.nnx.leaky_relu(x, negative_slope=0.01)[原始碼]#
洩漏式線性整流單元激活函數。
計算逐元素函數
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]其中 \(\alpha\) =
negative_slope
。- 參數
x – 輸入陣列
negative_slope – 指定負斜率的陣列或純量 (預設值:0.01)
- 回傳
一個陣列。
另請參閱
- flax.nnx.log_sigmoid(x)[原始碼]#
對數-Sigmoid 激活函數。
計算逐元素函數
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- 參數
x – 輸入陣列
- 回傳
一個陣列。
另請參閱
- flax.nnx.log_softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[原始碼]#
對數 Softmax 函數。
計算
softmax
函數的對數,它會將元素重新縮放到 \([-\infty, 0)\) 的範圍內。\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- 參數
x – 輸入陣列
axis – 應沿著計算
log_softmax
的軸或多個軸。可以是整數或整數的元組。where – 要包含在
log_softmax
中的元素。
- 回傳
一個陣列。
注意
如果任何輸入值為
+inf
,則結果將全部為NaN
:這反映了在浮點數數學的上下文中,inf / inf
沒有明確定義的事實。另請參閱
- flax.nnx.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None)[原始碼]#
對數和指數 (Log-sum-exp) 縮減。
scipy.special.logsumexp()
的 JAX 實作。\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]其中 \(j\) 索引的範圍涵蓋一個或多個要縮減的維度。
- 參數
a – 輸入陣列
axis – 要縮減的軸或多個軸。可以是
None
、整數或整數的元組。b – \(\mathrm{exp}(a)\) 的縮放因子。必須可廣播至 a 的形狀。
keepdims – 如果為
True
,則縮減的軸會在輸出中保留為大小為 1 的維度。return_sign – 如果為
True
,則輸出會是一個(結果, 符號)
對,其中符號
是總和的符號,而結果
包含其絕對值的對數。如果為False
,則只會回傳結果
,並且如果總和為負數,則會包含 NaN 值。where – 要包含在縮減中的元素。
- 回傳
根據
return_sign
引數的值,會回傳一個陣列結果
或一對陣列(結果, 符號)
。
- flax.nnx.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[原始碼]#
對給定的索引進行獨熱編碼。
輸入
x
中的每個索引都會被編碼為長度為num_classes
的零向量,其中索引
處的元素設定為一。>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
範圍 [0, num_classes) 之外的索引將被編碼為零。
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- 參數
x – 索引張量。
num_classes – 獨熱維度中的類別數。
dtype – 可選,回傳值的浮點 dtype (預設為
jnp.float_
)。axis – 應沿其計算函數的軸或多個軸。
- flax.nnx.relu(x)[原始碼]#
整流線性單元激活函數。
計算逐元素函數
\[\mathrm{relu}(x) = \max(x, 0)\]但在微分下,我們取
\[\nabla \mathrm{relu}(0) = 0\]如需詳細資訊,請參閱 ReLU'(0) 對反向傳播的數值影響。
- 參數
x – 輸入陣列
- 回傳
一個陣列。
範例
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
另請參閱
relu6()
- flax.nnx.selu(x)[原始碼]#
縮放指數線性單元激活。
計算逐元素函數
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]其中 \(\lambda = 1.0507009873554804934193349852946\) 和 \(\alpha = 1.6732632423543772848170429916717\)。
如需詳細資訊,請參閱 自正規化神經網路。
- 參數
x – 輸入陣列
- 回傳
一個陣列。
另請參閱
- flax.nnx.sigmoid(x)[原始碼]#
Sigmoid 激活函數。
計算逐元素函數
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- 參數
x – 輸入陣列
- 回傳
一個陣列。
另請參閱
- flax.nnx.silu(x)[原始碼]#
SiLU (又稱 swish) 激活函數。
計算逐元素函數
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]- 參數
x – 輸入陣列
- 回傳
一個陣列。
另請參閱
- flax.nnx.soft_sign(x)[原始碼]#
軟符號激活函數。
計算逐元素函數
\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]- 參數
x – 輸入陣列
- flax.nnx.softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[原始碼]#
Softmax 函數。
計算將元素重新調整為 \([0, 1]\) 範圍的函數,使得沿著
axis
的元素總和為 \(1\)。\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- 參數
x – 輸入陣列
axis – 應沿其計算 softmax 的軸或多個軸。softmax 輸出在這些維度上的總和應為 \(1\)。可以是整數或整數的元組。
where – 要包含在
softmax
中的元素。
- 回傳
一個陣列。
注意
如果任何輸入值為
+inf
,則結果將全部為NaN
:這反映了在浮點數數學的上下文中,inf / inf
沒有明確定義的事實。另請參閱
- flax.nnx.softplus(x)[原始碼]#
軟加 (Softplus) 激活函數。
計算逐元素函數
\[\mathrm{softplus}(x) = \log(1 + e^x)\]- 參數
x – 輸入陣列
- flax.nnx.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[原始碼]#
透過減去
mean
並除以 \(\sqrt{\mathrm{variance}}\) 來正規化陣列。
- flax.nnx.swish(x)#
SiLU (又稱 swish) 激活函數。
計算逐元素函數
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]- 參數
x – 輸入陣列
- 回傳
一個陣列。
另請參閱
- flax.nnx.tanh(x, /)#
計算輸入的逐元素雙曲正切值。
JAX 實現的
numpy.tanh
。雙曲正切的定義為
\[tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]- 參數
x – 輸入陣列或純量。
- 回傳
一個包含
x
每個元素的雙曲正切值的陣列,並提升為不精確的 dtype。
注意
jnp.tanh
等效於計算-1j * jnp.tan(1j * x)
。另請參閱
jax.numpy.sinh()
: 計算輸入的逐元素雙曲正弦值。jax.numpy.cosh()
: 計算輸入的逐元素雙曲餘弦值。jax.numpy.arctanh()
: 計算輸入的逐元素反雙曲正切值。
範例
>>> x = jnp.array([[-1, 0, 1], ... [3, -2, 5]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.tanh(x) Array([[-0.762, 0. , 0.762], [ 0.995, -0.964, 1. ]], dtype=float32) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.tan(1j * x) Array([[-0.762+0.j, 0. -0.j, 0.762-0.j], [ 0.995-0.j, -0.964+0.j, 1. -0.j]], dtype=complex64, weak_type=True)
對於複數值輸入
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.tanh(2-5j) Array(1.031+0.021j, dtype=complex64, weak_type=True) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.tan(1j * (2-5j)) Array(1.031+0.021j, dtype=complex64, weak_type=True)