注意力#
- class flax.nnx.MultiHeadAttention(*args, **kwargs)[來源]#
多頭注意力機制。
使用範例
>>> from flax import nnx >>> import jax >>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16, ... decode=False, rngs=nnx.Rngs(0)) >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> shape = (4, 3, 2, 5) >>> q, k, v = ( ... jax.random.uniform(key1, shape), ... jax.random.uniform(key2, shape), ... jax.random.uniform(key3, shape), ... ) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer(q, k, v) >>> # equivalent output when inferring v >>> assert (layer(q, k) == layer(q, k, k)).all() >>> # equivalent output when inferring k and v >>> assert (layer(q) == layer(q, q)).all() >>> assert (layer(q) == layer(q, q, q)).all()
- num_heads#
注意力頭的數量。特徵(即 inputs_q.shape[-1])應可被頭的數量整除。
- in_features#
整數或包含輸入特徵數量的元組。
- qkv_features#
鍵、查詢和值的維度。
- out_features#
最後投影的維度
- dtype#
計算的 dtype (預設:從輸入和參數推斷)
- param_dtype#
傳遞給參數初始化器的 dtype (預設:float32)
- broadcast_dropout#
布林值:使用沿著批次維度廣播的 dropout。
- dropout_rate#
dropout 率
- deterministic#
如果為 false,則注意力權重會使用 dropout 隨機遮罩,如果為 true,則注意力權重是確定的。
- precision#
計算的數值精度,詳情請參閱 jax.lax.Precision。
- kernel_init#
用於密集層的核初始化器。
- out_kernel_init#
用於輸出密集層核的可選初始化器,如果為 None,則使用 kernel_init。
- bias_init#
用於密集層偏差的初始化器。
- out_bias_init#
用於輸出密集層偏差的可選初始化器,如果為 None,則使用 bias_init。
- use_bias#
布林值:逐點 QKVO 密集轉換是否使用偏差。
- attention_fn#
點積注意力或相容的函式。接受查詢、鍵、值,並返回形狀為 [bs, dim1, dim2, …, dimN,, num_heads, value_channels]` 的輸出
- decode#
是否準備並使用自動迴歸快取。
- normalize_qk#
是否應套用 QK 正規化 (arxiv.org/abs/2302.05442)。
- rngs#
rng 金鑰。
- __call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[來源]#
將多頭點積注意力應用於輸入資料。
將輸入投影到多頭查詢、鍵和值向量中,應用點積注意力,並將結果投影到輸出向量。
如果 inputs_k 和 inputs_v 皆為 None,則它們都會複製 inputs_q 的值(自我注意力)。如果只有 inputs_v 為 None,則它會複製 inputs_k 的值。
- 參數
inputs_q – 形狀為 [batch_sizes…, length, features] 的輸入查詢。
inputs_k – 形狀為 [batch_sizes…, length, features] 的鍵。如果為 None,inputs_k 將複製 inputs_q 的值。
inputs_v – 形狀為 [batch_sizes…, length, features] 的值。如果為 None,inputs_v 將複製 inputs_k 的值。
mask – 形狀為 [batch_sizes…, num_heads, query_length, key/value_length] 的注意力遮罩。如果對應的遮罩值為 False,則會遮罩注意力權重。
deterministic – 如果為 false,則注意力權重會使用 dropout 隨機遮罩,如果為 true,則注意力權重是確定的。傳遞到呼叫方法的
deterministic
旗標將優先於傳遞到建構函式的deterministic
旗標。rngs – rng 金鑰。傳遞到呼叫方法的 rng 金鑰將優先於傳遞到建構函式的 rng 金鑰。
sow_weights – 如果
True
,則注意力權重會被播種到「intermediates」集合中。decode – 是否準備並使用自動迴歸快取。傳遞到呼叫方法的
decode
旗標將優先於傳遞到建構函式的decode
旗標。
- 返回
形狀為 [batch_sizes…, length, features] 的輸出。
- init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[來源]#
初始化用於快速自動迴歸解碼的快取。當
decode=True
時,必須先呼叫此方法,然後才能執行前向推論。在解碼模式下,一次只能傳遞一個 token。使用範例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> batch_size = 5 >>> embed_dim = 3 >>> x = jnp.ones((batch_size, 1, embed_dim)) # single token ... >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(42), ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x)
方法
init_cache
(input_shape[, dtype])初始化用於快速自動迴歸解碼的快取。
- flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[來源]#
組合注意力遮罩。
- 參數
*masks – 要組合的一組注意力遮罩引數,有些可以是 None。
dtype – 返回的遮罩的 dtype。
- 返回
組合遮罩,透過邏輯 AND 減少,如果沒有給定遮罩,則返回 None。
- flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[原始碼]#
計算給定 query、key 和 value 的點積注意力。
這是基於 https://arxiv.org/abs/1706.03762 應用注意力的核心函數。它會計算給定 query 和 key 的注意力權重,並使用這些權重組合 value。
注意
query
、key
、value
不需要有任何批次維度。- 參數
query – 用於計算注意力的查詢,形狀為
[batch..., q_length, num_heads, qk_depth_per_head]
。key – 用於計算注意力的鍵,形狀為
[batch..., kv_length, num_heads, qk_depth_per_head]
。value – 在注意力中使用的值,形狀為
[batch..., kv_length, num_heads, v_depth_per_head]
。bias – 注意力權重的偏差。它應該可廣播到形狀為 [batch…, num_heads, q_length, kv_length]。這可以用於加入因果遮罩、填充遮罩、鄰近偏差等。
mask – 注意力權重的遮罩。它應該可廣播到形狀為 [batch…, num_heads, q_length, kv_length]。這可以用於加入因果遮罩。如果對應的遮罩值為 False,則注意力權重會被遮蔽。
broadcast_dropout – bool:沿著批次維度使用廣播的 dropout。
dropout_rng – JAX PRNGKey:用於 dropout。
dropout_rate – dropout 率。
deterministic – bool,是否為確定性的(用於套用 dropout)。
dtype – 計算的資料類型(預設:從輸入推斷)。
precision – 計算的數值精度,詳細資訊請參閱 jax.lax.Precision。
module – 將注意力權重播種到
nnx.Intermediate
集合的 Module。如果module
為 None,則不會播種注意力權重。
- 返回
輸出形狀為 [batch…, q_length, num_heads, v_depth_per_head]。
- flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[原始碼]#
用於注意力權重的遮罩建立輔助函數。
對於 1 維輸入(即 [batch…, len_q]、[batch…, len_kv]),注意力權重將為 [batch…, heads, len_q, len_kv],而此函數將產生 [batch…, 1, len_q, len_kv]。
- 參數
query_input – 批次的、平坦的 query_length 大小的輸入。
key_input – 批次的、平坦的 key_length 大小的輸入。
pairwise_fn – 廣播式的逐元素比較函數。
extra_batch_dims – 要新增單例軸的額外批次維度數量,預設為無。
dtype – 遮罩回傳的資料類型。
- 返回
用於 1 維注意力的 [batch…, 1, len_q, len_kv] 形狀的遮罩。
- flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[原始碼]#
為自我注意力建立因果遮罩。
對於 1 維輸入(即 [batch…, len]),自我注意力權重將為 [batch…, heads, len, len],而此函數將產生形狀為 [batch…, 1, len, len] 的因果遮罩。
- 參數
x – 形狀為 [batch…, len] 的輸入陣列。
extra_batch_dims – 要新增單例軸的批次維度數量,預設為無。
dtype – 遮罩回傳的資料類型。
- 返回
用於 1 維注意力的 [batch…, 1, len, len] 形狀的因果遮罩。