Source code for attnax.config

# SPDX-License-Identifier: Apache-2.0

"""Transformer configuration."""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Literal


[docs] class AttentionType(str, Enum): """Built-in attention backend selectors. Attributes: STANDARD: Scaled dot-product attention with :math:`O(n^2)` memory. MEMORY_EFFICIENT: Block-wise online-softmax attention. FLASH: :func:`jax.nn.dot_product_attention` on GPU, :attr:`MEMORY_EFFICIENT` elsewhere. PALLAS_FLASH: Pallas-lowered FlashAttention with ``score_mod`` in the inner loop; falls back to :attr:`MEMORY_EFFICIENT`. LINEAR: Chunkwise-parallel linear attention. Ignores ``score_mod``. LITE: Element-wise gated attention. Ring and paged attention are not enum entries because they require additional arguments (``axis_name`` and ``PagedKVCache``); pass :func:`~attnax.kernels.ring_attention` or :func:`~attnax.kernels.paged_attention` directly via ``attention_fn=``. """ STANDARD = "standard" MEMORY_EFFICIENT = "memory_efficient" FLASH = "flash" PALLAS_FLASH = "pallas_flash" LINEAR = "linear" LITE = "lite"
PosEmbType = Literal["sinusoidal", "rope", "alibi", "none"] NormKind = Literal["layer", "rms"] FfActivation = Literal["gelu", "relu", "swiglu", "geglu"] Pool = Literal["cls", "mean"]
[docs] @dataclass(frozen=True, kw_only=True) class TransformerConfig: """Transformer stack hyperparameters. Args: vocab_size: Vocabulary size. d_model: Hidden size. num_heads: Number of query heads. num_layers: Number of transformer blocks. d_ff: Feed-forward hidden width. dropout_rate: Dropout probability. max_len: Maximum sequence length. use_pre_norm: Pre-norm when ``True``, post-norm otherwise. pos_emb_type: Positional embedding variant. ff_activation: Feed-forward activation. norm_type: Normalisation kind. pad_token_id: Token id treated as padding. attention_type: Attention backend. attention_block_size: Block size for memory-efficient, flash and pallas_flash backends. linear_attention_chunk_size: Chunk size for the linear backend. rope_base: RoPE base :math:`\\theta`. rope_max_positions: Length of the RoPE table. Defaults to ``max_len``. num_kv_heads: KV heads for GQA/MQA. Defaults to ``num_heads``. attention_window: Causal sliding-window size. kv_cache_max_len: Maximum length of the KV cache. Defaults to ``max_len``. """ vocab_size: int d_model: int = 512 num_heads: int = 8 num_layers: int = 6 d_ff: int = 2048 dropout_rate: float = 0.1 max_len: int = 512 use_pre_norm: bool = True pos_emb_type: PosEmbType = "sinusoidal" ff_activation: FfActivation = "gelu" norm_type: NormKind = "layer" pad_token_id: int = 0 attention_type: AttentionType = AttentionType.STANDARD attention_block_size: int = 512 linear_attention_chunk_size: int = 64 rope_base: float = 10000.0 rope_max_positions: int | None = None num_kv_heads: int | None = None attention_window: int | None = None kv_cache_max_len: int | None = None
[docs] @dataclass(frozen=True, kw_only=True) class VisionTransformerConfig: """Vision Transformer hyperparameters. Defaults reproduce ViT-Base/16 on 224x224 RGB images. Args: image_size: ``int`` or ``(height, width)``. Must be divisible by ``patch_size``. patch_size: ``int`` or ``(patch_height, patch_width)``. num_channels: Number of image channels. num_classes: Output classes. ``None`` disables the head and returns the token sequence. use_cls_token: Prepend a learnable ``[CLS]`` token. pool: Pooling strategy before the classification head. d_model: Hidden size. num_heads: Number of query heads. num_layers: Number of transformer blocks. d_ff: Feed-forward hidden width. dropout_rate: Dropout probability. use_pre_norm: Pre-norm when ``True``, post-norm otherwise. ff_activation: Feed-forward activation. norm_type: Normalisation kind. attention_type: Attention backend. attention_block_size: Block size for memory-efficient, flash and pallas_flash backends. linear_attention_chunk_size: Chunk size for the linear backend. num_kv_heads: KV heads for GQA/MQA. attention_window: Causal sliding-window size. Raises: ValueError: If ``image_size`` is not divisible by ``patch_size``, or if ``pool == "cls"`` while ``use_cls_token`` is ``False``. References: Dosovitskiy et al., `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_, 2020. """ image_size: int | tuple[int, int] = 224 patch_size: int | tuple[int, int] = 16 num_channels: int = 3 num_classes: int | None = None use_cls_token: bool = True pool: Pool = "cls" d_model: int = 768 num_heads: int = 12 num_layers: int = 12 d_ff: int = 3072 dropout_rate: float = 0.0 use_pre_norm: bool = True ff_activation: FfActivation = "gelu" norm_type: NormKind = "layer" attention_type: AttentionType = AttentionType.STANDARD attention_block_size: int = 512 linear_attention_chunk_size: int = 64 num_kv_heads: int | None = None attention_window: int | None = None def __post_init__(self) -> None: img_h, img_w = _pair(self.image_size) patch_h, patch_w = _pair(self.patch_size) if img_h % patch_h != 0 or img_w % patch_w != 0: raise ValueError( f"image_size {(img_h, img_w)} must be divisible by patch_size " f"{(patch_h, patch_w)} along both axes." ) if self.pool == "cls" and not self.use_cls_token: raise ValueError( "pool='cls' requires use_cls_token=True; set pool='mean' or " "enable the CLS token." )
def _pair(value: int | tuple[int, int]) -> tuple[int, int]: """Returns ``(value, value)`` for an int, or ``value`` for a tuple.""" if isinstance(value, int): return (value, value) return value