KV cache#

Key/value caches for autoregressive self-attention. Buffers are post-RoPE and in KV-head layout. Cross-attention caching is not supported.

KVLayerCache(keys, values, length)

KV cache buffers for one attention layer.

init_kv_layer_cache(batch_size, ...)

Creates a zero-filled KVLayerCache.

update_kv_layer_cache(cache, keys_new, ...)

Writes new keys/values into cache at [start, start + chunk).

init_decoder_kv_caches(*, num_layers, ...)

Creates one empty KVLayerCache per layer.

init_decoder_kv_caches_from_config(config, ...)

Builds per-layer KV caches from a TransformerConfig.

Per-layer cache#

class attnax.KVLayerCache(keys, values, length)[source]#

Bases: object

KV cache buffers for one attention layer.

Stores keys and values after RoPE rotation in KV-head layout (num_kv_heads).

Parameters:
  • keys (Array) – Cached keys of shape (batch, num_kv_heads, max_len, head_dim).

  • values (Array) – Cached values of the same shape as keys.

  • length (Array) – int32 scalar number of valid positions.

property max_len: int#

Maximum number of tokens the cache can hold.

attnax.init_kv_layer_cache(batch_size, num_kv_heads, head_dim, max_len, dtype)[source]#

Creates a zero-filled KVLayerCache.

Parameters:
  • batch_size (int) – Batch dimension.

  • num_kv_heads (int) – Number of KV heads.

  • head_dim (int) – Per-head feature dimensionality.

  • max_len (int) – Maximum sequence length.

  • dtype (dtype) – Buffer dtype.

Returns:

KVLayerCache with zero buffers and length = 0.

Return type:

KVLayerCache

attnax.update_kv_layer_cache(cache, keys_new, values_new, start)[source]#

Writes new keys/values into cache at [start, start + chunk).

Parameters:
  • cache (KVLayerCache) – Existing KVLayerCache.

  • keys_new (Array) – Keys of shape (batch, num_kv_heads, chunk, head_dim).

  • values_new (Array) – Values of the same shape as keys_new.

  • start (int) – Position to write at.

Returns:

Updated KVLayerCache with length = start + chunk.

Raises:

ValueError – If start + chunk > cache.max_len.

Return type:

KVLayerCache

Whole-model caches#

attnax.init_decoder_kv_caches(*, num_layers, batch_size, num_kv_heads, head_dim, max_len, dtype)[source]#

Creates one empty KVLayerCache per layer.

Parameters:
  • num_layers (int) – Number of layers.

  • batch_size (int) – Batch dimension.

  • num_kv_heads (int) – Number of KV heads.

  • head_dim (int) – Per-head feature dimensionality.

  • max_len (int) – Maximum cached sequence length.

  • dtype (dtype) – Buffer dtype.

Returns:

Tuple of num_layers KVLayerCache objects.

Return type:

tuple[KVLayerCache, …]

attnax.init_decoder_kv_caches_from_config(config, *, batch_size, max_len=None, dtype=<class 'jax.numpy.float32'>)[source]#

Builds per-layer KV caches from a TransformerConfig.

Parameters:
  • config (TransformerConfig) – Transformer hyperparameters.

  • batch_size (int) – Batch dimension.

  • max_len (int | None) – Maximum cached length. Defaults to config.kv_cache_max_len if set, otherwise config.max_len.

  • dtype (dtype) – Buffer dtype.

Returns:

Tuple of config.num_layers KVLayerCache objects.

Return type:

tuple[KVLayerCache, …]

Paged KV cache#

PagedKVCache(key_pool, value_pool, ...)

Paged KV cache storage and per-sequence block table.

init_paged_kv_cache(*, num_blocks, ...)

Allocates an empty PagedKVCache.

allocate_blocks(cache, *, sequence_idx, ...)

Reserves blocks for num_new_tokens additional tokens.

append_kv(cache, sequence_idx, keys_new, ...)

Writes new keys/values for one sequence into the pool.

gather_kv(cache, sequence_idx)

Materialises the contiguous KV view for one sequence.

class attnax.PagedKVCache(key_pool, value_pool, block_table, seq_lengths)[source]#

Bases: object

Paged KV cache storage and per-sequence block table.

Parameters:
  • key_pool (Array)

  • value_pool (Array)

  • block_table (Array)

  • seq_lengths (Array)

key_pool#

Float array of shape (num_blocks, block_size, num_kv_heads, head_dim) holding the physical key blocks.

Type:

jax.jaxlib._jax.Array

value_pool#

Float array of the same shape holding the value blocks.

Type:

jax.jaxlib._jax.Array

block_table#

int32 array of shape (batch, max_blocks_per_seq) mapping each logical block to a physical block index. Unused slots are -1.

Type:

jax.jaxlib._jax.Array

seq_lengths#

int32 array of shape (batch,) giving the number of valid tokens per sequence.

Type:

jax.jaxlib._jax.Array

property num_blocks: int#
property block_size: int#
property num_kv_heads: int#
property head_dim: int#
attnax.init_paged_kv_cache(*, num_blocks, block_size, num_kv_heads, head_dim, batch_size, max_blocks_per_seq, dtype)[source]#

Allocates an empty PagedKVCache.

Parameters:
  • num_blocks (int) – Total physical blocks in the pool.

  • block_size (int) – Tokens per physical block.

  • num_kv_heads (int) – Number of KV heads.

  • head_dim (int) – Per-head feature dimensionality.

  • batch_size (int) – Number of sequences.

  • max_blocks_per_seq (int) – Maximum logical blocks per sequence.

  • dtype (dtype) – Buffer dtype.

Returns:

PagedKVCache with zero pools, -1 block table, and seq_lengths = 0.

Return type:

PagedKVCache

attnax.allocate_blocks(cache, *, sequence_idx, num_new_tokens, free_block_ids)[source]#

Reserves blocks for num_new_tokens additional tokens.

Parameters:
  • cache (PagedKVCache) – Current cache.

  • sequence_idx (int) – Row of the block table to update.

  • num_new_tokens (int) – Tokens about to be appended.

  • free_block_ids (Array) – 1-D array of free physical block indices.

Returns:

Tuple (new_cache, blocks_consumed).

Raises:

ValueError – If free_block_ids is shorter than the number of blocks needed.

Return type:

tuple[PagedKVCache, int]

attnax.append_kv(cache, sequence_idx, keys_new, values_new)[source]#

Writes new keys/values for one sequence into the pool.

The block table for sequence_idx must be populated for every logical block required by the new tokens.

Parameters:
  • cache (PagedKVCache) – Current cache.

  • sequence_idx (int) – Sequence to append to.

  • keys_new (Array) – Keys of shape (chunk, num_kv_heads, head_dim).

  • values_new (Array) – Values of the same shape.

Returns:

Updated PagedKVCache with seq_lengths[sequence_idx] incremented by chunk.

Raises:

ValueError – If shapes disagree or the block table is unallocated.

Return type:

PagedKVCache

attnax.gather_kv(cache, sequence_idx)[source]#

Materialises the contiguous KV view for one sequence.

Parameters:
  • cache (PagedKVCache) – Current cache.

  • sequence_idx (int) – Sequence to materialise.

Returns:

Tuple (keys, values, seq_len) where keys and values have shape (seq_len, num_kv_heads, head_dim).

Return type:

tuple[Array, Array, int]

Pair PagedKVCache with attnax.paged_attention() to attend against a sequence stored in the cache.