Source code for attnax.norms
# SPDX-License-Identifier: Apache-2.0
"""Normalization layers."""
from __future__ import annotations
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from .config import NormKind
[docs]
class RMSNorm(nnx.Module):
"""Root-mean-square normalisation with a learnable scale.
.. math::
y = \\gamma \\cdot
\\frac{x}{\\sqrt{\\operatorname{mean}(x^2) + \\epsilon}}
with mean over the last dimension.
Args:
num_features: Size of the last dimension.
rngs: Flax NNX random key container (unused).
epsilon: Numerical-stability term.
References:
Zhang and Sennrich, `Root Mean Square Layer Normalization
<https://arxiv.org/abs/1910.07467>`_, 2019.
"""
def __init__(
self,
num_features: int,
*,
rngs: nnx.Rngs,
epsilon: float = 1e-6,
):
del rngs
self.scale = nnx.Param(jnp.ones((num_features,)))
self.epsilon = epsilon
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies RMS normalisation along the last axis.
Args:
x: Input array.
Returns:
Array of the same shape as ``x``.
"""
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
normed = x * jax.lax.rsqrt(var + self.epsilon)
return normed * self.scale
[docs]
def create_norm(
norm_type: NormKind,
d_model: int,
rngs: nnx.Rngs,
) -> nnx.Module:
"""Returns a normalisation module.
Args:
norm_type: ``'layer'`` for :class:`flax.nnx.LayerNorm`, ``'rms'``
for :class:`RMSNorm`.
d_model: Feature dimension.
rngs: Flax NNX random key container.
Returns:
Normalisation module.
Raises:
ValueError: If ``norm_type`` is not ``'layer'`` or ``'rms'``.
"""
if norm_type == "layer":
return nnx.LayerNorm(d_model, rngs=rngs)
if norm_type == "rms":
return RMSNorm(d_model, rngs=rngs)
raise ValueError(f"Unknown norm_type: {norm_type!r}")