Encoder and decoder blocks#
|
Transformer encoder block. |
|
Transformer decoder block. |
Encoder block#
- class attnax.EncoderBlock(*args, **kwargs)[source]#
Bases:
ModuleTransformer encoder block.
Self-attention sub-layer followed by a feed-forward sub-layer, each with a residual connection and normalisation. When
pre_normisTruethe layout isx = x + self_attn(norm1(x)) x = x + ffn(norm2(x))
otherwise normalisation is applied after the residual sum.
- Parameters:
rngs – Flax NNX random key container.
d_model – Model dimensionality.
num_heads – Number of query attention heads.
d_ff – Feed-forward hidden width.
dropout_rate – Dropout probability.
pre_norm – Pre-norm layout when
True, post-norm otherwise.attention_type – Attention backend.
attention_block_size – Block size for memory-efficient, flash and pallas_flash backends.
linear_attention_chunk_size – Chunk size for the linear backend.
norm_type – Normalisation kind.
ff_activation – Feed-forward activation.
num_kv_heads – KV heads for GQA/MQA.
attention_window – Causal sliding-window size for self-attention.
score_mod –
ScoreModapplied to self-attention.use_rope – Apply RoPE in self-attention.
rope_base – RoPE base \(\theta\).
rope_max_positions – Length of the RoPE table.
args (Any)
kwargs (Any)
- Return type:
Any
Decoder block#
- class attnax.DecoderBlock(*args, **kwargs)[source]#
Bases:
ModuleTransformer decoder block.
Three sub-layers — self-attention, cross-attention, feed-forward — with residual connections. KV caching is supported on self-attention only.
- Parameters:
rngs – Flax NNX random key container.
d_model – Model dimensionality.
num_heads – Number of query attention heads.
d_ff – Feed-forward hidden width.
dropout_rate – Dropout probability.
pre_norm – Pre-norm layout when
True, post-norm otherwise.attention_type – Attention backend.
attention_block_size – Block size for memory-efficient, flash and pallas_flash backends.
linear_attention_chunk_size – Chunk size for the linear backend.
norm_type – Normalisation kind.
ff_activation – Feed-forward activation.
num_kv_heads – KV heads for GQA/MQA in self-attention.
attention_window – Causal sliding-window size for self-attention.
score_mod –
ScoreModapplied to self-attention.use_rope – Apply RoPE in self-attention.
rope_base – RoPE base \(\theta\).
rope_max_positions – Length of the RoPE table.
args (Any)
kwargs (Any)
- Return type:
Any