Configuration#
|
Transformer stack hyperparameters. |
- class attnax.TransformerConfig(*, vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout_rate=0.1, max_len=512, use_pre_norm=True, pos_emb_type='sinusoidal', ff_activation='gelu', norm_type='layer', pad_token_id=0, attention_type=AttentionType.STANDARD, attention_block_size=512, linear_attention_chunk_size=64, rope_base=10000.0, rope_max_positions=None, num_kv_heads=None, attention_window=None, kv_cache_max_len=None)[source]#
Bases:
objectTransformer stack hyperparameters.
- Parameters:
vocab_size (int) – Vocabulary size.
d_model (int) – Hidden size.
num_heads (int) – Number of query heads.
num_layers (int) – Number of transformer blocks.
d_ff (int) – Feed-forward hidden width.
dropout_rate (float) – Dropout probability.
max_len (int) – Maximum sequence length.
use_pre_norm (bool) – Pre-norm when
True, post-norm otherwise.pos_emb_type (Literal['sinusoidal', 'rope', 'alibi', 'none']) – Positional embedding variant.
ff_activation (Literal['gelu', 'relu', 'swiglu', 'geglu']) – Feed-forward activation.
norm_type (Literal['layer', 'rms']) – Normalisation kind.
pad_token_id (int) – Token id treated as padding.
attention_type (AttentionType) – Attention backend.
attention_block_size (int) – Block size for memory-efficient, flash and pallas_flash backends.
linear_attention_chunk_size (int) – Chunk size for the linear backend.
rope_base (float) – RoPE base \(\theta\).
rope_max_positions (int | None) – Length of the RoPE table. Defaults to
max_len.num_kv_heads (int | None) – KV heads for GQA/MQA. Defaults to
num_heads.attention_window (int | None) – Causal sliding-window size.
kv_cache_max_len (int | None) – Maximum length of the KV cache. Defaults to
max_len.
String literal aliases#
- attnax.PosEmbType: Literal['sinusoidal', 'rope', 'alibi', 'none']#
Positional embedding variant.
"sinusoidal": fixed sinusoidal absolute positions."rope": rotary positional embeddings on Q and K."alibi": per-head ALiBi additive bias viaalibi_mod()."none": no positional information.
- attnax.NormKind: Literal['layer', 'rms']#
Normalisation variant.
"layer":flax.nnx.LayerNorm."rms":RMSNorm.
- attnax.FfActivation: Literal['gelu', 'relu', 'swiglu', 'geglu']#
Feed-forward activation.
"gelu","relu": two-layer MLP."swiglu":SiLU(gate(x)) * up(x)thendown."geglu":GELU(gate(x)) * up(x)thendown.
- attnax.Pool: Literal['cls', 'mean']#
VisionTransformerpooling strategy."cls": take the[CLS]token."mean": mean over all tokens.