Source code for attnax.transformer
# SPDX-License-Identifier: Apache-2.0
"""Transformer encoder and decoder wrappers."""
from __future__ import annotations
from typing import Optional, Union, overload
import flax.nnx as nnx
import jax.numpy as jnp
from .blocks import EncoderBlock
from .cache import KVLayerCache
from .config import TransformerConfig
from .embeddings import PositionalEncoding, TokenEmbedding
from .kernels.score_mods import alibi_mod
from .masking import combine_masks
from .norms import create_norm
def _build_layers(
rngs: nnx.Rngs, config: TransformerConfig
) -> nnx.List:
"""Builds the stack of :class:`EncoderBlock` layers."""
use_rope = config.pos_emb_type == "rope"
rope_table = config.rope_max_positions or config.max_len
score_mod = (
alibi_mod(num_heads=config.num_heads)
if config.pos_emb_type == "alibi"
else None
)
return nnx.List(
[
EncoderBlock(
rngs,
d_model=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
dropout_rate=config.dropout_rate,
pre_norm=config.use_pre_norm,
attention_type=config.attention_type,
attention_block_size=config.attention_block_size,
linear_attention_chunk_size=config.linear_attention_chunk_size,
norm_type=config.norm_type,
ff_activation=config.ff_activation,
num_kv_heads=config.num_kv_heads,
attention_window=config.attention_window,
score_mod=score_mod,
use_rope=use_rope,
rope_base=config.rope_base,
rope_max_positions=rope_table,
)
for _ in range(config.num_layers)
]
)
[docs]
class TransformerEncoder(nnx.Module):
"""Transformer encoder stack.
Token embedding, positional encoding, ``config.num_layers`` encoder
blocks, and a final normalisation. No mask is applied internally;
pass ``padding_mask`` as needed.
Args:
rngs: Flax NNX random key container.
config: Transformer hyperparameters.
"""
def __init__(self, rngs: nnx.Rngs, config: TransformerConfig):
self.config = config
self.token_embed = TokenEmbedding(rngs, config.vocab_size, config.d_model)
self.pos_encoding = PositionalEncoding(config.max_len, config.d_model)
self.dropout = nnx.Dropout(rate=config.dropout_rate, rngs=rngs)
self.layers = _build_layers(rngs, config)
self.final_ln = create_norm(config.norm_type, config.d_model, rngs=rngs)
@overload
def __call__(
self,
input_ids: jnp.ndarray,
*,
padding_mask: Optional[jnp.ndarray] = None,
deterministic: Optional[bool] = None,
position_ids: Optional[jnp.ndarray] = None,
layer_kv_caches: None = None,
) -> jnp.ndarray: ...
@overload
def __call__(
self,
input_ids: jnp.ndarray,
*,
padding_mask: Optional[jnp.ndarray] = None,
deterministic: Optional[bool] = None,
position_ids: Optional[jnp.ndarray] = None,
layer_kv_caches: tuple[KVLayerCache, ...],
) -> tuple[jnp.ndarray, tuple[KVLayerCache, ...]]: ...
def __call__(
self,
input_ids: jnp.ndarray,
*,
padding_mask: Optional[jnp.ndarray] = None,
deterministic: Optional[bool] = None,
position_ids: Optional[jnp.ndarray] = None,
layer_kv_caches: Optional[tuple[KVLayerCache, ...]] = None,
) -> Union[jnp.ndarray, tuple[jnp.ndarray, tuple[KVLayerCache, ...]]]:
"""Applies the encoder.
Args:
input_ids: Token ids of shape ``(batch, seq_len)``.
padding_mask: Boolean key-padding mask.
deterministic: If ``True``, disables dropout.
position_ids: Integer positions of shape ``(batch, seq_len)``
for RoPE.
layer_kv_caches: Per-layer :class:`KVLayerCache` tuple of length
``config.num_layers``.
Returns:
Array of shape ``(batch, seq_len, d_model)``, or
``(output, updated_caches)`` when ``layer_kv_caches`` is set.
"""
x = self.token_embed(input_ids)
if self.config.pos_emb_type == "sinusoidal":
x = self.pos_encoding(x)
x = self.dropout(x, deterministic=deterministic)
if layer_kv_caches is not None:
if len(layer_kv_caches) != len(self.layers):
raise ValueError(
f"layer_kv_caches length {len(layer_kv_caches)} must match "
f"num_layers {len(self.layers)}"
)
new_caches: list[KVLayerCache] = []
for layer, kv in zip(self.layers, layer_kv_caches):
out = layer(
x,
mask=padding_mask,
deterministic=deterministic,
position_ids=position_ids,
self_attn_kv_cache=kv,
)
x, kv_next = out
new_caches.append(kv_next)
x = self.final_ln(x)
return x, tuple(new_caches)
for layer in self.layers:
x = layer(
x,
mask=padding_mask,
deterministic=deterministic,
position_ids=position_ids,
)
return self.final_ln(x)
[docs]
class TransformerDecoder(nnx.Module):
"""Decoder-only transformer stack.
Token embedding, positional encoding, ``config.num_layers`` blocks,
and a final normalisation. A causal mask is applied internally and
AND-combined with ``padding_mask``. When ``layer_kv_caches`` is
set, ``position_ids`` default to
``arange(past_len, past_len + seq_len)``.
Args:
rngs: Flax NNX random key container.
config: Transformer hyperparameters.
"""
def __init__(self, rngs: nnx.Rngs, config: TransformerConfig):
self.config = config
self.token_embed = TokenEmbedding(rngs, config.vocab_size, config.d_model)
self.pos_encoding = PositionalEncoding(config.max_len, config.d_model)
self.dropout = nnx.Dropout(rate=config.dropout_rate, rngs=rngs)
self.layers = _build_layers(rngs, config)
self.final_ln = create_norm(config.norm_type, config.d_model, rngs=rngs)
@overload
def __call__(
self,
input_ids: jnp.ndarray,
*,
padding_mask: Optional[jnp.ndarray] = None,
deterministic: Optional[bool] = None,
position_ids: Optional[jnp.ndarray] = None,
layer_kv_caches: None = None,
) -> jnp.ndarray: ...
@overload
def __call__(
self,
input_ids: jnp.ndarray,
*,
padding_mask: Optional[jnp.ndarray] = None,
deterministic: Optional[bool] = None,
position_ids: Optional[jnp.ndarray] = None,
layer_kv_caches: tuple[KVLayerCache, ...],
) -> tuple[jnp.ndarray, tuple[KVLayerCache, ...]]: ...
def __call__(
self,
input_ids: jnp.ndarray,
*,
padding_mask: Optional[jnp.ndarray] = None,
deterministic: Optional[bool] = None,
position_ids: Optional[jnp.ndarray] = None,
layer_kv_caches: Optional[tuple[KVLayerCache, ...]] = None,
) -> Union[jnp.ndarray, tuple[jnp.ndarray, tuple[KVLayerCache, ...]]]:
"""Applies the decoder.
Args:
input_ids: Token ids of shape ``(batch, seq_len)``.
padding_mask: Boolean key-padding mask of shape
``(batch, 1, 1, past_len + seq_len)``.
deterministic: If ``True``, disables dropout.
position_ids: Integer positions of shape ``(batch, seq_len)``
for RoPE.
layer_kv_caches: Per-layer :class:`KVLayerCache` tuple of length
``config.num_layers``.
Returns:
Array of shape ``(batch, seq_len, d_model)``, or
``(output, updated_caches)`` when ``layer_kv_caches`` is set.
"""
batch, seq_q = input_ids.shape
past_len = (
int(layer_kv_caches[0].length) if layer_kv_caches is not None else 0
)
seq_kv = past_len + seq_q
if position_ids is None:
position_ids = jnp.broadcast_to(
(jnp.arange(seq_q) + past_len)[None, :].astype(jnp.int32),
(batch, seq_q),
)
q_idx = jnp.arange(seq_q) + past_len
k_idx = jnp.arange(seq_kv)
causal_mask = (q_idx[:, None] >= k_idx[None, :])[None, None, :, :]
mask = combine_masks(padding_mask, causal_mask)
x = self.token_embed(input_ids)
if self.config.pos_emb_type == "sinusoidal":
x = self.pos_encoding(x, start=past_len)
x = self.dropout(x, deterministic=deterministic)
if layer_kv_caches is not None:
if len(layer_kv_caches) != len(self.layers):
raise ValueError(
f"layer_kv_caches length {len(layer_kv_caches)} must match "
f"num_layers {len(self.layers)}"
)
new_caches: list[KVLayerCache] = []
for layer, kv in zip(self.layers, layer_kv_caches):
out = layer(
x,
mask=mask,
deterministic=deterministic,
position_ids=position_ids,
self_attn_kv_cache=kv,
)
x, kv_next = out
new_caches.append(kv_next)
x = self.final_ln(x)
return x, tuple(new_caches)
for layer in self.layers:
x = layer(
x,
mask=mask,
deterministic=deterministic,
position_ids=position_ids,
)
return self.final_ln(x)