Attention#

MultiHeadAttention(*args, **kwargs)

Multi-head attention layer.

AttentionType(value)

Built-in attention backend selectors.

Multi-head attention#

class attnax.MultiHeadAttention(*args, **kwargs)[source]#

Bases: Module

Multi-head attention layer.

Supports MHA, GQA (1 < num_kv_heads < num_heads) and MQA (num_kv_heads == 1); optional rotary position embeddings on Q and K; an attention backend selected by AttentionType or a user-supplied AttentionFn; an optional ScoreMod and causal sliding-window; and an optional KVLayerCache for autoregressive decoding.

Parameters:
  • rngs – Flax NNX random key container.

  • num_heads – Number of query heads.

  • in_features – Input dimensionality.

  • qkv_features – QKV projection width. Defaults to in_features.

  • out_features – Output projection width. Defaults to in_features.

  • num_kv_heads – Number of key/value heads. Must divide num_heads. Defaults to num_heads (MHA).

  • dropout_rate – Output dropout probability.

  • broadcast_dropout – Share the dropout mask across the batch.

  • decode – Reserved; kept for API compatibility.

  • attention_type – Built-in backend selection. Ignored when attention_fn is set.

  • attention_block_size – Block size for memory_efficient, flash and pallas_flash backends.

  • linear_attention_chunk_size – Chunk size for the linear backend.

  • attention_fn – Custom kernel conforming to AttentionFn. Takes priority over attention_type.

  • score_modScoreMod applied on every call.

  • attention_window – Causal sliding-window size. When set, each query attends only to the most recent attention_window keys.

  • use_rope – Apply rotary position embeddings to Q and K. Requires even head_dim.

  • rope_base – RoPE base \(\theta\).

  • rope_max_positions – Length of the precomputed RoPE table.

  • args (Any)

  • kwargs (Any)

Raises:

ValueError – If qkv_features is not divisible by num_heads, if num_kv_heads does not satisfy 1 <= num_kv_heads <= num_heads and divide num_heads, if use_rope is set with an odd head_dim, or if both attention_fn and AttentionType.LITE are supplied.

Return type:

Any

Attention backend#

class attnax.AttentionType(value)[source]#

Bases: str, Enum

Built-in attention backend selectors.

STANDARD#

Scaled dot-product attention with \(O(n^2)\) memory.

MEMORY_EFFICIENT#

Block-wise online-softmax attention.

FLASH#

jax.nn.dot_product_attention() on GPU, MEMORY_EFFICIENT elsewhere.

PALLAS_FLASH#

Pallas-lowered FlashAttention with score_mod in the inner loop; falls back to MEMORY_EFFICIENT.

LINEAR#

Chunkwise-parallel linear attention. Ignores score_mod.

LITE#

Element-wise gated attention.

Ring and paged attention are not enum entries because they require additional arguments (axis_name and PagedKVCache); pass ring_attention() or paged_attention() directly via attention_fn=.

See also

Kernels for the attention kernels (standard_attention, memory_efficient_attention, flash_attention, pallas_flash_attention, linear_attention, ring_attention, paged_attention, lite_attention), the AttentionFn protocol, and the prebuilt ScoreMod / MaskMod constructors that are passed to MultiHeadAttention via attention_fn= and score_mod=.