激活函數#

flax.nnx.celu(x, alpha=1.0)[原始碼]#

連續可微分指數線性單元激活。

計算逐元素函數

\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]

有關更多資訊,請參閱 連續可微分指數線性單元

參數
  • x – 輸入陣列

  • alpha – 陣列或純量 (預設值:1.0)

回傳

一個陣列。

flax.nnx.elu(x, alpha=1.0)[原始碼]#

指數線性單元激活函數。

計算逐元素函數

\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
參數
  • x – 輸入陣列

  • alpha – alpha 值的純量或陣列 (預設值:1.0)

回傳

一個陣列。

另請參閱

selu()

flax.nnx.gelu(x, approximate=True)[原始碼]#

高斯誤差線性單元激活函數。

如果 approximate=False,則計算逐元素函數

\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]

如果 approximate=True,則使用 GELU 的近似公式

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

有關更多資訊,請參閱 高斯誤差線性單元 (GELU),第 2 節。

參數
  • x – 輸入陣列

  • approximate – 是否使用近似公式或精確公式。

flax.nnx.glu(x, axis=-1)[原始碼]#

閘控線性單元激活函數。

計算函數

\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]

其中陣列沿著 axis 分成兩部分。axis 維度的大小必須可被 2 整除。

參數
  • x – 輸入陣列

  • axis – 應沿著計算分割的軸 (預設值:-1)

回傳

一個陣列。

另請參閱

sigmoid()

flax.nnx.hard_sigmoid(x)[原始碼]#

Hard Sigmoid 激活函數。

計算逐元素函數

\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]
參數

x – 輸入陣列

回傳

一個陣列。

另請參閱

relu6()

flax.nnx.hard_silu(x)[原始碼]#

Hard SiLU (swish) 激活函數

計算逐元素函數

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]

hard_silu()hard_swish() 都是同一個函數的別名。

參數

x – 輸入陣列

回傳

一個陣列。

另請參閱

hard_sigmoid()

flax.nnx.hard_swish(x)#

Hard SiLU (swish) 激活函數

計算逐元素函數

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]

hard_silu()hard_swish() 都是同一個函數的別名。

參數

x – 輸入陣列

回傳

一個陣列。

另請參閱

hard_sigmoid()

flax.nnx.hard_tanh(x)[原始碼]#

硬性 \(\mathrm{tanh}\) 激活函數。

計算逐元素函數

\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]
參數

x – 輸入陣列

回傳

一個陣列。

flax.nnx.leaky_relu(x, negative_slope=0.01)[原始碼]#

洩漏式線性整流單元激活函數。

計算逐元素函數

\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]

其中 \(\alpha\) = negative_slope

參數
  • x – 輸入陣列

  • negative_slope – 指定負斜率的陣列或純量 (預設值:0.01)

回傳

一個陣列。

另請參閱

relu()

flax.nnx.log_sigmoid(x)[原始碼]#

對數-Sigmoid 激活函數。

計算逐元素函數

\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
參數

x – 輸入陣列

回傳

一個陣列。

另請參閱

sigmoid()

flax.nnx.log_softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[原始碼]#

對數 Softmax 函數。

計算 softmax 函數的對數,它會將元素重新縮放到 \([-\infty, 0)\) 的範圍內。

\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
參數
  • x – 輸入陣列

  • axis – 應沿著計算 log_softmax 的軸或多個軸。可以是整數或整數的元組。

  • where – 要包含在 log_softmax 中的元素。

回傳

一個陣列。

注意

如果任何輸入值為 +inf,則結果將全部為 NaN:這反映了在浮點數數學的上下文中,inf / inf 沒有明確定義的事實。

另請參閱

softmax()

flax.nnx.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None)[原始碼]#

對數和指數 (Log-sum-exp) 縮減。

scipy.special.logsumexp() 的 JAX 實作。

\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]

其中 \(j\) 索引的範圍涵蓋一個或多個要縮減的維度。

參數
  • a – 輸入陣列

  • axis – 要縮減的軸或多個軸。可以是 None、整數或整數的元組。

  • b\(\mathrm{exp}(a)\) 的縮放因子。必須可廣播至 a 的形狀。

  • keepdims – 如果為 True,則縮減的軸會在輸出中保留為大小為 1 的維度。

  • return_sign – 如果為 True,則輸出會是一個 (結果, 符號) 對,其中 符號 是總和的符號,而 結果 包含其絕對值的對數。如果為 False,則只會回傳 結果,並且如果總和為負數,則會包含 NaN 值。

  • where – 要包含在縮減中的元素。

回傳

根據 return_sign 引數的值,會回傳一個陣列 結果 或一對陣列 (結果, 符號)

flax.nnx.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[原始碼]#

對給定的索引進行獨熱編碼。

輸入 x 中的每個索引都會被編碼為長度為 num_classes 的零向量,其中 索引 處的元素設定為一。

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

範圍 [0, num_classes) 之外的索引將被編碼為零。

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
參數
  • x – 索引張量。

  • num_classes – 獨熱維度中的類別數。

  • dtype – 可選,回傳值的浮點 dtype (預設為 jnp.float_)。

  • axis – 應沿其計算函數的軸或多個軸。

flax.nnx.relu(x)[原始碼]#

整流線性單元激活函數。

計算逐元素函數

\[\mathrm{relu}(x) = \max(x, 0)\]

但在微分下,我們取

\[\nabla \mathrm{relu}(0) = 0\]

如需詳細資訊,請參閱 ReLU'(0) 對反向傳播的數值影響

參數

x – 輸入陣列

回傳

一個陣列。

範例

>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)

另請參閱

relu6()

flax.nnx.selu(x)[原始碼]#

縮放指數線性單元激活。

計算逐元素函數

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

其中 \(\lambda = 1.0507009873554804934193349852946\)\(\alpha = 1.6732632423543772848170429916717\)

如需詳細資訊,請參閱 自正規化神經網路

參數

x – 輸入陣列

回傳

一個陣列。

另請參閱

elu()

flax.nnx.sigmoid(x)[原始碼]#

Sigmoid 激活函數。

計算逐元素函數

\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
參數

x – 輸入陣列

回傳

一個陣列。

另請參閱

log_sigmoid()

flax.nnx.silu(x)[原始碼]#

SiLU (又稱 swish) 激活函數。

計算逐元素函數

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]

swish()silu() 都是同一個函數的別名。

參數

x – 輸入陣列

回傳

一個陣列。

另請參閱

sigmoid()

flax.nnx.soft_sign(x)[原始碼]#

軟符號激活函數。

計算逐元素函數

\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]
參數

x – 輸入陣列

flax.nnx.softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[原始碼]#

Softmax 函數。

計算將元素重新調整為 \([0, 1]\) 範圍的函數,使得沿著 axis 的元素總和為 \(1\)

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
參數
  • x – 輸入陣列

  • axis – 應沿其計算 softmax 的軸或多個軸。softmax 輸出在這些維度上的總和應為 \(1\)。可以是整數或整數的元組。

  • where – 要包含在 softmax 中的元素。

回傳

一個陣列。

注意

如果任何輸入值為 +inf,則結果將全部為 NaN:這反映了在浮點數數學的上下文中,inf / inf 沒有明確定義的事實。

另請參閱

log_softmax()

flax.nnx.softplus(x)[原始碼]#

軟加 (Softplus) 激活函數。

計算逐元素函數

\[\mathrm{softplus}(x) = \log(1 + e^x)\]
參數

x – 輸入陣列

flax.nnx.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[原始碼]#

透過減去 mean 並除以 \(\sqrt{\mathrm{variance}}\) 來正規化陣列。

flax.nnx.swish(x)#

SiLU (又稱 swish) 激活函數。

計算逐元素函數

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]

swish()silu() 都是同一個函數的別名。

參數

x – 輸入陣列

回傳

一個陣列。

另請參閱

sigmoid()

flax.nnx.tanh(x, /)#

計算輸入的逐元素雙曲正切值。

JAX 實現的 numpy.tanh

雙曲正切的定義為

\[tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]
參數

x – 輸入陣列或純量。

回傳

一個包含 x 每個元素的雙曲正切值的陣列,並提升為不精確的 dtype。

注意

jnp.tanh 等效於計算 -1j * jnp.tan(1j * x)

另請參閱

  • jax.numpy.sinh(): 計算輸入的逐元素雙曲正弦值。

  • jax.numpy.cosh(): 計算輸入的逐元素雙曲餘弦值。

  • jax.numpy.arctanh(): 計算輸入的逐元素反雙曲正切值。

範例

>>> x = jnp.array([[-1, 0, 1],
...                [3, -2, 5]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.tanh(x)
Array([[-0.762,  0.   ,  0.762],
       [ 0.995, -0.964,  1.   ]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
...   -1j * jnp.tan(1j * x)
Array([[-0.762+0.j,  0.   -0.j,  0.762-0.j],
       [ 0.995-0.j, -0.964+0.j,  1.   -0.j]],      dtype=complex64, weak_type=True)

對於複數值輸入

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.tanh(2-5j)
Array(1.031+0.021j, dtype=complex64, weak_type=True)
>>> with jnp.printoptions(precision=3, suppress=True):
...   -1j * jnp.tan(1j * (2-5j))
Array(1.031+0.021j, dtype=complex64, weak_type=True)