Feed-forward#

Position-wise feed-forward layers used inside EncoderBlock and DecoderBlock. FeedForward is the dense MLP and gated variant; MixtureOfExperts is the sparse top-\(k\) routed alternative and is a drop-in replacement at the same call site.

FeedForward(*args, **kwargs)

Position-wise feed-forward network.

MixtureOfExperts(*args, **kwargs)

Top-\(k\) routed Mixture-of-Experts feed-forward.

Dense feed-forward#

class attnax.FeedForward(*args, **kwargs)[source]#

Bases: Module

Position-wise feed-forward network.

For 'gelu' and 'relu':

\[y = W_2 \,\sigma(W_1 x)\]

For 'swiglu' and 'geglu':

\[y = W_d \,\bigl(\sigma(W_g x) \odot W_u x\bigr)\]
Parameters:
  • rngs – Flax NNX random key container.

  • d_model – Input and output dimension.

  • d_ff – Hidden width.

  • dropout_rate – Dropout probability.

  • ff_activation – Activation variant.

  • args (Any)

  • kwargs (Any)

Raises:

ValueError – If ff_activation is not a recognised value.

Return type:

Any

References

Shazeer, GLU Variants Improve Transformer, 2020.

Mixture of Experts#

class attnax.MixtureOfExperts(*args, **kwargs)[source]#

Bases: Module

Top-\(k\) routed Mixture-of-Experts feed-forward.

Every token independently selects its top-\(k\) of \(\text{num\_experts}\) experts via a learned linear router. Each expert runs an MLP / SwiGLU / GeGLU block with width d_ff; the outputs are combined weighted by the renormalised top-\(k\) softmax of the router logits. Dispatch is dense — every token is multiplied by every expert weight and unselected contributions are zeroed by the gate.

Parameters:
  • rngs – Flax NNX random key container.

  • d_model – Input and output dimension.

  • d_ff – Per-expert hidden width.

  • num_experts – Number of experts.

  • top_k – Experts selected per token.

  • dropout_rate – Output dropout probability.

  • ff_activation – Activation variant.

  • router_jitter – Standard deviation of multiplicative Gaussian noise on router logits at training time.

  • capacity_factor – Reserved for sparse-dispatch implementations.

  • args (Any)

  • kwargs (Any)

Raises:

ValueError – If top_k is not in [1, num_experts] or either is non-positive.

Return type:

Any

References

Fedus et al., Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity, 2022.

Example#

import flax.nnx as nnx, jax.numpy as jnp
from attnax import MixtureOfExperts

moe = MixtureOfExperts(
    nnx.Rngs(0), d_model=4096, d_ff=14336,
    num_experts=8, top_k=2, ff_activation="swiglu",
)
x = jnp.zeros((batch, seq_len, 4096))
y, aux = moe(x, deterministic=False)
loss = main_loss + 0.01 * aux["load_balance_loss"]

aux["load_balance_loss"] is the auxiliary load-balance loss and aux["router_entropy"] is the mean Shannon entropy of the router distribution.