Transformer wrappers#
|
Transformer encoder stack. |
|
Decoder-only transformer stack. |
Transformer encoder#
- class attnax.TransformerEncoder(*args, **kwargs)[source]#
Bases:
ModuleTransformer encoder stack.
Token embedding, positional encoding,
config.num_layersencoder blocks, and a final normalisation. No mask is applied internally; passpadding_maskas needed.- Parameters:
rngs – Flax NNX random key container.
config – Transformer hyperparameters.
args (Any)
kwargs (Any)
- Return type:
Any
Transformer decoder#
- class attnax.TransformerDecoder(*args, **kwargs)[source]#
Bases:
ModuleDecoder-only transformer stack.
Token embedding, positional encoding,
config.num_layersblocks, and a final normalisation. A causal mask is applied internally and AND-combined withpadding_mask. Whenlayer_kv_cachesis set,position_idsdefault toarange(past_len, past_len + seq_len).- Parameters:
rngs – Flax NNX random key container.
config – Transformer hyperparameters.
args (Any)
kwargs (Any)
- Return type:
Any