線性#
NNX 線性層類別。
- class flax.nnx.Conv(*args, **kwargs)[原始碼]#
卷積模組,封裝
lax.conv_general_dilated
。範例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), ... strides=2, padding='CIRCULAR', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- in_features#
輸入特徵的數量,可以是整數或元組。
- out_features#
輸出特徵的數量,可以是整數或元組。
- kernel_size#
卷積核的形狀。對於 1D 卷積,核大小可以作為整數傳遞,它將被解釋為單個整數的元組。對於所有其他情況,它必須是整數序列。
- strides#
一個整數或
n
個整數的序列,表示視窗間的步幅(預設值:1)。
- padding#
可以是字串
'SAME'
、字串'VALID'
、字串'CIRCULAR'
(週期性邊界條件),或n
個(low, high)
整數對的序列,指定在每個空間維度之前和之後應用的填充。單個整數會被解釋為在所有維度中應用相同的填充,而在序列中傳遞單個整數會導致在兩側使用相同的填充。對於 1D 卷積,'CAUSAL'
填充會將卷積軸左邊填充,以產生相同大小的輸出。
- input_dilation#
一個整數或
n
個整數的序列,指定應用於inputs
每個空間維度的膨脹因子(預設值:1)。輸入膨脹d
的卷積等效於步幅為d
的轉置卷積。
- kernel_dilation#
一個整數或
n
個整數的序列,指定應用於卷積核每個空間維度的膨脹因子(預設值:1)。使用核膨脹的卷積也稱為「空洞卷積」。
- feature_group_count#
整數,預設值為 1。如果指定,則將輸入特徵分為多個群組。
- use_bias#
是否在輸出中加入偏置(預設值:True)。
- mask#
可選的遮罩,用於在遮罩卷積期間遮蓋權重。遮罩的形狀必須與卷積權重矩陣相同。
- dtype#
計算的資料類型(預設值:從輸入和參數推斷)。
- param_dtype#
傳遞給參數初始化器的資料類型(預設值:float32)。
- precision#
計算的數值精度,詳細資訊請參閱
jax.lax.Precision
。
- kernel_init#
卷積核的初始化器。
- bias_init#
偏置的初始化器。
- rngs#
rng 鍵。
- __call__(inputs)[原始碼]#
將(可能未共享的)卷積應用於輸入。
- 參數
inputs – 輸入資料,維度為
(*batch_dims, spatial_dims..., features)
。這是通道最後的慣例,例如,2D 卷積的 NHWC 和 3D 卷積的 NDHWC。注意:這與lax.conv_general_dilated
使用的輸入慣例不同,後者將空間維度放在最後。注意:如果輸入有多個批次維度,則所有批次維度將被展平為單個維度以進行卷積,並在返回前還原。在某些情況下,直接 vmap 該層可能會產生比此預設展平方法更好的效能。如果輸入沒有批次維度,則將為卷積加入一個批次維度,並在返回時移除,此為啟用編寫單例程式碼所做的容許。- 傳回
卷積後的資料。
方法
- class flax.nnx.ConvTranspose(*args, **kwargs)[原始碼]#
包裝
lax.conv_transpose
的卷積模組。範例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6), ... strides=(2, 2), padding='CIRCULAR', ... transpose_kernel=True, rngs=rngs) >>> layer.kernel.value.shape (6, 6, 4, 3) >>> layer.bias.value.shape (4,) >>> out = layer(jnp.ones((1, 15, 15, 3))) >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- in_features#
輸入特徵的數量,可以是整數或元組。
- out_features#
輸出特徵的數量,可以是整數或元組。
- kernel_size#
卷積核的形狀。對於 1D 卷積,核大小可以作為整數傳遞,它將被解釋為單個整數的元組。對於所有其他情況,它必須是整數序列。
- strides#
一個整數或
n
個整數的序列,表示視窗間的步幅(預設值:1)。
- padding#
可以是字串
'SAME'
、字串'VALID'
、字串'CIRCULAR'
(週期性邊界條件),或n
個(low, high)
整數對的序列,指定在每個空間維度之前和之後應用的填充。單個整數會被解釋為在所有維度中應用相同的填充,而在序列中傳遞單個整數會導致在兩側使用相同的填充。對於 1D 卷積,'CAUSAL'
填充會將卷積軸左邊填充,以產生相同大小的輸出。
- kernel_dilation#
一個整數或
n
個整數的序列,指定應用於卷積核每個空間維度的膨脹因子(預設值:1)。使用核膨脹的卷積也稱為「空洞卷積」。
- use_bias#
是否在輸出中加入偏置(預設值:True)。
- mask#
可選的遮罩,用於在遮罩卷積期間遮蓋權重。遮罩的形狀必須與卷積權重矩陣相同。
- dtype#
計算的資料類型(預設值:從輸入和參數推斷)。
- param_dtype#
傳遞給參數初始化器的資料類型(預設值:float32)。
- precision#
計算的數值精度,詳細資訊請參閱
jax.lax.Precision
。
- kernel_init#
卷積核的初始化器。
- bias_init#
偏置的初始化器。
- transpose_kernel#
如果為
True
,則會翻轉空間軸並交換核心的輸入/輸出通道軸。
- rngs#
rng 鍵。
- __call__(inputs)[原始碼]#
將轉置卷積應用於輸入。
行為與
jax.lax.conv_transpose
相同。- 參數
inputs – 輸入資料的維度為
(*batch_dims, spatial_dims..., features)
。這是通道在後的慣例,例如,二維卷積為NHWC
,三維卷積為NDHWC
。注意:這與lax.conv_general_dilated
使用的輸入慣例不同,後者將空間維度放在最後。注意:如果輸入有多個批次維度,所有批次維度將被展平成單一維度以進行卷積,並在返回之前還原。在某些情況下,直接對該層進行 vmap 可能會比這種預設的展平方法產生更好的效能。如果輸入缺少批次維度,則將為卷積新增該維度並在返回時移除,以便能夠編寫單一範例程式碼。- 傳回
卷積後的資料。
方法
- class flax.nnx.Embed(*args, **kwargs)[原始碼]#
嵌入模組。
範例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'embedding': VariableState( type=Param, value=Array([[-0.90411377, -0.3648777 , -1.1083648 ], [ 0.01070483, 0.27923733, 1.7487359 ], [ 0.59161806, 0.8660184 , 1.2838588 ], [-0.748139 , -0.15856352, 0.06061118], [-0.4769059 , -0.6607095 , 0.46697947]], dtype=float32) ) }) >>> # get the first three and last three embeddings >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> layer(indices_input) Array([[[-0.90411377, -0.3648777 , -1.1083648 ], [ 0.01070483, 0.27923733, 1.7487359 ], [ 0.59161806, 0.8660184 , 1.2838588 ]], [[-0.4769059 , -0.6607095 , 0.46697947], [-0.748139 , -0.15856352, 0.06061118], [ 0.59161806, 0.8660184 , 1.2838588 ]]], dtype=float32)
從整數 [0,
num_embeddings
) 到features
維向量的參數化函式。此Module
將建立一個形狀為(num_embeddings, features)
的embedding
矩陣。呼叫此層時,輸入值將用於 0 索引到embedding
矩陣中。索引大於或等於num_embeddings
的值將導致nan
值。當num_embeddings
等於 1 時,它會將embedding
矩陣廣播到輸入形狀,並附加features
維度。- num_embeddings#
嵌入數量/詞彙大小。
- features#
每個嵌入的特徵維度數量。
- dtype#
嵌入向量的資料類型(預設值:與嵌入相同)。
- param_dtype#
傳遞給參數初始化器的資料類型(預設值:float32)。
- embedding_init#
嵌入初始化器。
- rngs#
rng 鍵。
- __call__(inputs)[原始碼]#
沿著最後一個維度嵌入輸入。
- 參數
inputs – 輸入資料,所有維度都被視為批次維度。輸入陣列中的值必須為整數。
- 傳回
輸出為嵌入的輸入資料。輸出形狀遵循輸入,並附加一個額外的
features
維度。
- attend(query)[原始碼]#
使用查詢陣列對嵌入進行注意力操作。
- 參數
query – 最後一個維度等於嵌入的特徵深度
features
的陣列。- 傳回
一個最後維度為
num_embeddings
的陣列,對應於查詢向量陣列與每個嵌入的批次內積。通常用於 NLP 模型中嵌入和 logit 轉換之間的權重共享。
方法
attend
(query)使用查詢陣列對嵌入進行注意力操作。
- class flax.nnx.Linear(*args, **kwargs)[原始碼]#
應用於輸入最後一個維度的線性轉換。
範例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(4,) ), 'kernel': VariableState( type=Param, value=(3, 4) ) })
- in_features#
輸入特徵的數量。
- out_features#
輸出特徵的數量。
- use_bias#
是否在輸出中加入偏置(預設值:True)。
- dtype#
計算的資料類型(預設值:從輸入和參數推斷)。
- param_dtype#
傳遞給參數初始化器的資料類型(預設值:float32)。
- precision#
計算的數值精度,詳細資訊請參閱
jax.lax.Precision
。
- kernel_init#
權重矩陣的初始化函式。
- bias_init#
偏差的初始化函式。
- dot_general#
點積函式。
- rngs#
rng 鍵。
方法
- class flax.nnx.LinearGeneral(*args, **kwargs)[原始碼]#
一個具有彈性軸的線性轉換。
範例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` >>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4) >>> # output features (4, 5) >>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4, 5) >>> layer.bias.value.shape (4, 5) >>> # apply transformation on the the second and last axes >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 3, 4, 5) >>> layer.bias.value.shape (4, 5) >>> y = layer(jnp.ones((16, 2, 3))) >>> y.shape (16, 4, 5)
- in_features#
輸入特徵的數量,可以是整數或元組。
- out_features#
輸出特徵的數量,可以是整數或元組。
- axis#
要應用轉換的軸的整數或元組。例如,(-2, -1) 將轉換應用於最後兩個軸。
- batch_axis#
批次軸索引到軸大小的映射。
- use_bias#
是否在輸出中加入偏置(預設值:True)。
- dtype#
計算的資料類型(預設值:從輸入和參數推斷)。
- param_dtype#
傳遞給參數初始化器的資料類型(預設值:float32)。
- kernel_init#
權重矩陣的初始化函式。
- bias_init#
偏差的初始化函式。
- precision#
計算的數值精度,詳細資訊請參閱
jax.lax.Precision
。
- rngs#
rng 鍵。
方法
- class flax.nnx.Einsum(*args, **kwargs)[原始碼]#
一個具有可學習核心和偏差的 einsum 轉換。
範例用法
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (8, 2, 4) >>> layer.bias.value.shape (8, 4) >>> y = layer(jnp.ones((16, 11, 2))) >>> y.shape (16, 11, 8, 4)
- einsum_str#
一個表示 einsum 方程式的字串。方程式必須恰好有兩個運算元,左側 (lhs) 是傳入的輸入,右側 (rhs) 是可學習的核心。建構函式引數和呼叫引數中的
einsum_str
必須恰好有一個不是 None,而另一個必須是 None。
- kernel_shape#
核心的形狀。
- bias_shape#
偏差的形狀。如果這是 None,則不會使用偏差。
- dtype#
計算的資料類型(預設值:從輸入和參數推斷)。
- param_dtype#
傳遞給參數初始化器的資料類型(預設值:float32)。
- precision#
計算的數值精度,詳細資訊請參閱
jax.lax.Precision
。
- kernel_init#
權重矩陣的初始化函式。
- bias_init#
偏差的初始化函式。
- rngs#
rng 鍵。
- __call__(inputs, einsum_str=None)[原始碼]#
對輸入沿著最後一個維度應用線性轉換。
- 參數
inputs – 要轉換的 nd-array。
einsum_str – 一個表示 einsum 方程式的字串。方程式必須恰好有兩個運算元,左側 (lhs) 是傳入的輸入,右側 (rhs) 是可學習的核心。建構函式引數和呼叫引數中的
einsum_str
必須恰好有一個不是 None,而另一個必須是 None。
- 傳回
轉換後的輸入。
方法