KV cache#
Key/value caches for autoregressive self-attention. Buffers are post-RoPE and in KV-head layout. Cross-attention caching is not supported.
|
KV cache buffers for one attention layer. |
|
Creates a zero-filled |
|
Writes new keys/values into |
|
Creates one empty |
|
Builds per-layer KV caches from a |
Per-layer cache#
- class attnax.KVLayerCache(keys, values, length)[source]#
Bases:
objectKV cache buffers for one attention layer.
Stores keys and values after RoPE rotation in KV-head layout (
num_kv_heads).- Parameters:
keys (Array) – Cached keys of shape
(batch, num_kv_heads, max_len, head_dim).values (Array) – Cached values of the same shape as
keys.length (Array) –
int32scalar number of valid positions.
- attnax.init_kv_layer_cache(batch_size, num_kv_heads, head_dim, max_len, dtype)[source]#
Creates a zero-filled
KVLayerCache.- Parameters:
- Returns:
KVLayerCachewith zero buffers andlength = 0.- Return type:
- attnax.update_kv_layer_cache(cache, keys_new, values_new, start)[source]#
Writes new keys/values into
cacheat[start, start + chunk).- Parameters:
cache (KVLayerCache) – Existing
KVLayerCache.keys_new (Array) – Keys of shape
(batch, num_kv_heads, chunk, head_dim).values_new (Array) – Values of the same shape as
keys_new.start (int) – Position to write at.
- Returns:
Updated
KVLayerCachewithlength = start + chunk.- Raises:
ValueError – If
start + chunk > cache.max_len.- Return type:
Whole-model caches#
- attnax.init_decoder_kv_caches(*, num_layers, batch_size, num_kv_heads, head_dim, max_len, dtype)[source]#
Creates one empty
KVLayerCacheper layer.- Parameters:
- Returns:
Tuple of
num_layersKVLayerCacheobjects.- Return type:
tuple[KVLayerCache, …]
- attnax.init_decoder_kv_caches_from_config(config, *, batch_size, max_len=None, dtype=<class 'jax.numpy.float32'>)[source]#
Builds per-layer KV caches from a
TransformerConfig.- Parameters:
config (TransformerConfig) – Transformer hyperparameters.
batch_size (int) – Batch dimension.
max_len (int | None) – Maximum cached length. Defaults to
config.kv_cache_max_lenif set, otherwiseconfig.max_len.dtype (dtype) – Buffer dtype.
- Returns:
Tuple of
config.num_layersKVLayerCacheobjects.- Return type:
tuple[KVLayerCache, …]
Paged KV cache#
|
Paged KV cache storage and per-sequence block table. |
|
Allocates an empty |
|
Reserves blocks for |
|
Writes new keys/values for one sequence into the pool. |
|
Materialises the contiguous KV view for one sequence. |
- class attnax.PagedKVCache(key_pool, value_pool, block_table, seq_lengths)[source]#
Bases:
objectPaged KV cache storage and per-sequence block table.
- Parameters:
key_pool (Array)
value_pool (Array)
block_table (Array)
seq_lengths (Array)
- key_pool#
Float array of shape
(num_blocks, block_size, num_kv_heads, head_dim)holding the physical key blocks.- Type:
jax.jaxlib._jax.Array
- value_pool#
Float array of the same shape holding the value blocks.
- Type:
jax.jaxlib._jax.Array
- block_table#
int32array of shape(batch, max_blocks_per_seq)mapping each logical block to a physical block index. Unused slots are-1.- Type:
jax.jaxlib._jax.Array
- seq_lengths#
int32array of shape(batch,)giving the number of valid tokens per sequence.- Type:
jax.jaxlib._jax.Array
- attnax.init_paged_kv_cache(*, num_blocks, block_size, num_kv_heads, head_dim, batch_size, max_blocks_per_seq, dtype)[source]#
Allocates an empty
PagedKVCache.- Parameters:
num_blocks (int) – Total physical blocks in the pool.
block_size (int) – Tokens per physical block.
num_kv_heads (int) – Number of KV heads.
head_dim (int) – Per-head feature dimensionality.
batch_size (int) – Number of sequences.
max_blocks_per_seq (int) – Maximum logical blocks per sequence.
dtype (dtype) – Buffer dtype.
- Returns:
PagedKVCachewith zero pools,-1block table, andseq_lengths = 0.- Return type:
- attnax.allocate_blocks(cache, *, sequence_idx, num_new_tokens, free_block_ids)[source]#
Reserves blocks for
num_new_tokensadditional tokens.- Parameters:
cache (PagedKVCache) – Current cache.
sequence_idx (int) – Row of the block table to update.
num_new_tokens (int) – Tokens about to be appended.
free_block_ids (Array) – 1-D array of free physical block indices.
- Returns:
Tuple
(new_cache, blocks_consumed).- Raises:
ValueError – If
free_block_idsis shorter than the number of blocks needed.- Return type:
- attnax.append_kv(cache, sequence_idx, keys_new, values_new)[source]#
Writes new keys/values for one sequence into the pool.
The block table for
sequence_idxmust be populated for every logical block required by the new tokens.- Parameters:
cache (PagedKVCache) – Current cache.
sequence_idx (int) – Sequence to append to.
keys_new (Array) – Keys of shape
(chunk, num_kv_heads, head_dim).values_new (Array) – Values of the same shape.
- Returns:
Updated
PagedKVCachewithseq_lengths[sequence_idx]incremented bychunk.- Raises:
ValueError – If shapes disagree or the block table is unallocated.
- Return type:
- attnax.gather_kv(cache, sequence_idx)[source]#
Materialises the contiguous KV view for one sequence.
- Parameters:
cache (PagedKVCache) – Current cache.
sequence_idx (int) – Sequence to materialise.
- Returns:
Tuple
(keys, values, seq_len)wherekeysandvalueshave shape(seq_len, num_kv_heads, head_dim).- Return type:
Pair PagedKVCache with attnax.paged_attention() to
attend against a sequence stored in the cache.