初始化器#
- flax.nnx.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)#
建立一個初始化器,它會傳回充滿常數
value
的陣列。- 參數
value – 要用來填充初始化器的常數值。
dtype – 可選;初始化器的預設 dtype。
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.constant(-7) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32)
- flax.nnx.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#
建立 delta 正交核心的初始化器。
- 參數
scale – 均勻分佈的上限。
column_axis – 包含應該正交的列的軸。
dtype – 權重的預設 dtype。
- 傳回值
一個 delta 正交初始化器。傳遞給初始化器的形狀必須是 3D、4D 或 5D。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() >>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) Array([[[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]], [[ 0.27858758, -0.7949833 , -0.53887904], [ 0.9120717 , 0.04322892, 0.40774566], [-0.30085585, -0.6050892 , 0.73712474]], [[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]]], dtype=float32)
- flax.nnx.initializers.glorot_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 Glorot 常態初始化器(又稱 Xavier 常態初始化器)。
Glorot 常態初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 1.0
、mode="fan_avg"
和distribution="truncated_normal"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
- flax.nnx.initializers.glorot_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 Glorot 均勻初始化器(又稱 Xavier 均勻初始化器)。
Glorot 均勻初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 1.0
、mode="fan_avg"
和distribution="uniform"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
- flax.nnx.initializers.he_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 He 常態初始化器(又稱 Kaiming 常態初始化器)。
He 常態初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 2.0
、mode="fan_in"
和distribution="truncated_normal"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
- flax.nnx.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 He 均勻初始化器(又稱 Kaiming 均勻初始化器)。
He 均勻初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 2.0
、mode="fan_in"
和distribution="uniform"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
- flax.nnx.initializers.kaiming_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 He 常態初始化器(又稱 Kaiming 常態初始化器)。
He 常態初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 2.0
、mode="fan_in"
和distribution="truncated_normal"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
- flax.nnx.initializers.kaiming_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 He 均勻初始化器(又稱 Kaiming 均勻初始化器)。
He 均勻初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 2.0
、mode="fan_in"
和distribution="uniform"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
- flax.nnx.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 Lecun 常態初始化器。
Lecun 常態初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 1.0
、mode="fan_in"
和distribution="truncated_normal"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.46700746, 0.8414632 , 0.8518669 ], [-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
- flax.nnx.initializers.lecun_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 Lecun 均勻初始化器。
Lecun 均勻初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 1.0
、mode="fan_in"
和distribution="uniform"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.56293887, 0.90433645, 0.9119454 ], [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
- flax.nnx.initializers.normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>)#
建立一個初始化器,會回傳實數、常態分佈的隨機陣列。
- 參數
stddev – 選填;分佈的標準差。
dtype – 可選;初始化器的預設 dtype。
- 傳回值
一個初始化器,會回傳數值呈常態分佈的陣列,其平均值為
0
,標準差為stddev
。
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.normal(5.0) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
- flax.nnx.initializers.truncated_normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>, lower=-2.0, upper=2.0)#
建立一個初始化器,會回傳截斷常態分佈的隨機陣列。
- 參數
stddev – 選填;未截斷分佈的標準差。請注意,此函式不會像 variancescaling 初始化器那樣套用標準差校正,如果使用者希望使用此校正,則應透過 stddev 引數自行套用。
dtype – 可選;初始化器的預設 dtype。
lower – 代表截斷下限的浮點數。會在輸出乘以 stddev 之前套用。
upper – 代表截斷上限的浮點數。會在輸出乘以 stddev 之前套用。
- 傳回值
一個初始化器,會回傳數值遵循截斷常態分佈的陣列,其平均值為
0
,標準差為stddev
,範圍為 \(\rm{lower * stddev} < x < \rm{upper * stddev}\)。
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.truncated_normal(5.0) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 2.9047365, 5.2338114, 5.29852 ], [-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
- flax.nnx.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)#
一個初始化器,會回傳一個充滿 1 的常數陣列。
會忽略
key
引數。>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)
- flax.nnx.initializers.ones_init()#
建立一個初始化器,會回傳一個充滿 1 的常數陣列。
>>> import jax, jax.numpy as jnp >>> from flax.nnx import initializers >>> ones_initializer = initializers.ones_init() >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)
- flax.nnx.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#
建立一個初始化器,會回傳均勻分佈的正交矩陣。
如果形狀不是正方形,則矩陣將具有正交的列或行,具體取決於哪一側較小。
- 參數
scale – 均勻分佈的上限。
column_axis – 包含應該正交的列的軸。
dtype – 權重的預設 dtype。
- 傳回值
一個正交初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
- flax.nnx.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)#
建立一個初始化器,會回傳實數、均勻分佈的隨機陣列。
- 參數
scale – 選填;隨機分佈的上限。
dtype – 可選;初始化器的預設 dtype。
- 傳回值
一個初始化器,會回傳數值均勻分佈在
[0, scale)
範圍內的陣列。
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.uniform(10.0) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
- flax.nnx.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
一個初始化器,會使其尺度適應權重張量的形狀。
使用
distribution="truncated_normal"
或distribution="normal"
時,樣本會從平均值為零且標準差(如果適用,則在截斷後)為 \(\sqrt{\frac{scale}{n}}\) 的(截斷)常態分佈中抽取,其中 n 為權重張量中的輸入單元數(如果
mode="fan_in"
),輸出單元數(如果
mode="fan_out"
),或輸入和輸出單元數的平均值(如果
mode="fan_avg"
)。
此初始化器可以使用
in_axis
、out_axis
和batch_axis
進行設定,以適用於通用卷積層或密集層;未在任何這些引數中的軸會被視為「感受野」(卷積核空間軸)。使用
distribution="truncated_normal"
時,樣本的絕對值會在縮放之前截斷為 2 個標準差。使用
distribution="uniform"
時,樣本會從一個均勻間隔中抽取(如果 dtype 為實數),或
一個均勻圓盤中抽取(如果 dtype 為複數),
其平均值為零,標準差為 \(\sqrt{\frac{scale}{n}}\),其中 n 定義如上。
- 參數
scale – 縮放因子(正浮點數)。
mode –
"fan_in"
、"fan_out"
和"fan_avg"
的其中一個。distribution – 要使用的隨機分佈。
"truncated_normal"
、"normal"
和"uniform"
的其中一個。in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- flax.nnx.initializers.xavier_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 Glorot 常態初始化器(又稱 Xavier 常態初始化器)。
Glorot 常態初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 1.0
、mode="fan_avg"
和distribution="truncated_normal"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
- flax.nnx.initializers.xavier_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
建立 Glorot 均勻初始化器(又稱 Xavier 均勻初始化器)。
Glorot 均勻初始化器是
jax.nn.initializers.variance_scaling()
的一個特例,其中scale = 1.0
、mode="fan_avg"
和distribution="uniform"
。- 參數
in_axis – 權重陣列中輸入維度的軸或軸序列。
out_axis – 權重陣列中輸出維度的軸或軸序列。
batch_axis – 權重陣列中應該忽略的軸或軸序列。
dtype – 權重的 dtype。
- 傳回值
一個初始化器。
範例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
- flax.nnx.initializers.zeros(key, shape, dtype=<class 'jax.numpy.float64'>)#
一個初始化器,會回傳一個充滿零的常數陣列。
會忽略
key
引數。>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- flax.nnx.initializers.zeros_init()#
建立一個初始化器,會回傳一個充滿零的常數陣列。
>>> import jax, jax.numpy as jnp >>> from flax.nnx import initializers >>> zeros_initializer = initializers.zeros_init() >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)