Getting started#

Open in Colab

Attnax is an attention and transformer library for JAX. Its main abstraction is the AttentionFn protocol — a pure JAX function with signature (q, k, v, *, mask, score_mod, ...) -> out — together with ScoreMod callables that fold biases and sparsity patterns into the pre-softmax scores. Every bundled kernel (standard_attention, memory_efficient_attention, flash_attention, pallas_flash_attention, linear_attention, ring_attention, paged_attention, lite_attention) conforms to that protocol, as does any user-written kernel, and the same protocol is consumed by MultiHeadAttention via its attention_fn= argument.

This notebook walks through Attnax in three parts:

  • Attention as a function. Pure-JAX kernels and ScoreMod biases on a single batch of Q/K/V tensors.

  • A transformer layer. Wrapping a kernel in MultiHeadAttention, selecting a backend, and supplying a custom kernel.

  • A full transformer. TransformerConfig, masks, training with Optax, autoregressive inference with KV caching, Mixture-of-Experts, and the Vision Transformer.

In Colab the next cell installs Attnax. Locally, pip install attnax once and re-run the notebook.

pip install attnax

Begin by importing the libraries used throughout the notebook:

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import attnax

Attention as a function#

The most general entry point in Attnax is standard_attention — a pure JAX function that consumes Q/K/V tensors of shape (batch, num_heads, seq, head_dim) and returns the attended output of the same shape as the queries. We will start by computing scaled dot-product attention on a small synthetic batch:

from attnax import standard_attention

batch, num_heads, seq, head_dim = 1, 4, 16, 32
q = jax.random.normal(jax.random.key(0), (batch, num_heads, seq, head_dim))
k = jax.random.normal(jax.random.key(1), (batch, num_heads, seq, head_dim))
v = jax.random.normal(jax.random.key(2), (batch, num_heads, seq, head_dim))

out = standard_attention(q, k, v)
out.shape

The same call signature is shared by every kernel in attnax.kernels. memory_efficient_attention computes the same softmax with $O(n)$ activation memory using a block-wise online softmax, which is useful at long sequence lengths:

from attnax import memory_efficient_attention

out_mem = memory_efficient_attention(q, k, v)
jnp.allclose(out, out_mem, atol=1e-5)

flash_attention dispatches to jax.nn.dot_product_attention on GPU and falls back to memory_efficient_attention elsewhere; pallas_flash_attention lowers the same algorithm to a Pallas kernel on Pallas-capable backends and falls back to memory_efficient_attention otherwise. The remaining bundled kernels — linear_attention, ring_attention, paged_attention, lite_attention — are covered later in the notebook.

Composing biases with ScoreMod#

In Attnax, every attention bias — relative position, sliding window, causal masking, prefix-LM, document packing — is a ScoreMod: a callable applied to the pre-softmax scores. Constructors for the common variants live in attnax.kernels.score_mods and compose with compose_score_mods.

We can add an ALiBi bias and a causal sliding window with one call:

from attnax import alibi_mod, compose_score_mods, sliding_window_mod

mod = compose_score_mods(
    alibi_mod(num_heads=num_heads),
    sliding_window_mod(window_size=8, causal=True),
)
out_biased = standard_attention(q, k, v, score_mod=mod)
out_biased.shape

Per-call mods stack on top of whatever was passed at construction time, which is convenient for biases that change every batch — such as the document-packing masks used in long-context training:

from attnax import document_mask_mod

doc_ids = jnp.array(
    [[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]], dtype=jnp.int32
)
out_packed = standard_attention(q, k, v, score_mod=document_mask_mod(doc_ids))
out_packed.shape

A transformer layer#

In practice we rarely call kernels directly: we wrap them in a layer that owns the Q/K/V projections, optional rotary positional embeddings, dropout, and KV cache plumbing. That layer is MultiHeadAttention. It supports full multi-head attention, grouped-query attention (1 < num_kv_heads < num_heads), and multi-query attention (num_kv_heads == 1); the same score_mod= argument we used on the bare kernel is available at construction time:

from attnax import MultiHeadAttention

attn = MultiHeadAttention(
    nnx.Rngs(0),
    num_heads=8,
    in_features=512,
    num_kv_heads=2,          # grouped-query attention; 1 for MQA
    use_rope=True,
    score_mod=alibi_mod(num_heads=8),
)
attn(jnp.zeros((1, 32, 512)), deterministic=True).shape

Picking a built-in backend#

Which kernel MultiHeadAttention calls is controlled by the AttentionType enum (or the attention_fn= escape hatch covered next). The enum entries are:

  • STANDARD — scaled dot-product, $O(n^2)$ activation memory.

  • MEMORY_EFFICIENT — block-wise online softmax, $O(n)$ activation memory.

  • FLASHjax.nn.dot_product_attention on GPU; falls back to MEMORY_EFFICIENT elsewhere.

  • PALLAS_FLASH — Pallas-lowered FlashAttention with score_mod in the inner loop; falls back to MEMORY_EFFICIENT on CPU or when the Pallas kernel fails to lower.

  • LINEAR — chunkwise-parallel softmax-free linear attention. Does not accept score_mod.

  • LITE — element-wise gated attention; not a drop-in replacement for full softmax attention.

ring_attention and paged_attention are not enum entries because they take additional non-generic arguments (axis_name for ring, a PagedKVCache for paged). They are passed directly via attention_fn=.

from attnax import AttentionType, TransformerConfig

flash_config = TransformerConfig(
    vocab_size=32000,
    attention_type=AttentionType.FLASH,
    attention_block_size=512,
)
pallas_config = TransformerConfig(
    vocab_size=32000, attention_type=AttentionType.PALLAS_FLASH,
)
linear_config = TransformerConfig(
    vocab_size=32000, attention_type=AttentionType.LINEAR,
)

Plugging in a custom kernel#

Any callable matching AttentionFn plugs into MultiHeadAttention via attention_fn=. The kernel receives the already-projected, already-rotated (batch, num_heads, seq, head_dim) tensors and is responsible for the softmax compute alone. Here we re-implement scaled dot-product attention from scratch and verify it matches standard_attention:

def my_attention(query, key, value, *, mask=None, score_mod=None,
                 dropout_rng=None, dropout_rate=0.0, deterministic=True):
    del dropout_rng, dropout_rate, deterministic
    scale = jax.lax.rsqrt(jnp.asarray(query.shape[-1], query.dtype))
    scores = jnp.einsum("bhqd,bhkd->bhqk", query, key) * scale
    if mask is not None:
        scores = jnp.where(mask, scores, jnp.finfo(scores.dtype).min)
    weights = jax.nn.softmax(scores, axis=-1)
    return jnp.einsum("bhqk,bhkd->bhqd", weights, value)


attn_custom = MultiHeadAttention(
    nnx.Rngs(0), num_heads=8, in_features=512, attention_fn=my_attention,
)
attn_custom(jnp.zeros((1, 4, 512)), deterministic=True).shape

A real custom kernel is rarely written from scratch; the same attention_fn= slot accepts pallas_flash_attention, ring_attention partial-applied with axis_name=..., paged_attention partial-applied with a PagedKVCache, or a Triton kernel wrapped through jax-triton:

from attnax import pallas_flash_attention

attn_pallas = MultiHeadAttention(
    nnx.Rngs(0), num_heads=8, in_features=512,
    attention_fn=pallas_flash_attention,
)
attn_pallas(jnp.zeros((1, 4, 512)), deterministic=True).shape

A full transformer#

TransformerConfig bundles every transformer hyperparameter, and TransformerEncoder wires the token embeddings, positional encoding, num_layers encoder blocks, and a final norm. The defaults reproduce a pre-norm BERT-style encoder:

from attnax import TransformerEncoder

config = TransformerConfig(
    vocab_size=32000, d_model=512, num_heads=8, num_layers=6, d_ff=2048,
    dropout_rate=0.1, max_len=512,
)
encoder = TransformerEncoder(nnx.Rngs(0), config)
ids = jnp.ones((2, 10), dtype=jnp.int32)
encoder(ids, deterministic=True).shape

The same config controls rotary positional embeddings on Q/K, RMSNorm, gated SwiGLU feed-forwards, and grouped-query attention — by toggling individual fields:

llm_config = TransformerConfig(
    vocab_size=32000, d_model=512, num_heads=8, num_layers=6,
    pos_emb_type="rope",
    norm_type="rms",
    ff_activation="swiglu",
    num_kv_heads=2,
    rope_base=10000.0,
)

Masks#

Attention masks are boolean tensors broadcastable to (batch, num_heads, seq_q, seq_kv); True means attend. make_padding_mask, make_causal_mask, make_sliding_window_mask, and make_document_mask construct the standard variants; combine_masks AND-reduces any mix of masks and None:

from attnax import combine_masks, make_causal_mask, make_padding_mask

masked_ids = jnp.array([[1, 2, 3, 0, 0], [4, 5, 6, 7, 8]])
self_mask = combine_masks(
    make_padding_mask(masked_ids, pad_token_id=0),
    make_causal_mask(masked_ids.shape[1]),
)
self_mask.shape

Training#

TransformerEncoder returns (batch, seq, d_model) hidden states. Adding a small linear head and pairing the model with Optax and flax.nnx.Optimizer gives a complete training step. nnx.Optimizer is a stateful object whose update method mutates the model in place — there is no separate “params” dict to thread through:

import optax


class Model(nnx.Module):
    def __init__(self, rngs, config):
        self.encoder = TransformerEncoder(rngs, config)
        self.proj = nnx.Linear(config.d_model, config.vocab_size, rngs=rngs)

    def __call__(self, ids, *, padding_mask=None, deterministic=True):
        h = self.encoder(
            ids, padding_mask=padding_mask, deterministic=deterministic,
        )
        return self.proj(h)


lm = Model(nnx.Rngs(0), config)
optimizer = nnx.Optimizer(lm, optax.adamw(1e-4), wrt=nnx.Param)


@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model):
        logits = model(batch["input_ids"], deterministic=False)
        return optax.softmax_cross_entropy_with_integer_labels(
            logits, batch["labels"],
        ).mean()

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)
    return loss


batch = {
    "input_ids": jnp.ones((2, 10), dtype=jnp.int32),
    "labels": jnp.zeros((2, 10), dtype=jnp.int32),
}
loss = train_step(lm, optimizer, batch)
print(f"loss = {loss:.4f}")

Encoder–decoder and cross-attention#

For sequence-to-sequence models, pair a TransformerEncoder with one or more DecoderBlocks. Each decoder block runs masked self-attention on the target, cross-attention against the encoder output, and a feed-forward:

from attnax import DecoderBlock

enc_dec_config = TransformerConfig(vocab_size=32000, d_model=512, num_heads=8)
enc = TransformerEncoder(nnx.Rngs(0), enc_dec_config)
encoded = enc(jnp.ones((2, 16), dtype=jnp.int32), deterministic=True)

dec_block = DecoderBlock(
    nnx.Rngs(1),
    d_model=enc_dec_config.d_model,
    num_heads=enc_dec_config.num_heads,
    d_ff=enc_dec_config.d_ff,
)
tgt = jnp.zeros((2, 8, enc_dec_config.d_model))
dec_block(
    tgt,
    encoder_output=encoded,
    self_mask=make_causal_mask(8),
    deterministic=True,
).shape

For cross-attention without the surrounding decoder block, call MultiHeadAttention with context=. KV caching is supported only on self-attention; cross-attention recomputes K/V from the (fixed) encoder output on every call:

cross = MultiHeadAttention(
    nnx.Rngs(0), num_heads=8, in_features=enc_dec_config.d_model,
)
cross(tgt, context=encoded, deterministic=True).shape

Autoregressive inference with KV caching#

For decoding one token at a time, preallocate a KVLayerCache per layer. The cache stores post-RoPE K/V in KV-head layout, so the same buffers work for MHA, GQA, and MQA without conversion. When layer_kv_caches= is passed, the model returns (output, updated_caches):

from attnax import init_decoder_kv_caches_from_config

caches = init_decoder_kv_caches_from_config(
    config, batch_size=1, max_len=2048,
)
prompt = jnp.ones((1, 4), dtype=jnp.int32)
y, caches = encoder(prompt, layer_kv_caches=caches, deterministic=True)
y.shape, int(caches[0].length)

For inference servers serving many heterogeneous sequence lengths, PagedKVCache (à la vLLM) stores K/V in a pool of fixed-size physical blocks addressed through per-sequence block tables. init_paged_kv_cache, allocate_blocks, and append_kv populate it; paged_attention attends against it:

from attnax import (
    allocate_blocks, append_kv, init_paged_kv_cache, paged_attention,
)

num_kv_heads, head_dim = 4, 16
paged = init_paged_kv_cache(
    num_blocks=16, block_size=8,
    num_kv_heads=num_kv_heads, head_dim=head_dim,
    batch_size=2, max_blocks_per_seq=4, dtype=jnp.float32,
)
free = jnp.arange(16, dtype=jnp.int32)
for seq_idx, n_tokens in enumerate([12, 5]):
    paged, used = allocate_blocks(
        paged, sequence_idx=seq_idx,
        num_new_tokens=n_tokens, free_block_ids=free,
    )
    free = free[used:]
    keys = jax.random.normal(
        jax.random.key(seq_idx), (n_tokens, num_kv_heads, head_dim),
    )
    values = jax.random.normal(
        jax.random.key(seq_idx + 100), (n_tokens, num_kv_heads, head_dim),
    )
    paged = append_kv(
        paged, sequence_idx=seq_idx, keys_new=keys, values_new=values,
    )

q = jax.random.normal(jax.random.key(42), (4, 1, head_dim))
paged_attention(q, paged, sequence_idx=0).shape

Mixture of Experts#

MixtureOfExperts is a top-$k$ routed feed-forward that drops in for FeedForward inside any transformer block. It returns (output, aux), where aux["load_balance_loss"] should be added to the training loss with a small coefficient (≈ 0.01) and aux["router_entropy"] is a diagnostic for router collapse:

from attnax import MixtureOfExperts

moe = MixtureOfExperts(
    nnx.Rngs(0), d_model=64, d_ff=128,
    num_experts=4, top_k=2, ff_activation="swiglu",
)
y, aux = moe(jax.random.normal(jax.random.key(0), (2, 16, 64)), deterministic=False)
y.shape, float(aux["load_balance_loss"]), float(aux["router_entropy"])

Vision Transformer#

The same encoder blocks, attention backends, and feed-forward variants used for text are exposed for images through VisionTransformer and VisionTransformerConfig. An image is patchified, an optional [CLS] token is prepended, a learnable absolute positional embedding is added, and num_layers standard encoder blocks are applied:

from attnax import VisionTransformer, VisionTransformerConfig

vit_config = VisionTransformerConfig(
    image_size=224, patch_size=16, num_channels=3, num_classes=1000,
    d_model=768, num_heads=12, num_layers=12, d_ff=3072,
)
vit = VisionTransformer(nnx.Rngs(0), vit_config)
vit(jnp.zeros((2, 224, 224, 3)), deterministic=True).shape

Setting num_classes=None drops the classification head and returns the full token sequence. The LLM-style fields work here too:

vit_llm_config = VisionTransformerConfig(
    image_size=224, patch_size=16, num_classes=1000,
    norm_type="rms",
    ff_activation="swiglu",
    num_kv_heads=4,
    attention_type=AttentionType.MEMORY_EFFICIENT,
)

Composing custom architectures#

When none of the bundled wrappers fits, every lower-level component composes. The example below is a minimal language model assembled directly from TokenEmbedding, PositionalEncoding, and EncoderBlock:

from attnax import EncoderBlock, PositionalEncoding, TokenEmbedding


class CustomLM(nnx.Module):
    def __init__(self, rngs, config):
        self.embed = TokenEmbedding(
            rngs, config.vocab_size, config.d_model,
        )
        self.pos = PositionalEncoding(config.max_len, config.d_model)
        self.blocks = nnx.List([
            EncoderBlock(
                rngs,
                d_model=config.d_model,
                num_heads=config.num_heads,
                d_ff=config.d_ff,
                norm_type=config.norm_type,
                ff_activation=config.ff_activation,
                num_kv_heads=config.num_kv_heads,
            )
            for _ in range(config.num_layers)
        ])
        self.head = nnx.Linear(
            config.d_model, config.vocab_size, rngs=rngs,
        )

    def __call__(self, ids, deterministic=True):
        x = self.pos(self.embed(ids))
        for block in self.blocks:
            x = block(x, deterministic=deterministic)
        return self.head(x)

For the full API surface — every kernel, score-mod, mask helper, and module covered above — refer to the API reference on the docs site.