Transformer wrappers#

TransformerEncoder(*args, **kwargs)

Transformer encoder stack.

TransformerDecoder(*args, **kwargs)

Decoder-only transformer stack.

Transformer encoder#

class attnax.TransformerEncoder(*args, **kwargs)[source]#

Bases: Module

Transformer encoder stack.

Token embedding, positional encoding, config.num_layers encoder blocks, and a final normalisation. No mask is applied internally; pass padding_mask as needed.

Parameters:
  • rngs – Flax NNX random key container.

  • config – Transformer hyperparameters.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Transformer decoder#

class attnax.TransformerDecoder(*args, **kwargs)[source]#

Bases: Module

Decoder-only transformer stack.

Token embedding, positional encoding, config.num_layers blocks, and a final normalisation. A causal mask is applied internally and AND-combined with padding_mask. When layer_kv_caches is set, position_ids default to arange(past_len, past_len + seq_len).

Parameters:
  • rngs – Flax NNX random key container.

  • config – Transformer hyperparameters.

  • args (Any)

  • kwargs (Any)

Return type:

Any