Feed-forward#
Position-wise feed-forward layers used inside EncoderBlock
and DecoderBlock. FeedForward is the dense MLP and
gated variant; MixtureOfExperts is the sparse top-\(k\)
routed alternative and is a drop-in replacement at the same call site.
|
Position-wise feed-forward network. |
|
Top-\(k\) routed Mixture-of-Experts feed-forward. |
Dense feed-forward#
- class attnax.FeedForward(*args, **kwargs)[source]#
Bases:
ModulePosition-wise feed-forward network.
For
'gelu'and'relu':\[y = W_2 \,\sigma(W_1 x)\]For
'swiglu'and'geglu':\[y = W_d \,\bigl(\sigma(W_g x) \odot W_u x\bigr)\]- Parameters:
rngs – Flax NNX random key container.
d_model – Input and output dimension.
d_ff – Hidden width.
dropout_rate – Dropout probability.
ff_activation – Activation variant.
args (Any)
kwargs (Any)
- Raises:
ValueError – If
ff_activationis not a recognised value.- Return type:
Any
References
Shazeer, GLU Variants Improve Transformer, 2020.
Mixture of Experts#
- class attnax.MixtureOfExperts(*args, **kwargs)[source]#
Bases:
ModuleTop-\(k\) routed Mixture-of-Experts feed-forward.
Every token independently selects its top-\(k\) of \(\text{num\_experts}\) experts via a learned linear router. Each expert runs an MLP / SwiGLU / GeGLU block with width
d_ff; the outputs are combined weighted by the renormalised top-\(k\) softmax of the router logits. Dispatch is dense — every token is multiplied by every expert weight and unselected contributions are zeroed by the gate.- Parameters:
rngs – Flax NNX random key container.
d_model – Input and output dimension.
d_ff – Per-expert hidden width.
num_experts – Number of experts.
top_k – Experts selected per token.
dropout_rate – Output dropout probability.
ff_activation – Activation variant.
router_jitter – Standard deviation of multiplicative Gaussian noise on router logits at training time.
capacity_factor – Reserved for sparse-dispatch implementations.
args (Any)
kwargs (Any)
- Raises:
ValueError – If
top_kis not in[1, num_experts]or either is non-positive.- Return type:
Any
References
Fedus et al., Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity, 2022.
Example#
import flax.nnx as nnx, jax.numpy as jnp
from attnax import MixtureOfExperts
moe = MixtureOfExperts(
nnx.Rngs(0), d_model=4096, d_ff=14336,
num_experts=8, top_k=2, ff_activation="swiglu",
)
x = jnp.zeros((batch, seq_len, 4096))
y, aux = moe(x, deterministic=False)
loss = main_loss + 0.01 * aux["load_balance_loss"]
aux["load_balance_loss"] is the auxiliary load-balance loss and
aux["router_entropy"] is the mean Shannon entropy of the router
distribution.