Normalization

Normalization#

RMSNorm(*args, **kwargs)

Root-mean-square normalisation with a learnable scale.

create_norm(norm_type, d_model, rngs)

Returns a normalisation module.

class attnax.RMSNorm(*args, **kwargs)[source]#

Bases: Module

Root-mean-square normalisation with a learnable scale.

\[y = \gamma \cdot \frac{x}{\sqrt{\operatorname{mean}(x^2) + \epsilon}}\]

with mean over the last dimension.

Parameters:
  • num_features – Size of the last dimension.

  • rngs – Flax NNX random key container (unused).

  • epsilon – Numerical-stability term.

  • args (Any)

  • kwargs (Any)

Return type:

Any

References

Zhang and Sennrich, Root Mean Square Layer Normalization, 2019.

attnax.create_norm(norm_type, d_model, rngs)[source]#

Returns a normalisation module.

Parameters:
Returns:

Normalisation module.

Raises:

ValueError – If norm_type is not 'layer' or 'rms'.

Return type:

Module