Getting started#
Attnax is an attention and transformer library for JAX. Its main abstraction is the AttentionFn protocol — a pure JAX function with signature (q, k, v, *, mask, score_mod, ...) -> out — together with ScoreMod callables that fold biases and sparsity patterns into the pre-softmax scores. Every bundled kernel (standard_attention, memory_efficient_attention, flash_attention, pallas_flash_attention, linear_attention, ring_attention, paged_attention, lite_attention) conforms to that protocol, as does any user-written kernel, and the same protocol is consumed by MultiHeadAttention via its attention_fn= argument.
This notebook walks through Attnax in three parts:
Attention as a function. Pure-JAX kernels and
ScoreModbiases on a single batch of Q/K/V tensors.A transformer layer. Wrapping a kernel in
MultiHeadAttention, selecting a backend, and supplying a custom kernel.A full transformer.
TransformerConfig, masks, training with Optax, autoregressive inference with KV caching, Mixture-of-Experts, and the Vision Transformer.
In Colab the next cell installs Attnax. Locally, pip install attnax once and re-run the notebook.
pip install attnax
Begin by importing the libraries used throughout the notebook:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import attnax
Attention as a function#
The most general entry point in Attnax is standard_attention — a pure JAX function that consumes Q/K/V tensors of shape (batch, num_heads, seq, head_dim) and returns the attended output of the same shape as the queries. We will start by computing scaled dot-product attention on a small synthetic batch:
from attnax import standard_attention
batch, num_heads, seq, head_dim = 1, 4, 16, 32
q = jax.random.normal(jax.random.key(0), (batch, num_heads, seq, head_dim))
k = jax.random.normal(jax.random.key(1), (batch, num_heads, seq, head_dim))
v = jax.random.normal(jax.random.key(2), (batch, num_heads, seq, head_dim))
out = standard_attention(q, k, v)
out.shape
The same call signature is shared by every kernel in attnax.kernels. memory_efficient_attention computes the same softmax with $O(n)$ activation memory using a block-wise online softmax, which is useful at long sequence lengths:
from attnax import memory_efficient_attention
out_mem = memory_efficient_attention(q, k, v)
jnp.allclose(out, out_mem, atol=1e-5)
flash_attention dispatches to jax.nn.dot_product_attention on GPU and falls back to memory_efficient_attention elsewhere; pallas_flash_attention lowers the same algorithm to a Pallas kernel on Pallas-capable backends and falls back to memory_efficient_attention otherwise. The remaining bundled kernels — linear_attention, ring_attention, paged_attention, lite_attention — are covered later in the notebook.
Composing biases with ScoreMod#
In Attnax, every attention bias — relative position, sliding window, causal masking, prefix-LM, document packing — is a ScoreMod: a callable applied to the pre-softmax scores. Constructors for the common variants live in attnax.kernels.score_mods and compose with compose_score_mods.
We can add an ALiBi bias and a causal sliding window with one call:
from attnax import alibi_mod, compose_score_mods, sliding_window_mod
mod = compose_score_mods(
alibi_mod(num_heads=num_heads),
sliding_window_mod(window_size=8, causal=True),
)
out_biased = standard_attention(q, k, v, score_mod=mod)
out_biased.shape
Per-call mods stack on top of whatever was passed at construction time, which is convenient for biases that change every batch — such as the document-packing masks used in long-context training:
from attnax import document_mask_mod
doc_ids = jnp.array(
[[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]], dtype=jnp.int32
)
out_packed = standard_attention(q, k, v, score_mod=document_mask_mod(doc_ids))
out_packed.shape
A transformer layer#
In practice we rarely call kernels directly: we wrap them in a layer that owns the Q/K/V projections, optional rotary positional embeddings, dropout, and KV cache plumbing. That layer is MultiHeadAttention. It supports full multi-head attention, grouped-query attention (1 < num_kv_heads < num_heads), and multi-query attention (num_kv_heads == 1); the same score_mod= argument we used on the bare kernel is available at construction time:
from attnax import MultiHeadAttention
attn = MultiHeadAttention(
nnx.Rngs(0),
num_heads=8,
in_features=512,
num_kv_heads=2, # grouped-query attention; 1 for MQA
use_rope=True,
score_mod=alibi_mod(num_heads=8),
)
attn(jnp.zeros((1, 32, 512)), deterministic=True).shape
Picking a built-in backend#
Which kernel MultiHeadAttention calls is controlled by the AttentionType enum (or the attention_fn= escape hatch covered next). The enum entries are:
STANDARD— scaled dot-product, $O(n^2)$ activation memory.MEMORY_EFFICIENT— block-wise online softmax, $O(n)$ activation memory.FLASH—jax.nn.dot_product_attentionon GPU; falls back toMEMORY_EFFICIENTelsewhere.PALLAS_FLASH— Pallas-lowered FlashAttention withscore_modin the inner loop; falls back toMEMORY_EFFICIENTon CPU or when the Pallas kernel fails to lower.LINEAR— chunkwise-parallel softmax-free linear attention. Does not acceptscore_mod.LITE— element-wise gated attention; not a drop-in replacement for full softmax attention.
ring_attention and paged_attention are not enum entries because they take additional non-generic arguments (axis_name for ring, a PagedKVCache for paged). They are passed directly via attention_fn=.
from attnax import AttentionType, TransformerConfig
flash_config = TransformerConfig(
vocab_size=32000,
attention_type=AttentionType.FLASH,
attention_block_size=512,
)
pallas_config = TransformerConfig(
vocab_size=32000, attention_type=AttentionType.PALLAS_FLASH,
)
linear_config = TransformerConfig(
vocab_size=32000, attention_type=AttentionType.LINEAR,
)
Plugging in a custom kernel#
Any callable matching AttentionFn plugs into MultiHeadAttention via attention_fn=. The kernel receives the already-projected, already-rotated (batch, num_heads, seq, head_dim) tensors and is responsible for the softmax compute alone. Here we re-implement scaled dot-product attention from scratch and verify it matches standard_attention:
def my_attention(query, key, value, *, mask=None, score_mod=None,
dropout_rng=None, dropout_rate=0.0, deterministic=True):
del dropout_rng, dropout_rate, deterministic
scale = jax.lax.rsqrt(jnp.asarray(query.shape[-1], query.dtype))
scores = jnp.einsum("bhqd,bhkd->bhqk", query, key) * scale
if mask is not None:
scores = jnp.where(mask, scores, jnp.finfo(scores.dtype).min)
weights = jax.nn.softmax(scores, axis=-1)
return jnp.einsum("bhqk,bhkd->bhqd", weights, value)
attn_custom = MultiHeadAttention(
nnx.Rngs(0), num_heads=8, in_features=512, attention_fn=my_attention,
)
attn_custom(jnp.zeros((1, 4, 512)), deterministic=True).shape
A real custom kernel is rarely written from scratch; the same attention_fn= slot accepts pallas_flash_attention, ring_attention partial-applied with axis_name=..., paged_attention partial-applied with a PagedKVCache, or a Triton kernel wrapped through jax-triton:
from attnax import pallas_flash_attention
attn_pallas = MultiHeadAttention(
nnx.Rngs(0), num_heads=8, in_features=512,
attention_fn=pallas_flash_attention,
)
attn_pallas(jnp.zeros((1, 4, 512)), deterministic=True).shape
A full transformer#
TransformerConfig bundles every transformer hyperparameter, and TransformerEncoder wires the token embeddings, positional encoding, num_layers encoder blocks, and a final norm. The defaults reproduce a pre-norm BERT-style encoder:
from attnax import TransformerEncoder
config = TransformerConfig(
vocab_size=32000, d_model=512, num_heads=8, num_layers=6, d_ff=2048,
dropout_rate=0.1, max_len=512,
)
encoder = TransformerEncoder(nnx.Rngs(0), config)
ids = jnp.ones((2, 10), dtype=jnp.int32)
encoder(ids, deterministic=True).shape
The same config controls rotary positional embeddings on Q/K, RMSNorm, gated SwiGLU feed-forwards, and grouped-query attention — by toggling individual fields:
llm_config = TransformerConfig(
vocab_size=32000, d_model=512, num_heads=8, num_layers=6,
pos_emb_type="rope",
norm_type="rms",
ff_activation="swiglu",
num_kv_heads=2,
rope_base=10000.0,
)
Masks#
Attention masks are boolean tensors broadcastable to (batch, num_heads, seq_q, seq_kv); True means attend. make_padding_mask, make_causal_mask, make_sliding_window_mask, and make_document_mask construct the standard variants; combine_masks AND-reduces any mix of masks and None:
from attnax import combine_masks, make_causal_mask, make_padding_mask
masked_ids = jnp.array([[1, 2, 3, 0, 0], [4, 5, 6, 7, 8]])
self_mask = combine_masks(
make_padding_mask(masked_ids, pad_token_id=0),
make_causal_mask(masked_ids.shape[1]),
)
self_mask.shape
Training#
TransformerEncoder returns (batch, seq, d_model) hidden states. Adding a small linear head and pairing the model with Optax and flax.nnx.Optimizer gives a complete training step. nnx.Optimizer is a stateful object whose update method mutates the model in place — there is no separate “params” dict to thread through:
import optax
class Model(nnx.Module):
def __init__(self, rngs, config):
self.encoder = TransformerEncoder(rngs, config)
self.proj = nnx.Linear(config.d_model, config.vocab_size, rngs=rngs)
def __call__(self, ids, *, padding_mask=None, deterministic=True):
h = self.encoder(
ids, padding_mask=padding_mask, deterministic=deterministic,
)
return self.proj(h)
lm = Model(nnx.Rngs(0), config)
optimizer = nnx.Optimizer(lm, optax.adamw(1e-4), wrt=nnx.Param)
@nnx.jit
def train_step(model, optimizer, batch):
def loss_fn(model):
logits = model(batch["input_ids"], deterministic=False)
return optax.softmax_cross_entropy_with_integer_labels(
logits, batch["labels"],
).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model=model, grads=grads)
return loss
batch = {
"input_ids": jnp.ones((2, 10), dtype=jnp.int32),
"labels": jnp.zeros((2, 10), dtype=jnp.int32),
}
loss = train_step(lm, optimizer, batch)
print(f"loss = {loss:.4f}")
Encoder–decoder and cross-attention#
For sequence-to-sequence models, pair a TransformerEncoder with one or more DecoderBlocks. Each decoder block runs masked self-attention on the target, cross-attention against the encoder output, and a feed-forward:
from attnax import DecoderBlock
enc_dec_config = TransformerConfig(vocab_size=32000, d_model=512, num_heads=8)
enc = TransformerEncoder(nnx.Rngs(0), enc_dec_config)
encoded = enc(jnp.ones((2, 16), dtype=jnp.int32), deterministic=True)
dec_block = DecoderBlock(
nnx.Rngs(1),
d_model=enc_dec_config.d_model,
num_heads=enc_dec_config.num_heads,
d_ff=enc_dec_config.d_ff,
)
tgt = jnp.zeros((2, 8, enc_dec_config.d_model))
dec_block(
tgt,
encoder_output=encoded,
self_mask=make_causal_mask(8),
deterministic=True,
).shape
For cross-attention without the surrounding decoder block, call MultiHeadAttention with context=. KV caching is supported only on self-attention; cross-attention recomputes K/V from the (fixed) encoder output on every call:
cross = MultiHeadAttention(
nnx.Rngs(0), num_heads=8, in_features=enc_dec_config.d_model,
)
cross(tgt, context=encoded, deterministic=True).shape
Autoregressive inference with KV caching#
For decoding one token at a time, preallocate a KVLayerCache per layer. The cache stores post-RoPE K/V in KV-head layout, so the same buffers work for MHA, GQA, and MQA without conversion. When layer_kv_caches= is passed, the model returns (output, updated_caches):
from attnax import init_decoder_kv_caches_from_config
caches = init_decoder_kv_caches_from_config(
config, batch_size=1, max_len=2048,
)
prompt = jnp.ones((1, 4), dtype=jnp.int32)
y, caches = encoder(prompt, layer_kv_caches=caches, deterministic=True)
y.shape, int(caches[0].length)
For inference servers serving many heterogeneous sequence lengths, PagedKVCache (à la vLLM) stores K/V in a pool of fixed-size physical blocks addressed through per-sequence block tables. init_paged_kv_cache, allocate_blocks, and append_kv populate it; paged_attention attends against it:
from attnax import (
allocate_blocks, append_kv, init_paged_kv_cache, paged_attention,
)
num_kv_heads, head_dim = 4, 16
paged = init_paged_kv_cache(
num_blocks=16, block_size=8,
num_kv_heads=num_kv_heads, head_dim=head_dim,
batch_size=2, max_blocks_per_seq=4, dtype=jnp.float32,
)
free = jnp.arange(16, dtype=jnp.int32)
for seq_idx, n_tokens in enumerate([12, 5]):
paged, used = allocate_blocks(
paged, sequence_idx=seq_idx,
num_new_tokens=n_tokens, free_block_ids=free,
)
free = free[used:]
keys = jax.random.normal(
jax.random.key(seq_idx), (n_tokens, num_kv_heads, head_dim),
)
values = jax.random.normal(
jax.random.key(seq_idx + 100), (n_tokens, num_kv_heads, head_dim),
)
paged = append_kv(
paged, sequence_idx=seq_idx, keys_new=keys, values_new=values,
)
q = jax.random.normal(jax.random.key(42), (4, 1, head_dim))
paged_attention(q, paged, sequence_idx=0).shape
Mixture of Experts#
MixtureOfExperts is a top-$k$ routed feed-forward that drops in for FeedForward inside any transformer block. It returns (output, aux), where aux["load_balance_loss"] should be added to the training loss with a small coefficient (≈ 0.01) and aux["router_entropy"] is a diagnostic for router collapse:
from attnax import MixtureOfExperts
moe = MixtureOfExperts(
nnx.Rngs(0), d_model=64, d_ff=128,
num_experts=4, top_k=2, ff_activation="swiglu",
)
y, aux = moe(jax.random.normal(jax.random.key(0), (2, 16, 64)), deterministic=False)
y.shape, float(aux["load_balance_loss"]), float(aux["router_entropy"])
Vision Transformer#
The same encoder blocks, attention backends, and feed-forward variants used for text are exposed for images through VisionTransformer and VisionTransformerConfig. An image is patchified, an optional [CLS] token is prepended, a learnable absolute positional embedding is added, and num_layers standard encoder blocks are applied:
from attnax import VisionTransformer, VisionTransformerConfig
vit_config = VisionTransformerConfig(
image_size=224, patch_size=16, num_channels=3, num_classes=1000,
d_model=768, num_heads=12, num_layers=12, d_ff=3072,
)
vit = VisionTransformer(nnx.Rngs(0), vit_config)
vit(jnp.zeros((2, 224, 224, 3)), deterministic=True).shape
Setting num_classes=None drops the classification head and returns the full token sequence. The LLM-style fields work here too:
vit_llm_config = VisionTransformerConfig(
image_size=224, patch_size=16, num_classes=1000,
norm_type="rms",
ff_activation="swiglu",
num_kv_heads=4,
attention_type=AttentionType.MEMORY_EFFICIENT,
)
Composing custom architectures#
When none of the bundled wrappers fits, every lower-level component composes. The example below is a minimal language model assembled directly from TokenEmbedding, PositionalEncoding, and EncoderBlock:
from attnax import EncoderBlock, PositionalEncoding, TokenEmbedding
class CustomLM(nnx.Module):
def __init__(self, rngs, config):
self.embed = TokenEmbedding(
rngs, config.vocab_size, config.d_model,
)
self.pos = PositionalEncoding(config.max_len, config.d_model)
self.blocks = nnx.List([
EncoderBlock(
rngs,
d_model=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
norm_type=config.norm_type,
ff_activation=config.ff_activation,
num_kv_heads=config.num_kv_heads,
)
for _ in range(config.num_layers)
])
self.head = nnx.Linear(
config.d_model, config.vocab_size, rngs=rngs,
)
def __call__(self, ids, deterministic=True):
x = self.pos(self.embed(ids))
for block in self.blocks:
x = block(x, deterministic=deterministic)
return self.head(x)
For the full API surface — every kernel, score-mod, mask helper, and module covered above — refer to the API reference on the docs site.