注意力#

class flax.nnx.MultiHeadAttention(*args, **kwargs)[來源]#

多頭注意力機制。

使用範例

>>> from flax import nnx
>>> import jax

>>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16,
...                                decode=False, rngs=nnx.Rngs(0))
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = (
...   jax.random.uniform(key1, shape),
...   jax.random.uniform(key2, shape),
...   jax.random.uniform(key3, shape),
... )

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer(q, k, v)
>>> # equivalent output when inferring v
>>> assert (layer(q, k) == layer(q, k, k)).all()
>>> # equivalent output when inferring k and v
>>> assert (layer(q) == layer(q, q)).all()
>>> assert (layer(q) == layer(q, q, q)).all()
num_heads#

注意力頭的數量。特徵(即 inputs_q.shape[-1])應可被頭的數量整除。

in_features#

整數或包含輸入特徵數量的元組。

qkv_features#

鍵、查詢和值的維度。

out_features#

最後投影的維度

dtype#

計算的 dtype (預設:從輸入和參數推斷)

param_dtype#

傳遞給參數初始化器的 dtype (預設:float32)

broadcast_dropout#

布林值:使用沿著批次維度廣播的 dropout。

dropout_rate#

dropout 率

deterministic#

如果為 false,則注意力權重會使用 dropout 隨機遮罩,如果為 true,則注意力權重是確定的。

precision#

計算的數值精度,詳情請參閱 jax.lax.Precision

kernel_init#

用於密集層的核初始化器。

out_kernel_init#

用於輸出密集層核的可選初始化器,如果為 None,則使用 kernel_init。

bias_init#

用於密集層偏差的初始化器。

out_bias_init#

用於輸出密集層偏差的可選初始化器,如果為 None,則使用 bias_init。

use_bias#

布林值:逐點 QKVO 密集轉換是否使用偏差。

attention_fn#

點積注意力或相容的函式。接受查詢、鍵、值,並返回形狀為 [bs, dim1, dim2, …, dimN,, num_heads, value_channels]` 的輸出

decode#

是否準備並使用自動迴歸快取。

normalize_qk#

是否應套用 QK 正規化 (arxiv.org/abs/2302.05442)。

rngs#

rng 金鑰。

__call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[來源]#

將多頭點積注意力應用於輸入資料。

將輸入投影到多頭查詢、鍵和值向量中,應用點積注意力,並將結果投影到輸出向量。

如果 inputs_k 和 inputs_v 皆為 None,則它們都會複製 inputs_q 的值(自我注意力)。如果只有 inputs_v 為 None,則它會複製 inputs_k 的值。

參數
  • inputs_q – 形狀為 [batch_sizes…, length, features] 的輸入查詢。

  • inputs_k – 形狀為 [batch_sizes…, length, features] 的鍵。如果為 None,inputs_k 將複製 inputs_q 的值。

  • inputs_v – 形狀為 [batch_sizes…, length, features] 的值。如果為 None,inputs_v 將複製 inputs_k 的值。

  • mask – 形狀為 [batch_sizes…, num_heads, query_length, key/value_length] 的注意力遮罩。如果對應的遮罩值為 False,則會遮罩注意力權重。

  • deterministic – 如果為 false,則注意力權重會使用 dropout 隨機遮罩,如果為 true,則注意力權重是確定的。傳遞到呼叫方法的 deterministic 旗標將優先於傳遞到建構函式的 deterministic 旗標。

  • rngs – rng 金鑰。傳遞到呼叫方法的 rng 金鑰將優先於傳遞到建構函式的 rng 金鑰。

  • sow_weights – 如果 True,則注意力權重會被播種到「intermediates」集合中。

  • decode – 是否準備並使用自動迴歸快取。傳遞到呼叫方法的 decode 旗標將優先於傳遞到建構函式的 decode 旗標。

返回

形狀為 [batch_sizes…, length, features] 的輸出。

init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[來源]#

初始化用於快速自動迴歸解碼的快取。當 decode=True 時,必須先呼叫此方法,然後才能執行前向推論。在解碼模式下,一次只能傳遞一個 token。

使用範例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> batch_size = 5
>>> embed_dim = 3
>>> x = jnp.ones((batch_size, 1, embed_dim)) # single token
...
>>> model_nnx = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=nnx.Rngs(42),
... )
...
>>> # out_nnx = model_nnx(x)  <-- throws an error because cache isn't initialized
...
>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)

方法

init_cache(input_shape[, dtype])

初始化用於快速自動迴歸解碼的快取。

flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[來源]#

組合注意力遮罩。

參數
  • *masks – 要組合的一組注意力遮罩引數,有些可以是 None。

  • dtype – 返回的遮罩的 dtype。

返回

組合遮罩,透過邏輯 AND 減少,如果沒有給定遮罩,則返回 None。

flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[原始碼]#

計算給定 query、key 和 value 的點積注意力。

這是基於 https://arxiv.org/abs/1706.03762 應用注意力的核心函數。它會計算給定 query 和 key 的注意力權重,並使用這些權重組合 value。

注意

querykeyvalue 不需要有任何批次維度。

參數
  • query – 用於計算注意力的查詢,形狀為 [batch..., q_length, num_heads, qk_depth_per_head]

  • key – 用於計算注意力的鍵,形狀為 [batch..., kv_length, num_heads, qk_depth_per_head]

  • value – 在注意力中使用的值,形狀為 [batch..., kv_length, num_heads, v_depth_per_head]

  • bias – 注意力權重的偏差。它應該可廣播到形狀為 [batch…, num_heads, q_length, kv_length]。這可以用於加入因果遮罩、填充遮罩、鄰近偏差等。

  • mask – 注意力權重的遮罩。它應該可廣播到形狀為 [batch…, num_heads, q_length, kv_length]。這可以用於加入因果遮罩。如果對應的遮罩值為 False,則注意力權重會被遮蔽。

  • broadcast_dropout – bool:沿著批次維度使用廣播的 dropout。

  • dropout_rng – JAX PRNGKey:用於 dropout。

  • dropout_rate – dropout 率。

  • deterministic – bool,是否為確定性的(用於套用 dropout)。

  • dtype – 計算的資料類型(預設:從輸入推斷)。

  • precision – 計算的數值精度,詳細資訊請參閱 jax.lax.Precision

  • module – 將注意力權重播種到 nnx.Intermediate 集合的 Module。如果 module 為 None,則不會播種注意力權重。

返回

輸出形狀為 [batch…, q_length, num_heads, v_depth_per_head]

flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[原始碼]#

用於注意力權重的遮罩建立輔助函數。

對於 1 維輸入(即 [batch…, len_q][batch…, len_kv]),注意力權重將為 [batch…, heads, len_q, len_kv],而此函數將產生 [batch…, 1, len_q, len_kv]

參數
  • query_input – 批次的、平坦的 query_length 大小的輸入。

  • key_input – 批次的、平坦的 key_length 大小的輸入。

  • pairwise_fn – 廣播式的逐元素比較函數。

  • extra_batch_dims – 要新增單例軸的額外批次維度數量,預設為無。

  • dtype – 遮罩回傳的資料類型。

返回

用於 1 維注意力的 [batch…, 1, len_q, len_kv] 形狀的遮罩。

flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[原始碼]#

為自我注意力建立因果遮罩。

對於 1 維輸入(即 [batch…, len]),自我注意力權重將為 [batch…, heads, len, len],而此函數將產生形狀為 [batch…, 1, len, len] 的因果遮罩。

參數
  • x – 形狀為 [batch…, len] 的輸入陣列。

  • extra_batch_dims – 要新增單例軸的批次維度數量,預設為無。

  • dtype – 遮罩回傳的資料類型。

返回

用於 1 維注意力的 [batch…, 1, len, len] 形狀的因果遮罩。