Source code for attnax.masking

# SPDX-License-Identifier: Apache-2.0

"""Attention masking utilities."""

from __future__ import annotations

import jax.numpy as jnp


[docs] def make_padding_mask( input_ids: jnp.ndarray, pad_token_id: int = 0 ) -> jnp.ndarray: """Returns a key-padding mask. Args: input_ids: Integer ids of shape ``(batch, seq_len)``. pad_token_id: Token id treated as padding. Returns: Boolean array of shape ``(batch, 1, 1, seq_len)``; ``True`` for non-padding positions. """ mask = input_ids != pad_token_id return mask[:, None, None, :]
[docs] def make_causal_mask(seq_len: int) -> jnp.ndarray: """Returns a lower-triangular causal mask. Args: seq_len: Sequence length. Returns: Boolean array of shape ``(1, 1, seq_len, seq_len)``; ``True`` at positions ``(i, j)`` with ``j <= i``. """ mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool)) return mask[None, None, :, :]
[docs] def make_sliding_window_mask( seq_q: int, seq_kv: int | None = None, *, window_size: int, causal: bool = True, ) -> jnp.ndarray: """Returns a sliding-window attention mask. Args: seq_q: Query sequence length. seq_kv: Key/value sequence length. Defaults to ``seq_q``. window_size: Maximum query-key distance (positive integer). causal: If ``True``, restrict to past keys. Returns: Boolean array of shape ``(1, 1, seq_q, seq_kv)``. Raises: ValueError: If ``window_size <= 0``. """ if window_size <= 0: raise ValueError(f"window_size must be positive, got {window_size}") if seq_kv is None: seq_kv = seq_q q_idx = jnp.arange(seq_q)[:, None] kv_idx = jnp.arange(seq_kv)[None, :] if causal: mask = (kv_idx <= q_idx) & (q_idx - kv_idx < window_size) else: mask = jnp.abs(q_idx - kv_idx) < window_size return mask[None, None, :, :]
[docs] def make_document_mask(document_ids: jnp.ndarray) -> jnp.ndarray: """Returns a document-boundary mask for sequence packing. ``mask[b, 0, i, j]`` is ``True`` iff ``document_ids[b, i] == document_ids[b, j]``. Args: document_ids: Integer array of shape ``(batch, seq_len)``. Returns: Boolean array of shape ``(batch, 1, seq_len, seq_len)``. """ doc_ids = jnp.asarray(document_ids) same_doc = doc_ids[:, :, None] == doc_ids[:, None, :] return same_doc[:, None, :, :]
[docs] def combine_masks(*masks: jnp.ndarray | None) -> jnp.ndarray | None: """Element-wise logical AND of boolean masks. ``None`` arguments are skipped; returns ``None`` if every argument is ``None``. Args: *masks: Boolean arrays or ``None``. Returns: Combined boolean array, or ``None``. """ result = None for mask in masks: if mask is None: continue result = mask if result is None else (result & mask) return result