Configuration#

TransformerConfig(*, vocab_size[, d_model, ...])

Transformer stack hyperparameters.

class attnax.TransformerConfig(*, vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout_rate=0.1, max_len=512, use_pre_norm=True, pos_emb_type='sinusoidal', ff_activation='gelu', norm_type='layer', pad_token_id=0, attention_type=AttentionType.STANDARD, attention_block_size=512, linear_attention_chunk_size=64, rope_base=10000.0, rope_max_positions=None, num_kv_heads=None, attention_window=None, kv_cache_max_len=None)[source]#

Bases: object

Transformer stack hyperparameters.

Parameters:
  • vocab_size (int) – Vocabulary size.

  • d_model (int) – Hidden size.

  • num_heads (int) – Number of query heads.

  • num_layers (int) – Number of transformer blocks.

  • d_ff (int) – Feed-forward hidden width.

  • dropout_rate (float) – Dropout probability.

  • max_len (int) – Maximum sequence length.

  • use_pre_norm (bool) – Pre-norm when True, post-norm otherwise.

  • pos_emb_type (Literal['sinusoidal', 'rope', 'alibi', 'none']) – Positional embedding variant.

  • ff_activation (Literal['gelu', 'relu', 'swiglu', 'geglu']) – Feed-forward activation.

  • norm_type (Literal['layer', 'rms']) – Normalisation kind.

  • pad_token_id (int) – Token id treated as padding.

  • attention_type (AttentionType) – Attention backend.

  • attention_block_size (int) – Block size for memory-efficient, flash and pallas_flash backends.

  • linear_attention_chunk_size (int) – Chunk size for the linear backend.

  • rope_base (float) – RoPE base \(\theta\).

  • rope_max_positions (int | None) – Length of the RoPE table. Defaults to max_len.

  • num_kv_heads (int | None) – KV heads for GQA/MQA. Defaults to num_heads.

  • attention_window (int | None) – Causal sliding-window size.

  • kv_cache_max_len (int | None) – Maximum length of the KV cache. Defaults to max_len.

String literal aliases#

attnax.PosEmbType: Literal['sinusoidal', 'rope', 'alibi', 'none']#

Positional embedding variant.

  • "sinusoidal": fixed sinusoidal absolute positions.

  • "rope": rotary positional embeddings on Q and K.

  • "alibi": per-head ALiBi additive bias via alibi_mod().

  • "none": no positional information.

attnax.NormKind: Literal['layer', 'rms']#

Normalisation variant.

attnax.FfActivation: Literal['gelu', 'relu', 'swiglu', 'geglu']#

Feed-forward activation.

  • "gelu", "relu": two-layer MLP.

  • "swiglu": SiLU(gate(x)) * up(x) then down.

  • "geglu": GELU(gate(x)) * up(x) then down.

attnax.Pool: Literal['cls', 'mean']#

VisionTransformer pooling strategy.

  • "cls": take the [CLS] token.

  • "mean": mean over all tokens.