Kernels#

The attnax.kernels subpackage contains the attention kernels as pure JAX functions conforming to the AttentionFn protocol, and ScoreMod / MaskMod constructors for common attention biases. Each kernel can be passed to MultiHeadAttention via attention_fn= or called standalone on pre-projected (batch, num_heads, seq, head_dim) tensors.

Protocols#

class attnax.kernels.AttentionFn(*args, **kwargs)[source]#

Bases: Protocol

Signature for an attention kernel.

A pure JAX callable mapping pre-projected (batch, num_heads, seq, head_dim) query / key / value tensors to the attended output in the same layout. Carries no trainable parameters.

class attnax.kernels.ScoreMod(*args, **kwargs)[source]#

Bases: Protocol

Callable that modifies pre-softmax attention scores.

Returns scores + bias (or any other position-conditioned transformation) of the same shape as scores. Position indices broadcast over the batch, head, query and key axes respectively.

class attnax.kernels.MaskMod(*args, **kwargs)[source]#

Bases: Protocol

Callable that returns a boolean attention mask.

Reads only the position indices (no scores) and returns a boolean array that broadcasts to (batch, num_heads, seq_q, seq_kv). True means attend.

Built-in kernels#

attnax.kernels.standard_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True)[source]#

Scaled dot-product attention.

\[\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left( \frac{QK^\top}{\sqrt{d_k}} + \Delta\right) V\]

where \(\Delta\) is the optional score_mod bias. Entries with mask == False are set to \(-\infty\) before the softmax. Activation memory is \(O(n^2)\).

Parameters:
  • query (Array) – Array of shape (batch, num_heads, seq_q, head_dim).

  • key (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • value (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • mask (Array | None) – Boolean mask broadcastable to (batch, num_heads, seq_q, seq_kv). True means attend.

  • score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores. See attnax.kernels.score_mods.

  • dropout_rng (Array | None) – PRNG key for attention-weight dropout.

  • dropout_rate (float) – Dropout probability applied to attention weights.

  • deterministic (bool) – If True, disables dropout.

Returns:

Array of shape (batch, num_heads, seq_q, head_dim).

Return type:

Array

attnax.kernels.memory_efficient_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, block_size=512)[source]#

Block-wise attention with \(O(n)\) activation memory.

Tiles the key/value sequence into blocks of size block_size and accumulates a running max, softmax denominator, and output via the online-softmax recurrence. Mathematically identical to standard_attention() but does not materialise the full (seq_q, seq_kv) score matrix. Falls back to standard_attention() when both axes fit in one block.

Parameters:
  • query (Array) – Array of shape (batch, num_heads, seq_q, head_dim).

  • key (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • value (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • mask (Array | None) – Boolean mask broadcastable to (batch, num_heads, seq_q, seq_kv).

  • score_mod (ScoreMod | None) – Callable applied block-by-block to the pre-softmax scores; position indices are global.

  • dropout_rng (Array | None) – PRNG key for output dropout.

  • dropout_rate (float) – Dropout probability applied to the output.

  • deterministic (bool) – If True, disables dropout.

  • block_size (int) – Number of positions per query / key block.

Returns:

Array of shape (batch, num_heads, seq_q, head_dim).

Return type:

Array

References

Milakov and Gimelshein, Online normalizer calculation for softmax, 2018.

Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022.

attnax.kernels.flash_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, block_size=512)[source]#

Hardware-dispatched scaled dot-product attention.

Dispatches to jax.nn.dot_product_attention() on GPU (backed by cuDNN’s FlashAttention when available) and to memory_efficient_attention() on other backends. When score_mod is set, always uses memory_efficient_attention().

Parameters:
  • query (Array) – Array of shape (batch, num_heads, seq_q, head_dim).

  • key (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • value (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • mask (Array | None) – Boolean mask broadcastable to (batch, num_heads, seq_q, seq_kv).

  • score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores. Forces the memory-efficient fallback when set.

  • dropout_rng (Array | None) – PRNG key for dropout (fallback path only).

  • dropout_rate (float) – Dropout probability.

  • deterministic (bool) – If True, disables dropout.

  • block_size (int) – Block size for the fallback path.

Returns:

Array of shape (batch, num_heads, seq_q, head_dim).

Return type:

Array

References

Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022.

attnax.kernels.pallas_flash_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, block_q=128, block_kv=128, force_fallback=False)[source]#

FlashAttention forward pass compiled with Pallas.

Dispatches to a Pallas-lowered kernel on GPU/TPU and falls back to memory_efficient_attention() on CPU or when the kernel fails to lower. Forward pass only; the gradient is computed by tracing the kernel with jvp rather than by a hand-written backward.

Parameters:
  • query (Array) – Array of shape (batch, num_heads, seq_q, head_dim).

  • key (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • value (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • mask (Array | None) – Boolean mask broadcastable to (batch, num_heads, seq_q, seq_kv).

  • score_mod (ScoreMod | None) – Callable traced into the inner loop; position indices are global.

  • dropout_rng (Array | None) – PRNG key for output dropout.

  • dropout_rate (float) – Output dropout probability.

  • deterministic (bool) – If True, disables dropout.

  • block_q (int) – Queries per kernel program. seq_q must be divisible by block_q.

  • block_kv (int) – Keys per inner-loop block. seq_kv must be divisible by block_kv.

  • force_fallback (bool) – If True, always use the pure-JAX fallback.

Returns:

Array of shape (batch, num_heads, seq_q, head_dim).

Return type:

Array

References

Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022.

attnax.kernels.ring_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, axis_name=None, block_size=512)[source]#

Sequence-parallel ring attention.

Each device holds a shard of Q, K, V along the sequence axis. The kernel rotates the local K / V shard around a ring of devices via jax.lax.ppermute(), accumulating online-softmax statistics so that each device has attended to every key/value after one full rotation. Activation memory is \(O(L / P)\) per device for sequence length \(L\) across \(P\) devices. Falls back to memory_efficient_attention() when axis_name is None or the named axis is unbound.

Parameters:
  • query (Array) – Local shard of shape (batch, num_heads, seq_q_local, head_dim).

  • key (Array) – Local shard of shape (batch, num_heads, seq_kv_local, head_dim).

  • value (Array) – Local shard of shape (batch, num_heads, seq_kv_local, head_dim).

  • mask (Array | None) – Boolean mask broadcastable to the local score shape (batch, num_heads, seq_q_local, seq_kv_local).

  • score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores; position indices are global.

  • dropout_rng (Array | None) – PRNG key for output dropout.

  • dropout_rate (float) – Output dropout probability.

  • deterministic (bool) – If True, disables dropout.

  • axis_name (str | None) – Mesh axis to ring-permute over. When None, runs on the current device.

  • block_size (int) – Block size for the single-device fallback.

Returns:

Local output shard of shape (batch, num_heads, seq_q_local, head_dim).

Return type:

Array

References

Liu et al., Ring Attention with Blockwise Transformers for Near-Infinite Context, 2023.

attnax.kernels.ring_attention_reference(query, key, value, *, mask=None, score_mod=None)[source]#

Single-device reference for ring_attention().

Computes the same output as ring_attention() when the entire sequence resides on one device. Used for numerical equivalence tests.

Parameters:
  • query (Array) – Array of shape (batch, num_heads, seq_q, head_dim).

  • key (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • value (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • mask (Array | None) – Boolean mask broadcastable to (batch, num_heads, seq_q, seq_kv).

  • score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores.

Returns:

Array of shape (batch, num_heads, seq_q, head_dim).

Return type:

Array

attnax.kernels.linear_attention(query, key, value, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True, causal=True, chunk_size=64)[source]#

Chunkwise-parallel linear attention.

Softmax-free attention with the recurrence

\[S_t = S_{t-1} + \phi(K_t)^\top V_t, \qquad Y_t = \frac{\phi(Q_t)\,S_t}{\phi(Q_t)\,Z_t + \epsilon},\]

where \(\phi(x) = \mathrm{elu}(x) + 1\) and \(Z_t\) is the running sum of \(\phi(K)\). Inside a chunk, attention is a softmax-free matmul; across chunks, a (head_dim, head_dim) state and a (head_dim,) normaliser are propagated with jax.lax.scan(). Activation memory is \(O(L \, d^2)\).

Parameters:
  • query (Array) – Array of shape (batch, num_heads, seq_q, head_dim).

  • key (Array) – Array of shape (batch, num_heads, seq_kv, head_dim). seq_kv must equal seq_q.

  • value (Array) – Array of shape (batch, num_heads, seq_kv, head_dim).

  • mask (Array | None) – Boolean mask broadcastable to (batch, num_heads, seq_q, seq_kv). Only the first row along the query axis is consulted (treated as a key-padding mask).

  • score_mod (ScoreMod | None) – Must be None.

  • dropout_rng (Array | None) – PRNG key for output dropout.

  • dropout_rate (float) – Dropout probability applied to the output.

  • deterministic (bool) – If True, disables dropout.

  • causal (bool) – If True, intra-chunk attention is causal and inter-chunk state propagates left-to-right.

  • chunk_size (int) – Tokens per chunk; seq_q must be divisible by chunk_size.

Returns:

Array of shape (batch, num_heads, seq_q, head_dim).

Raises:
Return type:

Array

References

Katharopoulos et al., Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention, 2020.

attnax.kernels.paged_attention(query, cache, sequence_idx, *, mask=None, score_mod=None, dropout_rng=None, dropout_rate=0.0, deterministic=True)[source]#

Attention against a paged KV cache for one sequence.

Gathers the keys and values pointed to by the block table for sequence_idx and delegates to standard_attention(). Repeats KV heads if cache.num_kv_heads < num_heads (grouped-query attention).

Parameters:
  • query (Array) – Array of shape (num_heads, seq_q, head_dim) for a single sequence. Use jax.vmap() to batch.

  • cache (PagedKVCache) – PagedKVCache storing keys and values.

  • sequence_idx (int) – Row of the block table to attend against.

  • mask (Array | None) – Boolean mask broadcastable to (num_heads, seq_q, seq_kv) where seq_kv is the current sequence length.

  • score_mod (ScoreMod | None) – Callable applied to the pre-softmax scores; key indices are cache positions starting at zero.

  • dropout_rng (Array | None) – PRNG key for attention dropout.

  • dropout_rate (float) – Attention dropout probability.

  • deterministic (bool) – If True, disables dropout.

Returns:

Array of shape (num_heads, seq_q, head_dim).

Return type:

Array

Prebuilt score-mods#

attnax.kernels.causal_mod()[source]#

Causal score-mod.

Sets the score to \(-\infty\) whenever kv_idx > q_idx.

Returns:

A ScoreMod callable.

Return type:

ScoreMod

attnax.kernels.sliding_window_mod(window_size, *, causal=True)[source]#

Sliding-window score-mod.

Restricts attention to keys within window_size of each query. When causal is True, attends only to positions q_idx - window_size < kv_idx <= q_idx; otherwise the window is symmetric around the query.

Parameters:
  • window_size (int) – Maximum query-key distance (positive integer).

  • causal (bool) – If True, restrict to past keys.

Returns:

A ScoreMod callable.

Raises:

ValueError – If window_size <= 0.

Return type:

ScoreMod

attnax.kernels.alibi_mod(num_heads)[source]#

ALiBi additive bias score-mod.

Adds \(-m_h |q\_idx - kv\_idx|\) to the pre-softmax scores, where \(m_h\) is the per-head slope returned by alibi_slopes().

Parameters:

num_heads (int) – Number of attention heads.

Returns:

A ScoreMod callable.

Return type:

ScoreMod

References

Press et al., Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation, 2022.

attnax.kernels.alibi_slopes(num_heads)[source]#

Per-head ALiBi slopes.

Geometric slope schedule from the ALiBi paper: powers of two for num_heads that is a power of two, interleaved truncation otherwise.

Parameters:

num_heads (int) – Number of attention heads.

Returns:

float32 array of shape (num_heads,) with positive slopes \(m_h\). The bias added to the score between query position \(i\) and key position \(j\) is \(-m_h |i - j|\).

Return type:

Array

References

Press et al., Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation, 2022.

attnax.kernels.prefix_lm_mod(prefix_lengths)[source]#

Prefix-LM score-mod.

Attention is bidirectional for kv_idx < prefix_lengths[b] and causal elsewhere.

Parameters:

prefix_lengths (Array) – Integer array of shape (batch,) with the per-sequence prefix length.

Returns:

A ScoreMod callable.

Return type:

ScoreMod

attnax.kernels.document_mask_mod(document_ids)[source]#

Document-boundary score-mod for sequence packing.

Masks attention between tokens in different documents. A query at position \(i\) only attends to keys at positions \(j\) with document_ids[b, j] == document_ids[b, i].

Parameters:

document_ids (Array) – Integer array of shape (batch, seq_len).

Returns:

A ScoreMod callable.

Return type:

ScoreMod

attnax.kernels.additive_bias_mod(bias)[source]#

Additive score-mod from a precomputed bias tensor.

Parameters:

bias (Array) – Float array broadcastable to (batch, num_heads, seq_q, seq_kv).

Returns:

A ScoreMod callable.

Return type:

ScoreMod

Composition helpers#

attnax.kernels.compose_score_mods(*score_mods)[source]#

Composes ScoreMod callables left-to-right.

None arguments are skipped. Returns None when every argument is None.

Parameters:

*score_mods (ScoreMod | None) – Zero or more ScoreMod callables or None.

Returns:

A ScoreMod applying each non-None argument in order, or None.

Return type:

ScoreMod | None

attnax.kernels.mask_mod_to_score_mod(mask_mod)[source]#

Converts a MaskMod into an equivalent ScoreMod.

Sets the score to \(-\infty\) where mask_mod returns False and leaves other entries untouched.

Parameters:

mask_mod (MaskMod) – Callable conforming to MaskMod.

Returns:

A ScoreMod callable.

Return type:

ScoreMod

attnax.kernels.mask_mod_to_boolean_mask(mask_mod, *, batch, num_heads, seq_q, seq_kv)[source]#

Materialises a MaskMod into a boolean mask tensor.

Parameters:
  • mask_mod (MaskMod) – Callable conforming to MaskMod.

  • batch (int) – Batch dimension.

  • num_heads (int) – Number of attention heads.

  • seq_q (int) – Query sequence length.

  • seq_kv (int) – Key/value sequence length.

Returns:

Boolean array of shape (batch, num_heads, seq_q, seq_kv).

Return type:

Array

Examples#

Pass any AttentionFn to MultiHeadAttention:

from attnax import MultiHeadAttention
from attnax.kernels import standard_attention
from attnax.kernels.score_mods import (
    alibi_mod, sliding_window_mod, compose_score_mods,
)

attn = MultiHeadAttention(
    rngs,
    num_heads=8,
    in_features=512,
    attention_fn=standard_attention,
    score_mod=compose_score_mods(
        alibi_mod(num_heads=8),
        sliding_window_mod(window_size=4096),
    ),
)

Per-call score-mods compose with the constructor-supplied mod:

out = attn(x, score_mod=document_mask_mod(doc_ids), deterministic=True)