Source code for attnax.kernels.score_mods

# SPDX-License-Identifier: Apache-2.0

"""Prebuilt :data:`ScoreMod` and :data:`MaskMod` constructors."""

from __future__ import annotations

import math

import jax.numpy as jnp

from ._api import MaskMod, ScoreMod, compose_score_mods as compose_score_mods  # noqa: F401


[docs] def alibi_slopes(num_heads: int) -> jnp.ndarray: """Per-head ALiBi slopes. Geometric slope schedule from the ALiBi paper: powers of two for ``num_heads`` that is a power of two, interleaved truncation otherwise. Args: num_heads: Number of attention heads. Returns: ``float32`` array of shape ``(num_heads,)`` with positive slopes :math:`m_h`. The bias added to the score between query position :math:`i` and key position :math:`j` is :math:`-m_h |i - j|`. References: Press et al., `Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation <https://arxiv.org/abs/2108.12409>`_, 2022. """ def _power_of_two_slopes(n: int) -> list[float]: start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3))) return [start * (start**i) for i in range(n)] if (num_heads & (num_heads - 1)) == 0: return jnp.asarray(_power_of_two_slopes(num_heads), dtype=jnp.float32) closest = 2 ** math.floor(math.log2(num_heads)) base = _power_of_two_slopes(closest) extra = _power_of_two_slopes(2 * closest)[0::2][: num_heads - closest] return jnp.asarray(base + extra, dtype=jnp.float32)
[docs] def alibi_mod(num_heads: int) -> ScoreMod: """ALiBi additive bias score-mod. Adds :math:`-m_h |q\\_idx - kv\\_idx|` to the pre-softmax scores, where :math:`m_h` is the per-head slope returned by :func:`alibi_slopes`. Args: num_heads: Number of attention heads. Returns: A :data:`ScoreMod` callable. References: Press et al., `Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation <https://arxiv.org/abs/2108.12409>`_, 2022. """ slopes = alibi_slopes(num_heads)[None, :, None, None] def _mod( scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: del b_idx, h_idx distance = jnp.abs( q_idx.astype(scores.dtype) - kv_idx.astype(scores.dtype) ) return scores - slopes.astype(scores.dtype) * distance return _mod
[docs] def causal_mod() -> ScoreMod: """Causal score-mod. Sets the score to :math:`-\\infty` whenever ``kv_idx > q_idx``. Returns: A :data:`ScoreMod` callable. """ def _mod( scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: del b_idx, h_idx large_neg = jnp.finfo(scores.dtype).min return jnp.where(kv_idx <= q_idx, scores, large_neg) return _mod
[docs] def sliding_window_mod( window_size: int, *, causal: bool = True ) -> ScoreMod: """Sliding-window score-mod. Restricts attention to keys within ``window_size`` of each query. When ``causal`` is ``True``, attends only to positions ``q_idx - window_size < kv_idx <= q_idx``; otherwise the window is symmetric around the query. Args: window_size: Maximum query-key distance (positive integer). causal: If ``True``, restrict to past keys. Returns: A :data:`ScoreMod` callable. Raises: ValueError: If ``window_size <= 0``. """ if window_size <= 0: raise ValueError(f"window_size must be positive, got {window_size}") def _mod( scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: del b_idx, h_idx large_neg = jnp.finfo(scores.dtype).min if causal: in_window = (kv_idx <= q_idx) & (q_idx - kv_idx < window_size) else: in_window = jnp.abs(q_idx - kv_idx) < window_size return jnp.where(in_window, scores, large_neg) return _mod
[docs] def prefix_lm_mod(prefix_lengths: jnp.ndarray) -> ScoreMod: """Prefix-LM score-mod. Attention is bidirectional for ``kv_idx < prefix_lengths[b]`` and causal elsewhere. Args: prefix_lengths: Integer array of shape ``(batch,)`` with the per-sequence prefix length. Returns: A :data:`ScoreMod` callable. """ prefix_lengths = jnp.asarray(prefix_lengths, dtype=jnp.int32) def _mod( scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: del h_idx pref = prefix_lengths[b_idx[..., 0, 0, 0]][:, None, None, None] in_prefix = kv_idx < pref causal_ok = kv_idx <= q_idx allowed = in_prefix | causal_ok large_neg = jnp.finfo(scores.dtype).min return jnp.where(allowed, scores, large_neg) return _mod
[docs] def document_mask_mod(document_ids: jnp.ndarray) -> ScoreMod: """Document-boundary score-mod for sequence packing. Masks attention between tokens in different documents. A query at position :math:`i` only attends to keys at positions :math:`j` with ``document_ids[b, j] == document_ids[b, i]``. Args: document_ids: Integer array of shape ``(batch, seq_len)``. Returns: A :data:`ScoreMod` callable. """ document_ids = jnp.asarray(document_ids, dtype=jnp.int32) same_doc = ( document_ids[:, :, None] == document_ids[:, None, :] )[:, None, :, :] def _mod( scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: del b_idx, h_idx q_positions = q_idx[0, 0, :, 0] kv_positions = kv_idx[0, 0, 0, :] allowed = same_doc[:, :, q_positions[:, None], kv_positions[None, :]] large_neg = jnp.finfo(scores.dtype).min return jnp.where(allowed, scores, large_neg) return _mod
[docs] def additive_bias_mod(bias: jnp.ndarray) -> ScoreMod: """Additive score-mod from a precomputed bias tensor. Args: bias: Float array broadcastable to ``(batch, num_heads, seq_q, seq_kv)``. Returns: A :data:`ScoreMod` callable. """ def _mod( scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: del b_idx, h_idx, q_idx, kv_idx return scores + bias.astype(scores.dtype) return _mod
[docs] def mask_mod_to_score_mod(mask_mod: MaskMod) -> ScoreMod: """Converts a :data:`MaskMod` into an equivalent :data:`ScoreMod`. Sets the score to :math:`-\\infty` where ``mask_mod`` returns ``False`` and leaves other entries untouched. Args: mask_mod: Callable conforming to :data:`MaskMod`. Returns: A :data:`ScoreMod` callable. """ def _mod( scores: jnp.ndarray, b_idx: jnp.ndarray, h_idx: jnp.ndarray, q_idx: jnp.ndarray, kv_idx: jnp.ndarray, ) -> jnp.ndarray: keep = mask_mod(b_idx, h_idx, q_idx, kv_idx) large_neg = jnp.finfo(scores.dtype).min return jnp.where(keep, scores, large_neg) return _mod