初始化器#

Flax 的初始化器。

flax.linen.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)#

建立初始化器,回傳充滿常數 value 的陣列。

參數
  • value – 初始化器中使用的常數。

  • dtype – 可選擇;初始化器的預設資料型別。

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.],
       [-7., -7., -7.]], dtype=float32)
flax.linen.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

建立用於 delta 正交核心的初始化器。

參數
  • scale – 均勻分配的上限。

  • column_axis – 包含應為正交的欄的軸。

  • dtype – 權重的預設資料型別。

傳回

一個 delta 正交初始化器。傳遞給初始化器的形狀必須為 3D、4D 或 5D。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32)  
Array([[[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]],

       [[ 0.27858758, -0.7949833 , -0.53887904],
        [ 0.9120717 ,  0.04322892,  0.40774566],
        [-0.30085585, -0.6050892 ,  0.73712474]],

       [[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]]], dtype=float32)
flax.linen.initializers.glorot_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 Glorot 正態初始化器(又稱 Xavier 正態初始化器)。

一個 Glorot 正態初始化器jax.nn.initializers.variance_scaling()scale = 1.0, mode="fan_avg", 和 distribution="truncated_normal" 下的特例。

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.41770416,  0.75262755,  0.7619329 ],
       [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
flax.linen.initializers.glorot_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 Glorot 均勻初始化器(又稱 Xavier 均勻初始化器)。

jax.nn.initializers.variance_scaling() 的特別化 Glorot 均勻初始化器 ,其 scale = 1.0mode="fan_avg"distribution="uniform"

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.50350785,  0.8088631 ,  0.81566876],
       [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
flax.linen.initializers.he_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 He 正規初始化器(又稱 Kaiming 正規初始化器)。

He 正規初始化器 的特別化 jax.nn.initializers.variance_scaling(),其 scale = 2.0mode="fan_in"distribution="truncated_normal"

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
       [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
flax.linen.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 He 均勻初始化器(又稱 Kaiming 均勻初始化器)。

He 均勻初始化器 的特別化 jax.nn.initializers.variance_scaling(),其 scale = 2.0mode="fan_in"distribution="uniform"

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
flax.linen.initializers.kaiming_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 He 正規初始化器(又稱 Kaiming 正規初始化器)。

He 正規初始化器 的特別化 jax.nn.initializers.variance_scaling(),其 scale = 2.0mode="fan_in"distribution="truncated_normal"

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
       [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
flax.linen.initializers.kaiming_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 He 均勻初始化器(又稱 Kaiming 均勻初始化器)。

He 均勻初始化器 的特別化 jax.nn.initializers.variance_scaling(),其 scale = 2.0mode="fan_in"distribution="uniform"

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
flax.linen.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 Lecun normal 初始化項。

一個 Lecun normal 初始化項jax.nn.initializers.variance_scaling() 的一個特別用途,其中 scale = 1.0mode="fan_in",且 distribution="truncated_normal"

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.46700746,  0.8414632 ,  0.8518669 ],
       [-0.61677957, -0.67402434,  0.09683388]], dtype=float32)
flax.linen.initializers.lecun_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 Lecun uniform 初始化項。

一個 Lecun uniform 初始化項jax.nn.initializers.variance_scaling() 的一個特別用途,其中 scale = 1.0mode="fan_in",且 distribution="uniform"

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.56293887,  0.90433645,  0.9119454 ],
       [-0.71479625, -0.7676109 ,  0.12302713]], dtype=float32)
flax.linen.initializers.normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>)#

建立回傳真實常態分配隨機陣列的初始化函數。

參數
  • stddev – 選用;分配的標準差。

  • dtype – 可選擇;初始化器的預設資料型別。

傳回

回傳其值為常態分配的陣列,其平均值為 0,標準差為 stddev 的初始化函數。

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.0613258 ,  5.6129413 ,  5.6866574 ],
       [-4.063663  , -4.4520254 ,  0.63115686]], dtype=float32)
flax.linen.initializers.truncated_normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>, lower=-2.0, upper=2.0)#

建立回傳截斷常態隨機陣列的初始化函數。

參數
  • stddev – 選用;未截斷分配的標準差。請注意,此函數並未套用變異縮放初始化函數中所採用的 stddev 校正,使用者預期若要套用此項校正,需自行透過 stddev 參數執行。

  • dtype – 可選擇;初始化器的預設資料型別。

  • lower – 表示截斷下限的浮點數。在輸出乘上 stddev 之前套用。

  • upper – 表示截斷上限的浮點數。在輸出乘上 stddev 之前套用。

傳回

回傳其值遵循平均值為 0,標準差為 stddev,且範圍為 \(\rm{lower * stddev} < x < \rm{upper * stddev}\) 的截斷常態分配的陣列的初始化函數。

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.truncated_normal(5.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 2.9047365,  5.2338114,  5.29852  ],
       [-3.836303 , -4.192359 ,  0.6022964]], dtype=float32)
flax.linen.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)#

回傳恆定陣列的初始化函數,其中必定是 1。

忽略 key 參數。

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
flax.linen.initializers.ones_init()[來源碼]#

建構一個初始化器,會回傳一個充滿 $1$ 的常數陣列。

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import ones_init
>>> ones_initializer = ones_init()
>>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
flax.linen.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

建構一個初始化器,會回傳均勻分佈的正交矩陣。

如果形狀不是正方形,矩陣會有正規直行或直欄,視哪一側較小而定。

參數
  • scale – 均勻分配的上限。

  • column_axis – 包含應為正交的欄的軸。

  • dtype – 權重的預設資料型別。

傳回

一個正交初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.9026976e-01,  7.2495741e-01, -5.6756169e-01],
       [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]],            dtype=float32)
flax.linen.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)#

建構一個初始化器,會回傳真實均勻分佈隨機矩陣。

參數
  • scale – 選用;隨機分佈的上限。

  • dtype – 可選擇;初始化器的預設資料型別。

傳回

一個初始化器,會回傳值在範圍 [0, scale) 之間均勻分佈的陣列。

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[7.298188 , 8.691938 , 8.7230015],
       [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
flax.linen.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

初始化器用於調整其比例以配合權重張量的形狀。

使用 distribution="truncated_normal"distribution="normal" 時,範例會從一個 (截斷) 常態分佈中抽取,其均值為 0,標準差 (若適用,在截斷之後) 為 \(\sqrt{\frac{scale}{n}}\),其中 n

  • 權重張量中的輸入單元數,若 mode="fan_in"

  • 輸出單元數,若 mode="fan_out",或

  • 輸入和輸出單元數的平均值,若 mode="fan_avg"

此初始化器可以透過 in_axisout_axis,以及 batch_axis 進行設定,以搭配一般的卷積或密集層;任何未包含在這些參數中的軸預設為「感受域」(卷積核的空間軸)。

使用 distribution="truncated_normal",範例中的絕對值在縮放前會先在 2 個標準差處進行截斷。

使用 distribution="uniform",範例會從以下位置繪製:

  • 均勻區間,如果 dtype 為實數,或

  • 均勻圓盤,如果 dtype 為複數,

平均值為零,標準差為 \(\sqrt{\frac{scale}{n}}\),其中 n 在上面定義。

參數
  • scale – 縮放因子(正浮點數)。

  • mode"fan_in""fan_out""fan_avg" 其中之一。

  • distribution – 要使用的亂數字元分布。其中 "truncated_normal""normal""uniform" 之一。

  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

flax.linen.initializers.xavier_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 Glorot 正態初始化器(又稱 Xavier 正態初始化器)。

一個 Glorot 正態初始化器jax.nn.initializers.variance_scaling()scale = 1.0, mode="fan_avg", 和 distribution="truncated_normal" 下的特例。

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.41770416,  0.75262755,  0.7619329 ],
       [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
flax.linen.initializers.xavier_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

建立 Glorot 均勻初始化器(又稱 Xavier 均勻初始化器)。

jax.nn.initializers.variance_scaling() 的特別化 Glorot 均勻初始化器 ,其 scale = 1.0mode="fan_avg"distribution="uniform"

參數
  • in_axis – 權重陣列中輸入維度的軸或軸序列。

  • out_axis – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis – 權重陣列中應略過的軸或軸序列。

  • dtype – 權重的資料型別。

傳回

一個初始化器。

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.50350785,  0.8088631 ,  0.81566876],
       [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
flax.linen.initializers.zeros(key, shape, dtype=<class 'jax.numpy.float64'>)#

一個回傳含有常數陣列(全為 0)的初始化器。

忽略 key 參數。

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
flax.linen.initializers.zeros_init()[來源]#

建構回傳一個由零填滿的常數陣列的初始化器。

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import zeros_init
>>> zeros_initializer = zeros_init()
>>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)