Attention#
|
Multi-head attention layer. |
|
Built-in attention backend selectors. |
Multi-head attention#
- class attnax.MultiHeadAttention(*args, **kwargs)[source]#
Bases:
ModuleMulti-head attention layer.
Supports MHA, GQA (
1 < num_kv_heads < num_heads) and MQA (num_kv_heads == 1); optional rotary position embeddings on Q and K; an attention backend selected byAttentionTypeor a user-suppliedAttentionFn; an optionalScoreModand causal sliding-window; and an optionalKVLayerCachefor autoregressive decoding.- Parameters:
rngs – Flax NNX random key container.
num_heads – Number of query heads.
in_features – Input dimensionality.
qkv_features – QKV projection width. Defaults to
in_features.out_features – Output projection width. Defaults to
in_features.num_kv_heads – Number of key/value heads. Must divide
num_heads. Defaults tonum_heads(MHA).dropout_rate – Output dropout probability.
broadcast_dropout – Share the dropout mask across the batch.
decode – Reserved; kept for API compatibility.
attention_type – Built-in backend selection. Ignored when
attention_fnis set.attention_block_size – Block size for
memory_efficient,flashandpallas_flashbackends.linear_attention_chunk_size – Chunk size for the
linearbackend.attention_fn – Custom kernel conforming to
AttentionFn. Takes priority overattention_type.score_mod –
ScoreModapplied on every call.attention_window – Causal sliding-window size. When set, each query attends only to the most recent
attention_windowkeys.use_rope – Apply rotary position embeddings to Q and K. Requires even
head_dim.rope_base – RoPE base \(\theta\).
rope_max_positions – Length of the precomputed RoPE table.
args (Any)
kwargs (Any)
- Raises:
ValueError – If
qkv_featuresis not divisible bynum_heads, ifnum_kv_headsdoes not satisfy1 <= num_kv_heads <= num_headsand dividenum_heads, ifuse_ropeis set with an oddhead_dim, or if bothattention_fnandAttentionType.LITEare supplied.- Return type:
Any
Attention backend#
- class attnax.AttentionType(value)[source]#
-
Built-in attention backend selectors.
- STANDARD#
Scaled dot-product attention with \(O(n^2)\) memory.
- MEMORY_EFFICIENT#
Block-wise online-softmax attention.
- FLASH#
jax.nn.dot_product_attention()on GPU,MEMORY_EFFICIENTelsewhere.
- PALLAS_FLASH#
Pallas-lowered FlashAttention with
score_modin the inner loop; falls back toMEMORY_EFFICIENT.
- LINEAR#
Chunkwise-parallel linear attention. Ignores
score_mod.
- LITE#
Element-wise gated attention.
Ring and paged attention are not enum entries because they require additional arguments (
axis_nameandPagedKVCache); passring_attention()orpaged_attention()directly viaattention_fn=.
See also
Kernels for the attention kernels (standard_attention,
memory_efficient_attention, flash_attention,
pallas_flash_attention, linear_attention, ring_attention,
paged_attention, lite_attention), the AttentionFn
protocol, and the prebuilt ScoreMod / MaskMod
constructors that are passed to MultiHeadAttention via
attention_fn= and score_mod=.