Normalization#
|
Root-mean-square normalisation with a learnable scale. |
|
Returns a normalisation module. |
- class attnax.RMSNorm(*args, **kwargs)[source]#
Bases:
ModuleRoot-mean-square normalisation with a learnable scale.
\[y = \gamma \cdot \frac{x}{\sqrt{\operatorname{mean}(x^2) + \epsilon}}\]with mean over the last dimension.
- Parameters:
num_features – Size of the last dimension.
rngs – Flax NNX random key container (unused).
epsilon – Numerical-stability term.
args (Any)
kwargs (Any)
- Return type:
Any
References
Zhang and Sennrich, Root Mean Square Layer Normalization, 2019.
- attnax.create_norm(norm_type, d_model, rngs)[source]#
Returns a normalisation module.
- Parameters:
norm_type (Literal['layer', 'rms']) –
'layer'forflax.nnx.LayerNorm,'rms'forRMSNorm.d_model (int) – Feature dimension.
rngs (Rngs) – Flax NNX random key container.
- Returns:
Normalisation module.
- Raises:
ValueError – If
norm_typeis not'layer'or'rms'.- Return type: