正規化#

class flax.nnx.BatchNorm(*args, **kwargs)[原始碼]#

BatchNorm 模組。

若要在輸入上計算批次正規化並更新批次統計資訊,請呼叫 train() 方法(或在建構函式中或呼叫期間傳入 use_running_average=False)。

若要使用儲存的批次統計資訊的執行平均值,請呼叫 eval() 方法(或在建構函式中或呼叫期間傳入 use_running_average=True)。

用法範例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5,
...                       dtype=jnp.float32, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(6,)
  ),
  'mean': VariableState(
    type=BatchStat,
    value=(6,)
  ),
  'scale': VariableState(
    type=Param,
    value=(6,)
  ),
  'var': VariableState(
    type=BatchStat,
    value=(6,)
  )
})

>>> # calculate batch norm on input and update batch statistics
>>> layer.train()
>>> y = layer(x)
>>> batch_stats1 = nnx.state(layer, nnx.BatchStat)
>>> y = layer(x)
>>> batch_stats2 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all()
>>> assert (batch_stats1['var'].value != batch_stats2['var'].value).all()

>>> # use stored batch statistics' running average
>>> layer.eval()
>>> y = layer(x)
>>> batch_stats3 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
num_features#

輸入特徵的數量。

use_running_average#

如果為 True,則會使用儲存的批次統計資訊,而不是計算輸入上的批次統計資訊。

axis#

輸入的特徵或非批次軸。

momentum#

批次統計資訊的指數移動平均的衰減率。

epsilon#

加入變異數的小浮點數,以避免除以零。

dtype#

結果的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞到參數初始化器的資料類型(預設值:float32)。

use_bias#

如果為 True,則會加入偏差 (beta)。

use_scale#

如果為 True,則乘以縮放 (gamma)。當下一層是線性時(例如 nn.relu),可以停用此選項,因為縮放將由下一層完成。

bias_init#

偏差的初始化器,預設為零。

scale_init#

縮放的初始化器,預設為一。

axis_name#

用於合併多個裝置的批次統計資訊的軸名稱。有關軸名稱的說明,請參閱 jax.pmap(預設值:None)。

axis_index_groups#

該命名軸內軸索引的群組,表示要縮減的裝置子集(預設值:None)。例如,[[0, 1], [2, 3]] 會獨立於前兩個和最後兩個裝置上的範例進行批次正規化。有關詳細資訊,請參閱 jax.lax.psum

use_fast_variance#

如果為 true,則使用更快但數值穩定性較差的變異數計算。

rngs#

rng 金鑰。

__call__(x, use_running_average=None, *, mask=None)[原始碼]#

使用批次統計資訊對輸入進行正規化。

參數
  • x – 要正規化的輸入。

  • use_running_average – 如果為 true,則會使用儲存的批次統計資訊,而不是計算輸入上的批次統計資訊。傳遞到呼叫方法的 use_running_average 旗標會優先於傳遞到建構函式的 use_running_average 旗標。

傳回

已正規化的輸入(與輸入相同的形狀)。

方法

class flax.nnx.LayerNorm(*args, **kwargs)[原始碼]#

層正規化 (https://arxiv.org/abs/1607.06450)。

LayerNorm 會獨立於批次中的每個指定範例來正規化該層的激活,而不是像批次正規化那樣跨批次正規化。亦即,套用一種轉換,使每個範例中的平均激活維持接近 0,並且激活標準差接近 1。

用法範例

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

輸入特徵的數量。

epsilon#

加入變異數的小浮點數,以避免除以零。

dtype#

結果的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞到參數初始化器的資料類型(預設值:float32)。

use_bias#

如果為 True,則會加入偏差 (beta)。

use_scale#

如果為 True,則乘以縮放 (gamma)。當下一層是線性時(例如 nnx.relu),可以停用此選項,因為縮放將由下一層完成。

bias_init#

偏差的初始化器,預設為零。

scale_init#

縮放的初始化器,預設為一。

reduction_axes#

用於計算正規化統計資訊的軸。

feature_axes#

用於學習偏差和縮放的特徵軸。

axis_name#

用於合併多個裝置的批次統計資訊的軸名稱。有關軸名稱的說明,請參閱 jax.pmap(預設值:None)。只有在模型跨裝置細分時才需要此設定,亦即,正在正規化的陣列會在 pmap 中的裝置之間分片。

axis_index_groups#

該命名軸內軸索引的群組,表示要縮減的裝置子集(預設值:None)。例如,[[0, 1], [2, 3]] 會獨立於前兩個和最後兩個裝置上的範例進行批次正規化。有關詳細資訊,請參閱 jax.lax.psum

use_fast_variance#

如果為 true,則使用更快但數值穩定性較差的變異數計算。

rngs#

rng 金鑰。

__call__(x, *, mask=None)[原始碼]#

對輸入應用層正規化。

參數

x – 輸入

傳回

已正規化的輸入(與輸入相同的形狀)。

方法

class flax.nnx.RMSNorm(*args, **kwargs)[原始碼]#

RMS 層正規化 (https://arxiv.org/abs/1910.07467)。

RMSNorm 對於批次中每個給定的範例,獨立地正規化層的激活,而不是像批次正規化那樣跨批次正規化。與 LayerNorm 將平均值重新置中為 0 並通過激活的標準差進行正規化不同,RMSNorm 根本不重新置中,而是通過激活的均方根進行正規化。

用法範例

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

輸入特徵的數量。

epsilon#

加入變異數的小浮點數,以避免除以零。

dtype#

結果的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞到參數初始化器的資料類型(預設值:float32)。

use_scale#

如果為 True,則乘以 scale (gamma)。當下一層是線性層(例如,nn.relu)時,可以停用此功能,因為縮放將由下一層完成。

scale_init#

縮放的初始化器,預設為一。

reduction_axes#

用於計算正規化統計資訊的軸。

feature_axes#

用於學習偏差和縮放的特徵軸。

axis_name#

用於合併多個裝置的批次統計資訊的軸名稱。有關軸名稱的說明,請參閱 jax.pmap(預設值:None)。只有在模型跨裝置細分時才需要此設定,亦即,正在正規化的陣列會在 pmap 中的裝置之間分片。

axis_index_groups#

該命名軸內軸索引的群組,表示要縮減的裝置子集(預設值:None)。例如,[[0, 1], [2, 3]] 會獨立於前兩個和最後兩個裝置上的範例進行批次正規化。有關詳細資訊,請參閱 jax.lax.psum

use_fast_variance#

如果為 true,則使用更快但數值穩定性較差的變異數計算。

rngs#

rng 金鑰。

__call__(x, mask=None)[原始碼]#

對輸入應用層正規化。

參數

x – 輸入

傳回

已正規化的輸入(與輸入相同的形狀)。

方法

class flax.nnx.GroupNorm(*args, **kwargs)[原始碼]#

群組正規化 (arxiv.org/abs/1803.08494)。

此操作類似於批次正規化,但統計資訊在大小相等的通道群組之間共享,而不是跨批次維度共享。因此,群組正規化不依賴於批次組成,也不需要維護用於儲存統計資訊的內部狀態。使用者應指定通道群組的總數或每個群組的通道數。

注意

LayerNorm 是 GroupNorm 的特殊情況,其中 num_groups=1

用法範例

>>> from flax import nnx
>>> import jax
>>> import numpy as np
...
>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})
>>> y = layer(x)
...
>>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x)
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2)
num_features#

輸入特徵/通道的數量。

num_groups#

通道群組的總數。原始群組正規化論文提出了預設值 32。

group_size#

群組中的通道數。

epsilon#

加入變異數的小浮點數,以避免除以零。

dtype#

結果的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞到參數初始化器的資料類型(預設值:float32)。

use_bias#

如果為 True,則會加入偏差 (beta)。

use_scale#

如果為 True,則乘以 scale (gamma)。當下一層是線性層(例如,nn.relu)時,可以停用此功能,因為縮放將由下一層完成。

bias_init#

偏差的初始化器,預設為零。

scale_init#

縮放的初始化器,預設為一。

reduction_axes#

用於計算正規化統計資訊的軸的列表。此列表必須包含最後一個維度,假設它是特徵軸。此外,如果在呼叫時使用的輸入與用於初始化的資料相比具有額外的前導軸,例如由於批次處理,則需要明確定義縮減軸。

axis_name#

用於合併來自多個裝置的批次統計資訊的軸名稱。有關軸名稱的說明,請參閱 jax.pmap(預設值:None)。僅當模型在裝置之間細分時才需要此功能,也就是說,要正規化的陣列在 pmap 或分片映射中的裝置之間分片。對於 SPMD jit,您不需要手動同步。只需確保正確註釋軸,XLA:SPMD 就會插入必要的集合。

axis_index_groups#

該命名軸內軸索引的群組,表示要縮減的裝置子集(預設值:None)。例如,[[0, 1], [2, 3]] 會獨立於前兩個和最後兩個裝置上的範例進行批次正規化。有關詳細資訊,請參閱 jax.lax.psum

use_fast_variance#

如果為 true,則使用更快但數值穩定性較差的變異數計算。

rngs#

rng 金鑰。

__call__(x, *, mask=None)[原始碼]#

對輸入應用群組正規化 (arxiv.org/abs/1803.08494)。

參數
  • x – 形狀為 ...self.num_features 的輸入,其中 self.num_features 是通道維度,而 ... 表示可用於累積統計資訊的任意數量的額外維度。如果未指定任何縮減軸,則所有額外維度 ... 都將用於累積統計資訊,但前導維度除外,該維度假設代表批次。

  • mask – 可廣播到 inputs 張量的二進制陣列,表示應計算平均值和變異數的位置。

傳回

已正規化的輸入(與輸入相同的形狀)。

方法