Source code for attnax.kernels._api

# SPDX-License-Identifier: Apache-2.0

"""Protocols and helpers for attention kernels."""

from __future__ import annotations

from typing import Optional, Protocol

import jax
import jax.numpy as jnp


[docs] class ScoreMod(Protocol): """Callable that modifies pre-softmax attention scores. Returns ``scores + bias`` (or any other position-conditioned transformation) of the same shape as ``scores``. Position indices broadcast over the batch, head, query and key axes respectively. """ def __call__( self, scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: """Modifies scores. Args: scores: Pre-softmax scores of shape ``(batch, num_heads, seq_q, seq_kv)``. b_idx: Batch indices of shape ``(batch, 1, 1, 1)``. h_idx: Head indices of shape ``(1, num_heads, 1, 1)``. q_idx: Query positions of shape ``(1, 1, seq_q, 1)``. kv_idx: Key positions of shape ``(1, 1, 1, seq_kv)``. Returns: Modified scores of the same shape as ``scores``. """ ...
[docs] class MaskMod(Protocol): """Callable that returns a boolean attention mask. Reads only the position indices (no scores) and returns a boolean array that broadcasts to ``(batch, num_heads, seq_q, seq_kv)``. ``True`` means attend. """ def __call__( self, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: """Returns a boolean mask. Args: b_idx: Batch indices of shape ``(batch, 1, 1, 1)``. h_idx: Head indices of shape ``(1, num_heads, 1, 1)``. q_idx: Query positions of shape ``(1, 1, seq_q, 1)``. kv_idx: Key positions of shape ``(1, 1, 1, seq_kv)``. Returns: Boolean array broadcastable to ``(batch, num_heads, seq_q, seq_kv)``. """ ...
[docs] class AttentionFn(Protocol): """Signature for an attention kernel. A pure JAX callable mapping pre-projected ``(batch, num_heads, seq, head_dim)`` query / key / value tensors to the attended output in the same layout. Carries no trainable parameters. """ def __call__( self, query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray, *, mask: Optional[jnp.ndarray] = None, score_mod: Optional[ScoreMod] = None, dropout_rng: Optional[jax.Array] = None, dropout_rate: float = 0.0, deterministic: bool = True, ) -> jnp.ndarray: """Computes the attention output. Args: query: Array of shape ``(batch, num_heads, seq_q, head_dim)``. key: Array of shape ``(batch, num_heads, seq_kv, head_dim)``. value: Array of shape ``(batch, num_heads, seq_kv, head_dim)``. mask: Boolean mask broadcastable to ``(batch, num_heads, seq_q, seq_kv)``. score_mod: Callable applied to the pre-softmax scores. dropout_rng: PRNG key for attention-weight dropout. dropout_rate: Dropout probability applied to attention weights. deterministic: If ``True``, disables dropout. Returns: Array of shape ``(batch, num_heads, seq_q, head_dim)``. """ ...
def _position_indices( batch: int, num_heads: int, seq_q: int, seq_kv: int ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Returns ``(b_idx, h_idx, q_idx, kv_idx)`` broadcast index arrays.""" b_idx = jnp.arange(batch, dtype=jnp.int32)[:, None, None, None] h_idx = jnp.arange(num_heads, dtype=jnp.int32)[None, :, None, None] q_idx = jnp.arange(seq_q, dtype=jnp.int32)[None, None, :, None] kv_idx = jnp.arange(seq_kv, dtype=jnp.int32)[None, None, None, :] return b_idx, h_idx, q_idx, kv_idx def _apply_score_mod( scores: jnp.ndarray, score_mod: Optional[ScoreMod] ) -> jnp.ndarray: """Applies ``score_mod`` to ``scores`` if non-``None``.""" if score_mod is None: return scores batch, num_heads, seq_q, seq_kv = scores.shape b_idx, h_idx, q_idx, kv_idx = _position_indices( batch, num_heads, seq_q, seq_kv ) return score_mod(scores, b_idx, h_idx, q_idx, kv_idx)
[docs] def mask_mod_to_boolean_mask( mask_mod: MaskMod, *, batch: int, num_heads: int, seq_q: int, seq_kv: int, ) -> jnp.ndarray: """Materialises a :data:`MaskMod` into a boolean mask tensor. Args: mask_mod: Callable conforming to :data:`MaskMod`. batch: Batch dimension. num_heads: Number of attention heads. seq_q: Query sequence length. seq_kv: Key/value sequence length. Returns: Boolean array of shape ``(batch, num_heads, seq_q, seq_kv)``. """ b_idx, h_idx, q_idx, kv_idx = _position_indices( batch, num_heads, seq_q, seq_kv ) out = mask_mod(b_idx, h_idx, q_idx, kv_idx) return jnp.broadcast_to(out, (batch, num_heads, seq_q, seq_kv))
[docs] def compose_score_mods(*score_mods: Optional[ScoreMod]) -> Optional[ScoreMod]: """Composes :data:`ScoreMod` callables left-to-right. ``None`` arguments are skipped. Returns ``None`` when every argument is ``None``. Args: *score_mods: Zero or more :data:`ScoreMod` callables or ``None``. Returns: A :data:`ScoreMod` applying each non-``None`` argument in order, or ``None``. """ active = [m for m in score_mods if m is not None] if not active: return None if len(active) == 1: return active[0] def composed( scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: for m in active: scores = m(scores, b_idx, h_idx, q_idx, kv_idx) return scores return composed