Kernels#
The attnax.kernels subpackage contains the attention kernels as
pure JAX functions conforming to the AttentionFn protocol, and
ScoreMod / MaskMod constructors for common attention
biases. Each kernel can be passed to MultiHeadAttention via
attention_fn= or called standalone on pre-projected
(batch, num_heads, seq, head_dim) tensors.
See also
Attention for MultiHeadAttention and
AttentionType.
Protocols#
- class attnax.kernels.AttentionFn(*args, **kwargs)[source]#
Bases:
ProtocolSignature for an attention kernel.
A pure JAX callable mapping pre-projected
(batch, num_heads, seq, head_dim)query / key / value tensors to the attended output in the same layout. Carries no trainable parameters.
- class attnax.kernels.ScoreMod(*args, **kwargs)[source]#
Bases:
ProtocolCallable that modifies pre-softmax attention scores.
Returns
scores + bias(or any other position-conditioned transformation) of the same shape asscores. Position indices broadcast over the batch, head, query and key axes respectively.
Built-in kernels#
- attnax.kernels.standard_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True)[source]#
Scaled dot-product attention.
\[\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left( \frac{QK^\top}{\sqrt{d_k}} + \Delta\right) V\]where \(\Delta\) is the optional
score_modbias. Entries withmask == Falseare set to \(-\infty\) before the softmax. Activation memory is \(O(n^2)\).- Parameters:
query (Array) – Array of shape
(batch, num_heads, seq_q, head_dim).key (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).value (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).mask (Array | None) – Boolean mask broadcastable to
(batch, num_heads, seq_q, seq_kv).Truemeans attend.score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores. See
attnax.kernels.score_mods.dropout_rng (Array | None) – PRNG key for attention-weight dropout.
dropout_rate (float) – Dropout probability applied to attention weights.
deterministic (bool) – If
True, disables dropout.
- Returns:
Array of shape
(batch, num_heads, seq_q, head_dim).- Return type:
Array
- attnax.kernels.memory_efficient_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, block_size=512)[source]#
Block-wise attention with \(O(n)\) activation memory.
Tiles the key/value sequence into blocks of size
block_sizeand accumulates a running max, softmax denominator, and output via the online-softmax recurrence. Mathematically identical tostandard_attention()but does not materialise the full(seq_q, seq_kv)score matrix. Falls back tostandard_attention()when both axes fit in one block.- Parameters:
query (Array) – Array of shape
(batch, num_heads, seq_q, head_dim).key (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).value (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).mask (Array | None) – Boolean mask broadcastable to
(batch, num_heads, seq_q, seq_kv).score_mod (ScoreMod | None) – Callable applied block-by-block to the pre-softmax scores; position indices are global.
dropout_rng (Array | None) – PRNG key for output dropout.
dropout_rate (float) – Dropout probability applied to the output.
deterministic (bool) – If
True, disables dropout.block_size (int) – Number of positions per query / key block.
- Returns:
Array of shape
(batch, num_heads, seq_q, head_dim).- Return type:
Array
References
Milakov and Gimelshein, Online normalizer calculation for softmax, 2018.
Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022.
- attnax.kernels.flash_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, block_size=512)[source]#
Hardware-dispatched scaled dot-product attention.
Dispatches to
jax.nn.dot_product_attention()on GPU (backed by cuDNN’s FlashAttention when available) and tomemory_efficient_attention()on other backends. Whenscore_modis set, always usesmemory_efficient_attention().- Parameters:
query (Array) – Array of shape
(batch, num_heads, seq_q, head_dim).key (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).value (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).mask (Array | None) – Boolean mask broadcastable to
(batch, num_heads, seq_q, seq_kv).score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores. Forces the memory-efficient fallback when set.
dropout_rng (Array | None) – PRNG key for dropout (fallback path only).
dropout_rate (float) – Dropout probability.
deterministic (bool) – If
True, disables dropout.block_size (int) – Block size for the fallback path.
- Returns:
Array of shape
(batch, num_heads, seq_q, head_dim).- Return type:
Array
References
Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022.
- attnax.kernels.pallas_flash_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, block_q=128, block_kv=128, force_fallback=False)[source]#
FlashAttention forward pass compiled with Pallas.
Dispatches to a Pallas-lowered kernel on GPU/TPU and falls back to
memory_efficient_attention()on CPU or when the kernel fails to lower. Forward pass only; the gradient is computed by tracing the kernel withjvprather than by a hand-written backward.- Parameters:
query (Array) – Array of shape
(batch, num_heads, seq_q, head_dim).key (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).value (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).mask (Array | None) – Boolean mask broadcastable to
(batch, num_heads, seq_q, seq_kv).score_mod (ScoreMod | None) – Callable traced into the inner loop; position indices are global.
dropout_rng (Array | None) – PRNG key for output dropout.
dropout_rate (float) – Output dropout probability.
deterministic (bool) – If
True, disables dropout.block_q (int) – Queries per kernel program.
seq_qmust be divisible byblock_q.block_kv (int) – Keys per inner-loop block.
seq_kvmust be divisible byblock_kv.force_fallback (bool) – If
True, always use the pure-JAX fallback.
- Returns:
Array of shape
(batch, num_heads, seq_q, head_dim).- Return type:
Array
References
Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022.
- attnax.kernels.ring_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, axis_name=None, block_size=512)[source]#
Sequence-parallel ring attention.
Each device holds a shard of
Q,K,Valong the sequence axis. The kernel rotates the localK/Vshard around a ring of devices viajax.lax.ppermute(), accumulating online-softmax statistics so that each device has attended to every key/value after one full rotation. Activation memory is \(O(L / P)\) per device for sequence length \(L\) across \(P\) devices. Falls back tomemory_efficient_attention()whenaxis_nameisNoneor the named axis is unbound.- Parameters:
query (Array) – Local shard of shape
(batch, num_heads, seq_q_local, head_dim).key (Array) – Local shard of shape
(batch, num_heads, seq_kv_local, head_dim).value (Array) – Local shard of shape
(batch, num_heads, seq_kv_local, head_dim).mask (Array | None) – Boolean mask broadcastable to the local score shape
(batch, num_heads, seq_q_local, seq_kv_local).score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores; position indices are global.
dropout_rng (Array | None) – PRNG key for output dropout.
dropout_rate (float) – Output dropout probability.
deterministic (bool) – If
True, disables dropout.axis_name (str | None) – Mesh axis to ring-permute over. When
None, runs on the current device.block_size (int) – Block size for the single-device fallback.
- Returns:
Local output shard of shape
(batch, num_heads, seq_q_local, head_dim).- Return type:
Array
References
Liu et al., Ring Attention with Blockwise Transformers for Near-Infinite Context, 2023.
- attnax.kernels.ring_attention_reference(query, key, value, *, mask=None, score_mod=None)[source]#
Single-device reference for
ring_attention().Computes the same output as
ring_attention()when the entire sequence resides on one device. Used for numerical equivalence tests.- Parameters:
query (Array) – Array of shape
(batch, num_heads, seq_q, head_dim).key (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).value (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).mask (Array | None) – Boolean mask broadcastable to
(batch, num_heads, seq_q, seq_kv).score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores.
- Returns:
Array of shape
(batch, num_heads, seq_q, head_dim).- Return type:
Array
- attnax.kernels.linear_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, causal=True, chunk_size=64)[source]#
Chunkwise-parallel linear attention.
Softmax-free attention with the recurrence
\[S_t = S_{t-1} + \phi(K_t)^\top V_t, \qquad Y_t = \frac{\phi(Q_t)\,S_t}{\phi(Q_t)\,Z_t + \epsilon},\]where \(\phi(x) = \mathrm{elu}(x) + 1\) and \(Z_t\) is the running sum of \(\phi(K)\). Inside a chunk, attention is a softmax-free matmul; across chunks, a
(head_dim, head_dim)state and a(head_dim,)normaliser are propagated withjax.lax.scan(). Activation memory is \(O(L \, d^2)\).- Parameters:
query (Array) – Array of shape
(batch, num_heads, seq_q, head_dim).key (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).seq_kvmust equalseq_q.value (Array) – Array of shape
(batch, num_heads, seq_kv, head_dim).mask (Array | None) – Boolean mask broadcastable to
(batch, num_heads, seq_q, seq_kv). Only the first row along the query axis is consulted (treated as a key-padding mask).score_mod (ScoreMod | None) – Must be
None.dropout_rng (Array | None) – PRNG key for output dropout.
dropout_rate (float) – Dropout probability applied to the output.
deterministic (bool) – If
True, disables dropout.causal (bool) – If
True, intra-chunk attention is causal and inter-chunk state propagates left-to-right.chunk_size (int) – Tokens per chunk;
seq_qmust be divisible bychunk_size.
- Returns:
Array of shape
(batch, num_heads, seq_q, head_dim).- Raises:
NotImplementedError – If
score_modis set.ValueError – If
seq_q != seq_kvorseq_qis not divisible bychunk_size.
- Return type:
Array
References
Katharopoulos et al., Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention, 2020.
- attnax.kernels.paged_attention(query, cache, sequence_idx, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True)[source]#
Attention against a paged KV cache for one sequence.
Gathers the keys and values pointed to by the block table for
sequence_idxand delegates tostandard_attention(). Repeats KV heads ifcache.num_kv_heads < num_heads(grouped-query attention).- Parameters:
query (Array) – Array of shape
(num_heads, seq_q, head_dim)for a single sequence. Usejax.vmap()to batch.cache (PagedKVCache) –
PagedKVCachestoring keys and values.sequence_idx (int) – Row of the block table to attend against.
mask (Array | None) – Boolean mask broadcastable to
(num_heads, seq_q, seq_kv)whereseq_kvis the current sequence length.score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores; key indices are cache positions starting at zero.
dropout_rng (Array | None) – PRNG key for attention dropout.
dropout_rate (float) – Attention dropout probability.
deterministic (bool) – If
True, disables dropout.
- Returns:
Array of shape
(num_heads, seq_q, head_dim).- Return type:
Array
Prebuilt score-mods#
- attnax.kernels.causal_mod()[source]#
Causal score-mod.
Sets the score to \(-\infty\) whenever
kv_idx > q_idx.
- attnax.kernels.sliding_window_mod(window_size, *, causal=True)[source]#
Sliding-window score-mod.
Restricts attention to keys within
window_sizeof each query. WhencausalisTrue, attends only to positionsq_idx - window_size < kv_idx <= q_idx; otherwise the window is symmetric around the query.- Parameters:
- Returns:
A
ScoreModcallable.- Raises:
ValueError – If
window_size <= 0.- Return type:
- attnax.kernels.alibi_mod(num_heads)[source]#
ALiBi additive bias score-mod.
Adds \(-m_h |q\_idx - kv\_idx|\) to the pre-softmax scores, where \(m_h\) is the per-head slope returned by
alibi_slopes().References
Press et al., Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation, 2022.
- attnax.kernels.alibi_slopes(num_heads)[source]#
Per-head ALiBi slopes.
Geometric slope schedule from the ALiBi paper: powers of two for
num_headsthat is a power of two, interleaved truncation otherwise.- Parameters:
num_heads (int) – Number of attention heads.
- Returns:
float32array of shape(num_heads,)with positive slopes \(m_h\). The bias added to the score between query position \(i\) and key position \(j\) is \(-m_h |i - j|\).- Return type:
Array
References
Press et al., Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation, 2022.
- attnax.kernels.prefix_lm_mod(prefix_lengths)[source]#
Prefix-LM score-mod.
Attention is bidirectional for
kv_idx < prefix_lengths[b]and causal elsewhere.
Composition helpers#
- attnax.kernels.compose_score_mods(*score_mods)[source]#
Composes
ScoreModcallables left-to-right.Nonearguments are skipped. ReturnsNonewhen every argument isNone.
- attnax.kernels.mask_mod_to_score_mod(mask_mod)[source]#
Converts a
MaskModinto an equivalentScoreMod.Sets the score to \(-\infty\) where
mask_modreturnsFalseand leaves other entries untouched.
Examples#
Pass any AttentionFn to MultiHeadAttention:
from attnax import MultiHeadAttention
from attnax.kernels import standard_attention
from attnax.kernels.score_mods import (
alibi_mod, sliding_window_mod, compose_score_mods,
)
attn = MultiHeadAttention(
rngs,
num_heads=8,
in_features=512,
attention_fn=standard_attention,
score_mod=compose_score_mods(
alibi_mod(num_heads=8),
sliding_window_mod(window_size=4096),
),
)
Per-call score-mods compose with the constructor-supplied mod:
out = attn(x, score_mod=document_mask_mod(doc_ids), deterministic=True)