Encoder and decoder blocks

Encoder and decoder blocks#

EncoderBlock(*args, **kwargs)

Transformer encoder block.

DecoderBlock(*args, **kwargs)

Transformer decoder block.

Encoder block#

class attnax.EncoderBlock(*args, **kwargs)[source]#

Bases: Module

Transformer encoder block.

Self-attention sub-layer followed by a feed-forward sub-layer, each with a residual connection and normalisation. When pre_norm is True the layout is

x = x + self_attn(norm1(x))
x = x + ffn(norm2(x))

otherwise normalisation is applied after the residual sum.

Parameters:
  • rngs – Flax NNX random key container.

  • d_model – Model dimensionality.

  • num_heads – Number of query attention heads.

  • d_ff – Feed-forward hidden width.

  • dropout_rate – Dropout probability.

  • pre_norm – Pre-norm layout when True, post-norm otherwise.

  • attention_type – Attention backend.

  • attention_block_size – Block size for memory-efficient, flash and pallas_flash backends.

  • linear_attention_chunk_size – Chunk size for the linear backend.

  • norm_type – Normalisation kind.

  • ff_activation – Feed-forward activation.

  • num_kv_heads – KV heads for GQA/MQA.

  • attention_window – Causal sliding-window size for self-attention.

  • score_modScoreMod applied to self-attention.

  • use_rope – Apply RoPE in self-attention.

  • rope_base – RoPE base \(\theta\).

  • rope_max_positions – Length of the RoPE table.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Decoder block#

class attnax.DecoderBlock(*args, **kwargs)[source]#

Bases: Module

Transformer decoder block.

Three sub-layers — self-attention, cross-attention, feed-forward — with residual connections. KV caching is supported on self-attention only.

Parameters:
  • rngs – Flax NNX random key container.

  • d_model – Model dimensionality.

  • num_heads – Number of query attention heads.

  • d_ff – Feed-forward hidden width.

  • dropout_rate – Dropout probability.

  • pre_norm – Pre-norm layout when True, post-norm otherwise.

  • attention_type – Attention backend.

  • attention_block_size – Block size for memory-efficient, flash and pallas_flash backends.

  • linear_attention_chunk_size – Chunk size for the linear backend.

  • norm_type – Normalisation kind.

  • ff_activation – Feed-forward activation.

  • num_kv_heads – KV heads for GQA/MQA in self-attention.

  • attention_window – Causal sliding-window size for self-attention.

  • score_modScoreMod applied to self-attention.

  • use_rope – Apply RoPE in self-attention.

  • rope_base – RoPE base \(\theta\).

  • rope_max_positions – Length of the RoPE table.

  • args (Any)

  • kwargs (Any)

Return type:

Any