Attnax

Contents

Attnax#

Attnax (Attention for JAX) is a library of attention kernels and transformer components for JAX and Flax. It ships the layers and kernels — attention, KV caches, feed-forward, Mixture-of-Experts, normalisation, positional encodings, and a Vision Transformer — that you assemble into your own attention-based model.

Pluggable kernels

Standard, memory-efficient, FlashAttention, Pallas, ring, linear, and paged attention. Every backend is a pure JAX function with the same call signature.

Composable biases

ALiBi, sliding window, prefix-LM, and document masks expressed as score-mod callables and composed in a single line.

KV caching

Contiguous and paged KV caches for autoregressive inference and batched serving.

Vision Transformer

ViT encoder reusing the same kernels and blocks as the text stack.

Minimal dependencies

Depends only on JAX and Flax.

Installation#

pip install attnax

Or from source:

git clone https://github.com/glibtkachenko/attnax.git
cd attnax
pip install -e .

Requires Python 3.10+, JAX 0.10.0+, and Flax 0.12.7+.