Source code for attnax.cache

# SPDX-License-Identifier: Apache-2.0

"""Key/value cache for autoregressive self-attention (inference)."""

from __future__ import annotations

from dataclasses import dataclass

import jax.numpy as jnp

from .config import TransformerConfig


[docs] @dataclass(frozen=True) class KVLayerCache: """KV cache buffers for one attention layer. Stores keys and values after RoPE rotation in KV-head layout (``num_kv_heads``). Args: keys: Cached keys of shape ``(batch, num_kv_heads, max_len, head_dim)``. values: Cached values of the same shape as ``keys``. length: ``int32`` scalar number of valid positions. """ keys: jnp.ndarray values: jnp.ndarray length: jnp.ndarray @property def max_len(self) -> int: """Maximum number of tokens the cache can hold.""" return int(self.keys.shape[2])
[docs] def init_kv_layer_cache( batch_size: int, num_kv_heads: int, head_dim: int, max_len: int, dtype: jnp.dtype, ) -> KVLayerCache: """Creates a zero-filled :class:`KVLayerCache`. Args: batch_size: Batch dimension. num_kv_heads: Number of KV heads. head_dim: Per-head feature dimensionality. max_len: Maximum sequence length. dtype: Buffer dtype. Returns: :class:`KVLayerCache` with zero buffers and ``length = 0``. """ shape = (batch_size, num_kv_heads, max_len, head_dim) z = jnp.zeros(shape, dtype=dtype) return KVLayerCache( keys=z, values=z, length=jnp.array(0, dtype=jnp.int32), )
[docs] def update_kv_layer_cache( cache: KVLayerCache, keys_new: jnp.ndarray, values_new: jnp.ndarray, start: int, ) -> KVLayerCache: """Writes new keys/values into ``cache`` at ``[start, start + chunk)``. Args: cache: Existing :class:`KVLayerCache`. keys_new: Keys of shape ``(batch, num_kv_heads, chunk, head_dim)``. values_new: Values of the same shape as ``keys_new``. start: Position to write at. Returns: Updated :class:`KVLayerCache` with ``length = start + chunk``. Raises: ValueError: If ``start + chunk > cache.max_len``. """ chunk = keys_new.shape[2] end = start + chunk if end > cache.max_len: raise ValueError( f"KV cache overflow: end={end} exceeds max_len={cache.max_len}" ) new_keys = cache.keys.at[:, :, start:end, :].set(keys_new) new_values = cache.values.at[:, :, start:end, :].set(values_new) return KVLayerCache( keys=new_keys, values=new_values, length=jnp.array(end, dtype=jnp.int32), )
[docs] def init_decoder_kv_caches( *, num_layers: int, batch_size: int, num_kv_heads: int, head_dim: int, max_len: int, dtype: jnp.dtype, ) -> tuple[KVLayerCache, ...]: """Creates one empty :class:`KVLayerCache` per layer. Args: num_layers: Number of layers. batch_size: Batch dimension. num_kv_heads: Number of KV heads. head_dim: Per-head feature dimensionality. max_len: Maximum cached sequence length. dtype: Buffer dtype. Returns: Tuple of ``num_layers`` :class:`KVLayerCache` objects. """ return tuple( init_kv_layer_cache( batch_size, num_kv_heads, head_dim, max_len, dtype ) for _ in range(num_layers) )
[docs] def init_decoder_kv_caches_from_config( config: TransformerConfig, *, batch_size: int, max_len: int | None = None, dtype: jnp.dtype = jnp.float32, ) -> tuple[KVLayerCache, ...]: """Builds per-layer KV caches from a :class:`TransformerConfig`. Args: config: Transformer hyperparameters. batch_size: Batch dimension. max_len: Maximum cached length. Defaults to ``config.kv_cache_max_len`` if set, otherwise ``config.max_len``. dtype: Buffer dtype. Returns: Tuple of ``config.num_layers`` :class:`KVLayerCache` objects. """ cap = max_len if max_len is not None else ( config.kv_cache_max_len or config.max_len ) num_kv = config.num_kv_heads or config.num_heads head_dim = config.d_model // config.num_heads return init_decoder_kv_caches( num_layers=config.num_layers, batch_size=batch_size, num_kv_heads=num_kv, head_dim=head_dim, max_len=cap, dtype=dtype, )