Source code for attnax.embeddings

# SPDX-License-Identifier: Apache-2.0

"""Token, positional, and rotary (RoPE) embeddings."""

from __future__ import annotations

import jax
import jax.numpy as jnp
import flax.nnx as nnx

from .config import _pair


[docs] class TokenEmbedding(nnx.Module): """Token embedding lookup table. Args: rngs: Flax NNX random key container. vocab_size: Vocabulary size. d_model: Embedding dimension. """ def __init__(self, rngs: nnx.Rngs, vocab_size: int, d_model: int): self.embed = nnx.Embed( num_embeddings=vocab_size, features=d_model, rngs=rngs ) def __call__(self, token_ids: jnp.ndarray) -> jnp.ndarray: """Looks up embeddings. Args: token_ids: Integer ids of shape ``(batch, seq_len)``. Returns: Embeddings of shape ``(batch, seq_len, d_model)``. """ return self.embed(token_ids)
[docs] class PositionalEncoding(nnx.Module): """Fixed sinusoidal positional encoding. .. math:: \\mathrm{PE}(p, 2i) = \\sin\\!\\left(\\frac{p}{10000^{2i/d}}\\right), \\qquad \\mathrm{PE}(p, 2i+1) = \\cos\\!\\left(\\frac{p}{10000^{2i/d}}\\right) Args: max_len: Maximum sequence length. d_model: Embedding dimension. References: Vaswani et al., `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_, 2017. """ def __init__(self, max_len: int, d_model: int): self.max_len = max_len self.d_model = d_model self.positional = self._create_sinusoidal_positions(max_len, d_model) @staticmethod def _create_sinusoidal_positions(max_len: int, d_model: int) -> jnp.ndarray: positions = jnp.arange(max_len)[:, None] dims = jnp.arange(d_model)[None, :] angle_rates = 1.0 / (10000 ** (2 * (dims // 2) / d_model)) angles = positions * angle_rates return jnp.where(dims % 2 == 0, jnp.sin(angles), jnp.cos(angles)) def __call__(self, x: jnp.ndarray, start: int = 0) -> jnp.ndarray: """Adds the positional encoding to ``x``. Args: x: Embeddings of shape ``(batch, seq_len, d_model)``. start: Offset of the first token, used when decoding past a KV cache. Returns: Array of the same shape as ``x``. """ seq_len = x.shape[1] return x + self.positional[None, start : start + seq_len, :]
[docs] def rope_inv_freq(head_dim: int, base: float, dtype: jnp.dtype) -> jnp.ndarray: """RoPE inverse-frequency vector. Returns :math:`b^{-2i / d}` for :math:`i \\in [0, d/2)`. Args: head_dim: Per-head feature size; must be even. base: RoPE base :math:`\\theta`. dtype: Output dtype. Returns: Array of shape ``(head_dim // 2,)``. """ half = head_dim // 2 return 1.0 / (base ** (jnp.arange(0, half, dtype=dtype) / half))
[docs] def rope_cos_sin_table( max_seq_len: int, head_dim: int, base: float, dtype: jnp.dtype, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Precomputes the RoPE cos/sin table. Args: max_seq_len: Number of positions. head_dim: Per-head feature size; must be even. base: RoPE base :math:`\\theta`. dtype: Output dtype. Returns: ``(cos, sin)`` each of shape ``(max_seq_len, head_dim // 2)``. """ inv_freq = rope_inv_freq(head_dim, base, dtype=dtype) t = jnp.arange(max_seq_len, dtype=dtype) freqs = jnp.outer(t, inv_freq) return jnp.cos(freqs).astype(dtype), jnp.sin(freqs).astype(dtype)
[docs] def apply_rope( x: jnp.ndarray, cos: jnp.ndarray, sin: jnp.ndarray, ) -> jnp.ndarray: """Applies RoPE rotation in the split-half formulation. Splits the last dimension into halves :math:`(x_1, x_2)` and returns .. math:: \\bigl(x_1 \\cos\\theta - x_2 \\sin\\theta,\\; x_1 \\sin\\theta + x_2 \\cos\\theta\\bigr). Args: x: Array with last dimension ``head_dim``. cos: Cosine values broadcastable to ``x`` with last dim ``head_dim // 2``. sin: Sine values of the same shape as ``cos``. Returns: Array of the same shape as ``x``. References: Su et al., `RoFormer: Enhanced Transformer with Rotary Position Embedding <https://arxiv.org/abs/2104.09864>`_, 2021. """ half = x.shape[-1] // 2 x1 = x[..., :half] x2 = x[..., half:] return jnp.concatenate( [x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1, )
[docs] def rope_cos_sin_from_positions( position_ids: jnp.ndarray, head_dim: int, base: float, table_len: int, out_dtype: jnp.dtype, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Gathers RoPE cos/sin values for given positions. Args: position_ids: Integer positions of shape ``(batch, seq_len)``. head_dim: Per-head feature size; must be even. base: RoPE base :math:`\\theta`. table_len: Length of the precomputed table. out_dtype: Output dtype. Returns: ``(cos, sin)`` each of shape ``(batch, seq_len, head_dim // 2)``. """ cos_t, sin_t = rope_cos_sin_table(table_len, head_dim, base, dtype=out_dtype) cos = cos_t[position_ids] sin = sin_t[position_ids] return cos, sin
[docs] class RotaryEmbedding(nnx.Module): """Precomputed RoPE cos/sin table. Args: head_dim: Per-head feature size; must be even. max_positions: Length of the precomputed table. base: RoPE base :math:`\\theta`. dtype: Dtype of the cos/sin buffers. Raises: ValueError: If ``head_dim`` is odd. """ def __init__( self, head_dim: int, max_positions: int, *, base: float = 10000.0, dtype: jnp.dtype = jnp.float32, ): if head_dim % 2 != 0: raise ValueError(f"head_dim must be even for RoPE, got {head_dim}") self.head_dim = head_dim self.max_positions = max_positions self.base = base cos, sin = rope_cos_sin_table(max_positions, head_dim, base, dtype=dtype) self.cos_cached = cos self.sin_cached = sin
[docs] def cos_sin_for_positions( self, position_ids: jnp.ndarray, *, out_dtype: jnp.dtype ) -> tuple[jnp.ndarray, jnp.ndarray]: """Gathers cos/sin for the given positions. Args: position_ids: Integer positions of shape ``(batch, seq_len)``. out_dtype: Output dtype. Returns: ``(cos, sin)`` each of shape ``(batch, seq_len, head_dim // 2)``. """ cos = self.cos_cached[position_ids].astype(out_dtype) sin = self.sin_cached[position_ids].astype(out_dtype) return cos, sin
[docs] class PatchEmbedding(nnx.Module): """Patchify an image and project each patch to ``d_model``. Implemented as a strided 2-D convolution with kernel and stride equal to ``patch_size``. Args: rngs: Flax NNX random key container. image_size: ``int`` or ``(height, width)``. patch_size: ``int`` or ``(patch_height, patch_width)``. num_channels: Number of image channels. d_model: Output token dimension. Raises: ValueError: If ``image_size`` is not divisible by ``patch_size``. """ def __init__( self, rngs: nnx.Rngs, *, image_size: int | tuple[int, int], patch_size: int | tuple[int, int], num_channels: int, d_model: int, ): img_h, img_w = _pair(image_size) patch_h, patch_w = _pair(patch_size) if img_h % patch_h != 0 or img_w % patch_w != 0: raise ValueError( f"image_size {(img_h, img_w)} must be divisible by patch_size " f"{(patch_h, patch_w)} along both axes." ) self.image_size = (img_h, img_w) self.patch_size = (patch_h, patch_w) self.num_channels = num_channels self.d_model = d_model self.grid_size = (img_h // patch_h, img_w // patch_w) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nnx.Conv( in_features=num_channels, out_features=d_model, kernel_size=self.patch_size, strides=self.patch_size, padding="VALID", rngs=rngs, ) def __call__(self, images: jnp.ndarray) -> jnp.ndarray: """Returns patch tokens. Args: images: Array of shape ``(batch, height, width, channels)``. Returns: Array of shape ``(batch, num_patches, d_model)``. """ x = self.proj(images) batch, h, w, d = x.shape return x.reshape(batch, h * w, d)
[docs] class LearnedPositionalEmbedding(nnx.Module): """Additive learnable positional embedding. Initialised with truncated-normal noise (standard deviation ``init_std``). Args: rngs: Flax NNX random key container. num_positions: Maximum number of positions. d_model: Embedding dimension. init_std: Truncated-normal standard deviation. """ def __init__( self, rngs: nnx.Rngs, *, num_positions: int, d_model: int, init_std: float = 0.02, ): self.num_positions = num_positions self.d_model = d_model key = rngs.params() init = jax.random.truncated_normal( key, lower=-2.0, upper=2.0, shape=(num_positions, d_model) ) * init_std self.embedding = nnx.Param(init) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Adds the learned positional embedding. Args: x: Array of shape ``(batch, seq_len, d_model)``. Returns: Array of the same shape as ``x``. Raises: ValueError: If ``seq_len > num_positions``. """ seq_len = x.shape[1] if seq_len > self.num_positions: raise ValueError( f"Input sequence length {seq_len} exceeds the positional " f"embedding table size {self.num_positions}." ) return x + self.embedding[...][None, :seq_len, :]