線性

線性#

NNX 線性層類別。

class flax.nnx.Conv(*args, **kwargs)[原始碼]#

卷積模組,封裝 lax.conv_general_dilated

範例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 6, 4)

>>> # circular padding with stride 2
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3),
...                  strides=2, padding='CIRCULAR', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 4, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
in_features#

輸入特徵的數量,可以是整數或元組。

out_features#

輸出特徵的數量,可以是整數或元組。

kernel_size#

卷積核的形狀。對於 1D 卷積,核大小可以作為整數傳遞,它將被解釋為單個整數的元組。對於所有其他情況,它必須是整數序列。

strides#

一個整數或 n 個整數的序列,表示視窗間的步幅(預設值:1)。

padding#

可以是字串 'SAME'、字串 'VALID'、字串 'CIRCULAR'(週期性邊界條件),或 n(low, high) 整數對的序列,指定在每個空間維度之前和之後應用的填充。單個整數會被解釋為在所有維度中應用相同的填充,而在序列中傳遞單個整數會導致在兩側使用相同的填充。對於 1D 卷積,'CAUSAL' 填充會將卷積軸左邊填充,以產生相同大小的輸出。

input_dilation#

一個整數或 n 個整數的序列,指定應用於 inputs 每個空間維度的膨脹因子(預設值:1)。輸入膨脹 d 的卷積等效於步幅為 d 的轉置卷積。

kernel_dilation#

一個整數或 n 個整數的序列,指定應用於卷積核每個空間維度的膨脹因子(預設值:1)。使用核膨脹的卷積也稱為「空洞卷積」。

feature_group_count#

整數,預設值為 1。如果指定,則將輸入特徵分為多個群組。

use_bias#

是否在輸出中加入偏置(預設值:True)。

mask#

可選的遮罩,用於在遮罩卷積期間遮蓋權重。遮罩的形狀必須與卷積權重矩陣相同。

dtype#

計算的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞給參數初始化器的資料類型(預設值:float32)。

precision#

計算的數值精度,詳細資訊請參閱 jax.lax.Precision

kernel_init#

卷積核的初始化器。

bias_init#

偏置的初始化器。

rngs#

rng 鍵。

__call__(inputs)[原始碼]#

將(可能未共享的)卷積應用於輸入。

參數

inputs – 輸入資料,維度為 (*batch_dims, spatial_dims..., features)。這是通道最後的慣例,例如,2D 卷積的 NHWC 和 3D 卷積的 NDHWC。注意:這與 lax.conv_general_dilated 使用的輸入慣例不同,後者將空間維度放在最後。注意:如果輸入有多個批次維度,則所有批次維度將被展平為單個維度以進行卷積,並在返回前還原。在某些情況下,直接 vmap 該層可能會產生比此預設展平方法更好的效能。如果輸入沒有批次維度,則將為卷積加入一個批次維度,並在返回時移除,此為啟用編寫單例程式碼所做的容許。

傳回

卷積後的資料。

方法

class flax.nnx.ConvTranspose(*args, **kwargs)[原始碼]#

包裝 lax.conv_transpose 的卷積模組。

範例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,),
...                           padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 10, 4)

>>> # circular padding with stride 2
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6),
...                           strides=(2, 2), padding='CIRCULAR',
...                           transpose_kernel=True, rngs=rngs)
>>> layer.kernel.value.shape
(6, 6, 4, 3)
>>> layer.bias.value.shape
(4,)
>>> out = layer(jnp.ones((1, 15, 15, 3)))
>>> out.shape
(1, 30, 30, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
in_features#

輸入特徵的數量,可以是整數或元組。

out_features#

輸出特徵的數量,可以是整數或元組。

kernel_size#

卷積核的形狀。對於 1D 卷積,核大小可以作為整數傳遞,它將被解釋為單個整數的元組。對於所有其他情況,它必須是整數序列。

strides#

一個整數或 n 個整數的序列,表示視窗間的步幅(預設值:1)。

padding#

可以是字串 'SAME'、字串 'VALID'、字串 'CIRCULAR'(週期性邊界條件),或 n(low, high) 整數對的序列,指定在每個空間維度之前和之後應用的填充。單個整數會被解釋為在所有維度中應用相同的填充,而在序列中傳遞單個整數會導致在兩側使用相同的填充。對於 1D 卷積,'CAUSAL' 填充會將卷積軸左邊填充,以產生相同大小的輸出。

kernel_dilation#

一個整數或 n 個整數的序列,指定應用於卷積核每個空間維度的膨脹因子(預設值:1)。使用核膨脹的卷積也稱為「空洞卷積」。

use_bias#

是否在輸出中加入偏置(預設值:True)。

mask#

可選的遮罩,用於在遮罩卷積期間遮蓋權重。遮罩的形狀必須與卷積權重矩陣相同。

dtype#

計算的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞給參數初始化器的資料類型(預設值:float32)。

precision#

計算的數值精度,詳細資訊請參閱 jax.lax.Precision

kernel_init#

卷積核的初始化器。

bias_init#

偏置的初始化器。

transpose_kernel#

如果為 True,則會翻轉空間軸並交換核心的輸入/輸出通道軸。

rngs#

rng 鍵。

__call__(inputs)[原始碼]#

將轉置卷積應用於輸入。

行為與 jax.lax.conv_transpose 相同。

參數

inputs – 輸入資料的維度為 (*batch_dims, spatial_dims..., features)。這是通道在後的慣例,例如,二維卷積為 NHWC,三維卷積為 NDHWC。注意:這與 lax.conv_general_dilated 使用的輸入慣例不同,後者將空間維度放在最後。注意:如果輸入有多個批次維度,所有批次維度將被展平成單一維度以進行卷積,並在返回之前還原。在某些情況下,直接對該層進行 vmap 可能會比這種預設的展平方法產生更好的效能。如果輸入缺少批次維度,則將為卷積新增該維度並在返回時移除,以便能夠編寫單一範例程式碼。

傳回

卷積後的資料。

方法

class flax.nnx.Embed(*args, **kwargs)[原始碼]#

嵌入模組。

範例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'embedding': VariableState(
    type=Param,
    value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
           [ 0.01070483,  0.27923733,  1.7487359 ],
           [ 0.59161806,  0.8660184 ,  1.2838588 ],
           [-0.748139  , -0.15856352,  0.06061118],
           [-0.4769059 , -0.6607095 ,  0.46697947]], dtype=float32)
  )
})
>>> # get the first three and last three embeddings
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> layer(indices_input)
Array([[[-0.90411377, -0.3648777 , -1.1083648 ],
        [ 0.01070483,  0.27923733,  1.7487359 ],
        [ 0.59161806,  0.8660184 ,  1.2838588 ]],

       [[-0.4769059 , -0.6607095 ,  0.46697947],
        [-0.748139  , -0.15856352,  0.06061118],
        [ 0.59161806,  0.8660184 ,  1.2838588 ]]], dtype=float32)

從整數 [0, num_embeddings) 到 features 維向量的參數化函式。此 Module 將建立一個形狀為 (num_embeddings, features)embedding 矩陣。呼叫此層時,輸入值將用於 0 索引到 embedding 矩陣中。索引大於或等於 num_embeddings 的值將導致 nan 值。當 num_embeddings 等於 1 時,它會將 embedding 矩陣廣播到輸入形狀,並附加 features 維度。

num_embeddings#

嵌入數量/詞彙大小。

features#

每個嵌入的特徵維度數量。

dtype#

嵌入向量的資料類型(預設值:與嵌入相同)。

param_dtype#

傳遞給參數初始化器的資料類型(預設值:float32)。

embedding_init#

嵌入初始化器。

rngs#

rng 鍵。

__call__(inputs)[原始碼]#

沿著最後一個維度嵌入輸入。

參數

inputs – 輸入資料,所有維度都被視為批次維度。輸入陣列中的值必須為整數。

傳回

輸出為嵌入的輸入資料。輸出形狀遵循輸入,並附加一個額外的 features 維度。

attend(query)[原始碼]#

使用查詢陣列對嵌入進行注意力操作。

參數

query – 最後一個維度等於嵌入的特徵深度 features 的陣列。

傳回

一個最後維度為 num_embeddings 的陣列,對應於查詢向量陣列與每個嵌入的批次內積。通常用於 NLP 模型中嵌入和 logit 轉換之間的權重共享。

方法

attend(query)

使用查詢陣列對嵌入進行注意力操作。

class flax.nnx.Linear(*args, **kwargs)[原始碼]#

應用於輸入最後一個維度的線性轉換。

範例用法

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(4,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(3, 4)
  )
})
in_features#

輸入特徵的數量。

out_features#

輸出特徵的數量。

use_bias#

是否在輸出中加入偏置(預設值:True)。

dtype#

計算的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞給參數初始化器的資料類型(預設值:float32)。

precision#

計算的數值精度,詳細資訊請參閱 jax.lax.Precision

kernel_init#

權重矩陣的初始化函式。

bias_init#

偏差的初始化函式。

dot_general#

點積函式。

rngs#

rng 鍵。

__call__(inputs)[原始碼]#

對輸入沿著最後一個維度應用線性轉換。

參數

inputs – 要轉換的 nd-array。

傳回

轉換後的輸入。

方法

class flax.nnx.LinearGeneral(*args, **kwargs)[原始碼]#

一個具有彈性軸的線性轉換。

範例用法

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> # equivalent to `nnx.Linear(2, 4)`
>>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4)
>>> # output features (4, 5)
>>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> # apply transformation on the the second and last axes
>>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 3, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> y = layer(jnp.ones((16, 2, 3)))
>>> y.shape
(16, 4, 5)
in_features#

輸入特徵的數量,可以是整數或元組。

out_features#

輸出特徵的數量,可以是整數或元組。

axis#

要應用轉換的軸的整數或元組。例如,(-2, -1) 將轉換應用於最後兩個軸。

batch_axis#

批次軸索引到軸大小的映射。

use_bias#

是否在輸出中加入偏置(預設值:True)。

dtype#

計算的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞給參數初始化器的資料類型(預設值:float32)。

kernel_init#

權重矩陣的初始化函式。

bias_init#

偏差的初始化函式。

precision#

計算的數值精度,詳細資訊請參閱 jax.lax.Precision

rngs#

rng 鍵。

__call__(inputs)[原始碼]#

沿著多個維度將線性轉換應用於輸入。

參數

inputs – 要轉換的 nd-array。

傳回

轉換後的輸入。

方法

class flax.nnx.Einsum(*args, **kwargs)[原始碼]#

一個具有可學習核心和偏差的 einsum 轉換。

範例用法

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(8, 2, 4)
>>> layer.bias.value.shape
(8, 4)
>>> y = layer(jnp.ones((16, 11, 2)))
>>> y.shape
(16, 11, 8, 4)
einsum_str#

一個表示 einsum 方程式的字串。方程式必須恰好有兩個運算元,左側 (lhs) 是傳入的輸入,右側 (rhs) 是可學習的核心。建構函式引數和呼叫引數中的 einsum_str 必須恰好有一個不是 None,而另一個必須是 None。

kernel_shape#

核心的形狀。

bias_shape#

偏差的形狀。如果這是 None,則不會使用偏差。

dtype#

計算的資料類型(預設值:從輸入和參數推斷)。

param_dtype#

傳遞給參數初始化器的資料類型(預設值:float32)。

precision#

計算的數值精度,詳細資訊請參閱 jax.lax.Precision

kernel_init#

權重矩陣的初始化函式。

bias_init#

偏差的初始化函式。

rngs#

rng 鍵。

__call__(inputs, einsum_str=None)[原始碼]#

對輸入沿著最後一個維度應用線性轉換。

參數
  • inputs – 要轉換的 nd-array。

  • einsum_str – 一個表示 einsum 方程式的字串。方程式必須恰好有兩個運算元,左側 (lhs) 是傳入的輸入,右側 (rhs) 是可學習的核心。建構函式引數和呼叫引數中的 einsum_str 必須恰好有一個不是 None,而另一個必須是 None。

傳回

轉換後的輸入。

方法