Attnax#
Attnax (Attention for JAX) is a library of attention kernels and transformer components for JAX and Flax. It ships the layers and kernels — attention, KV caches, feed-forward, Mixture-of-Experts, normalisation, positional encodings, and a Vision Transformer — that you assemble into your own attention-based model.
Standard, memory-efficient, FlashAttention, Pallas, ring, linear, and paged attention. Every backend is a pure JAX function with the same call signature.
ALiBi, sliding window, prefix-LM, and document masks expressed as score-mod callables and composed in a single line.
Contiguous and paged KV caches for autoregressive inference and batched serving.
ViT encoder reusing the same kernels and blocks as the text stack.
Depends only on JAX and Flax.
Installation#
pip install attnax
Or from source:
git clone https://github.com/glibtkachenko/attnax.git
cd attnax
pip install -e .
Requires Python 3.10+, JAX 0.10.0+, and Flax 0.12.7+.