Source code for attnax.kernels.paged

# SPDX-License-Identifier: Apache-2.0

"""Paged attention against a block-table KV cache."""

from __future__ import annotations

from typing import Optional

import jax
import jax.numpy as jnp

from ..paged_cache import PagedKVCache, gather_kv
from ._api import ScoreMod
from .attention import standard_attention


[docs] def paged_attention( query: jnp.ndarray, cache: PagedKVCache, sequence_idx: int, *, mask: Optional[jnp.ndarray] = None, score_mod: Optional[ScoreMod] = None, dropout_rng: Optional[jax.Array] = None, dropout_rate: float = 0.0, deterministic: bool = True, ) -> jnp.ndarray: """Attention against a paged KV cache for one sequence. Gathers the keys and values pointed to by the block table for ``sequence_idx`` and delegates to :func:`~attnax.kernels.standard_attention`. Repeats KV heads if ``cache.num_kv_heads < num_heads`` (grouped-query attention). Args: query: Array of shape ``(num_heads, seq_q, head_dim)`` for a single sequence. Use :func:`jax.vmap` to batch. cache: :class:`PagedKVCache` storing keys and values. sequence_idx: Row of the block table to attend against. mask: Boolean mask broadcastable to ``(num_heads, seq_q, seq_kv)`` where ``seq_kv`` is the current sequence length. score_mod: Callable applied to the pre-softmax scores; key indices are cache positions starting at zero. dropout_rng: PRNG key for attention dropout. dropout_rate: Attention dropout probability. deterministic: If ``True``, disables dropout. Returns: Array of shape ``(num_heads, seq_q, head_dim)``. """ keys, values, _ = gather_kv(cache, sequence_idx) if keys.shape[0] == 0: return jnp.zeros_like(query) num_heads = query.shape[0] q = query[None, ...] k = jnp.transpose(keys, (1, 0, 2))[None, ...] v = jnp.transpose(values, (1, 0, 2))[None, ...] num_kv_heads = k.shape[1] if num_kv_heads != num_heads: if num_heads % num_kv_heads != 0: raise ValueError( f"num_heads ({num_heads}) must be divisible by num_kv_heads " f"({num_kv_heads}) for paged attention" ) rep = num_heads // num_kv_heads k = jnp.repeat(k, rep, axis=1) v = jnp.repeat(v, rep, axis=1) if mask is not None: mask = mask[None, ...] out = standard_attention( q, k, v, mask=mask, score_mod=score_mod, dropout_rng=dropout_rng, dropout_rate=dropout_rate, deterministic=deterministic, ) return out[0]