Masking utilities#

Boolean masks broadcastable to (batch, num_heads, seq_q, seq_kv). True means attend.

make_padding_mask(input_ids[, pad_token_id])

Returns a key-padding mask.

make_causal_mask(seq_len)

Returns a lower-triangular causal mask.

make_sliding_window_mask(seq_q[, seq_kv, causal])

Returns a sliding-window attention mask.

make_document_mask(document_ids)

Returns a document-boundary mask for sequence packing.

combine_masks(*masks)

Element-wise logical AND of boolean masks.

attnax.make_padding_mask(input_ids, pad_token_id=0)[source]#

Returns a key-padding mask.

Parameters:
  • input_ids (Array) – Integer ids of shape (batch, seq_len).

  • pad_token_id (int) – Token id treated as padding.

Returns:

Boolean array of shape (batch, 1, 1, seq_len); True for non-padding positions.

Return type:

Array

attnax.make_causal_mask(seq_len)[source]#

Returns a lower-triangular causal mask.

Parameters:

seq_len (int) – Sequence length.

Returns:

Boolean array of shape (1, 1, seq_len, seq_len); True at positions (i, j) with j <= i.

Return type:

Array

attnax.make_sliding_window_mask(seq_q, seq_kv=None, *, window_size, causal=True)[source]#

Returns a sliding-window attention mask.

Parameters:
  • seq_q (int) – Query sequence length.

  • seq_kv (int | None) – Key/value sequence length. Defaults to seq_q.

  • window_size (int) – Maximum query-key distance (positive integer).

  • causal (bool) – If True, restrict to past keys.

Returns:

Boolean array of shape (1, 1, seq_q, seq_kv).

Raises:

ValueError – If window_size <= 0.

Return type:

Array

attnax.make_document_mask(document_ids)[source]#

Returns a document-boundary mask for sequence packing.

mask[b, 0, i, j] is True iff document_ids[b, i] == document_ids[b, j].

Parameters:

document_ids (Array) – Integer array of shape (batch, seq_len).

Returns:

Boolean array of shape (batch, 1, seq_len, seq_len).

Return type:

Array

attnax.combine_masks(*masks)[source]#

Element-wise logical AND of boolean masks.

None arguments are skipped; returns None if every argument is None.

Parameters:

*masks (Array | None) – Boolean arrays or None.

Returns:

Combined boolean array, or None.

Return type:

Array | None