# SPDX-License-Identifier: Apache-2.0
"""Pure-JAX attention kernels."""
from __future__ import annotations
import math
from typing import Optional
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from ._api import ScoreMod, _apply_score_mod
[docs]
def standard_attention(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
*,
mask: Optional[jnp.ndarray] = None,
score_mod: Optional[ScoreMod] = None,
dropout_rng: Optional[jax.Array] = None,
dropout_rate: float = 0.0,
deterministic: bool = True,
) -> jnp.ndarray:
"""Scaled dot-product attention.
.. math::
\\mathrm{Attention}(Q, K, V) = \\mathrm{softmax}\\!\\left(
\\frac{QK^\\top}{\\sqrt{d_k}} + \\Delta\\right) V
where :math:`\\Delta` is the optional ``score_mod`` bias. Entries
with ``mask == False`` are set to :math:`-\\infty` before the
softmax. Activation memory is :math:`O(n^2)`.
Args:
query: Array of shape ``(batch, num_heads, seq_q, head_dim)``.
key: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
value: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
mask: Boolean mask broadcastable to
``(batch, num_heads, seq_q, seq_kv)``. ``True`` means attend.
score_mod: Callable applied to the pre-softmax scores. See
:mod:`attnax.kernels.score_mods`.
dropout_rng: PRNG key for attention-weight dropout.
dropout_rate: Dropout probability applied to attention weights.
deterministic: If ``True``, disables dropout.
Returns:
Array of shape ``(batch, num_heads, seq_q, head_dim)``.
"""
depth = query.shape[-1]
scale = jax.lax.rsqrt(jnp.asarray(depth, dtype=query.dtype))
scores = jnp.einsum("bhqd,bhkd->bhqk", query, key) * scale
scores = _apply_score_mod(scores, score_mod)
if mask is not None:
large_neg = jnp.finfo(scores.dtype).min
scores = jnp.where(mask, scores, large_neg)
attn_weights = jax.nn.softmax(scores, axis=-1)
if not deterministic and dropout_rate > 0.0 and dropout_rng is not None:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
attn_weights = jnp.where(keep, attn_weights / keep_prob, 0.0)
return jnp.einsum("bhqk,bhkd->bhqd", attn_weights, value)
[docs]
def memory_efficient_attention(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
*,
mask: Optional[jnp.ndarray] = None,
score_mod: Optional[ScoreMod] = None,
dropout_rng: Optional[jax.Array] = None,
dropout_rate: float = 0.0,
deterministic: bool = True,
block_size: int = 512,
) -> jnp.ndarray:
"""Block-wise attention with :math:`O(n)` activation memory.
Tiles the key/value sequence into blocks of size ``block_size`` and
accumulates a running max, softmax denominator, and output via the
online-softmax recurrence. Mathematically identical to
:func:`standard_attention` but does not materialise the full
``(seq_q, seq_kv)`` score matrix. Falls back to
:func:`standard_attention` when both axes fit in one block.
Args:
query: Array of shape ``(batch, num_heads, seq_q, head_dim)``.
key: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
value: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
mask: Boolean mask broadcastable to
``(batch, num_heads, seq_q, seq_kv)``.
score_mod: Callable applied block-by-block to the pre-softmax
scores; position indices are global.
dropout_rng: PRNG key for output dropout.
dropout_rate: Dropout probability applied to the output.
deterministic: If ``True``, disables dropout.
block_size: Number of positions per query / key block.
Returns:
Array of shape ``(batch, num_heads, seq_q, head_dim)``.
References:
Milakov and Gimelshein, `Online normalizer calculation for softmax
<https://arxiv.org/abs/1805.02867>`_, 2018.
Dao et al., `FlashAttention: Fast and Memory-Efficient Exact
Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`_,
2022.
"""
batch, num_heads, seq_q, head_dim = query.shape
_, _, seq_kv, _ = key.shape
if seq_q <= block_size and seq_kv <= block_size:
return standard_attention(
query,
key,
value,
mask=mask,
score_mod=score_mod,
dropout_rng=dropout_rng,
dropout_rate=dropout_rate,
deterministic=deterministic,
)
scale = jax.lax.rsqrt(jnp.asarray(head_dim, dtype=query.dtype))
large_neg_scalar = jnp.finfo(query.dtype).min
num_q_blocks = (seq_q + block_size - 1) // block_size
num_kv_blocks = (seq_kv + block_size - 1) // block_size
b_idx_full = jnp.arange(batch, dtype=jnp.int32)[:, None, None, None]
h_idx_full = jnp.arange(num_heads, dtype=jnp.int32)[None, :, None, None]
output = jnp.zeros_like(query)
for q_block in range(num_q_blocks):
q_start = q_block * block_size
q_end = min(q_start + block_size, seq_q)
bq = q_end - q_start
query_block = query[:, :, q_start:q_end, :]
m_running = jnp.full(
(batch, num_heads, bq, 1), large_neg_scalar, dtype=query.dtype
)
l_running = jnp.zeros((batch, num_heads, bq, 1), dtype=query.dtype)
o_running = jnp.zeros((batch, num_heads, bq, head_dim), dtype=query.dtype)
for kv_block in range(num_kv_blocks):
kv_start = kv_block * block_size
kv_end = min(kv_start + block_size, seq_kv)
key_block = key[:, :, kv_start:kv_end, :]
value_block = value[:, :, kv_start:kv_end, :]
scores = (
jnp.einsum("bhqd,bhkd->bhqk", query_block, key_block) * scale
)
if score_mod is not None:
q_idx_block = jnp.arange(q_start, q_end, dtype=jnp.int32)
kv_idx_block = jnp.arange(kv_start, kv_end, dtype=jnp.int32)
scores = score_mod(
scores,
b_idx_full,
h_idx_full,
q_idx_block[None, None, :, None],
kv_idx_block[None, None, None, :],
)
if mask is not None:
mask_block = mask[..., q_start:q_end, kv_start:kv_end]
scores = jnp.where(mask_block, scores, large_neg_scalar)
# Online-softmax recurrence: rescale the accumulator by
# exp(m_running - m_new) before adding the new block's
# contribution. See Dao et al. (2022), Algorithm 1.
m_block = jnp.max(scores, axis=-1, keepdims=True)
m_new = jnp.maximum(m_running, m_block)
alpha = jnp.exp(m_running - m_new)
p_block = jnp.exp(scores - m_new)
l_running = l_running * alpha + jnp.sum(p_block, axis=-1, keepdims=True)
o_running = (
o_running * alpha
+ jnp.einsum("bhqk,bhkd->bhqd", p_block, value_block)
)
m_running = m_new
block_output = o_running / (l_running + 1e-10)
output = output.at[:, :, q_start:q_end, :].set(block_output)
if not deterministic and dropout_rate > 0.0 and dropout_rng is not None:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, output.shape)
output = jnp.where(keep, output / keep_prob, 0.0)
return output
[docs]
def flash_attention(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
*,
mask: Optional[jnp.ndarray] = None,
score_mod: Optional[ScoreMod] = None,
dropout_rng: Optional[jax.Array] = None,
dropout_rate: float = 0.0,
deterministic: bool = True,
block_size: int = 512,
) -> jnp.ndarray:
"""Hardware-dispatched scaled dot-product attention.
Dispatches to :func:`jax.nn.dot_product_attention` on GPU (backed by
cuDNN's FlashAttention when available) and to
:func:`memory_efficient_attention` on other backends. When
``score_mod`` is set, always uses
:func:`memory_efficient_attention`.
Args:
query: Array of shape ``(batch, num_heads, seq_q, head_dim)``.
key: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
value: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
mask: Boolean mask broadcastable to
``(batch, num_heads, seq_q, seq_kv)``.
score_mod: Callable applied to the pre-softmax scores. Forces the
memory-efficient fallback when set.
dropout_rng: PRNG key for dropout (fallback path only).
dropout_rate: Dropout probability.
deterministic: If ``True``, disables dropout.
block_size: Block size for the fallback path.
Returns:
Array of shape ``(batch, num_heads, seq_q, head_dim)``.
References:
Dao et al., `FlashAttention: Fast and Memory-Efficient Exact
Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`_,
2022.
"""
backend = jax.default_backend()
if (
score_mod is None
and backend == "gpu"
and hasattr(jax.nn, "dot_product_attention")
):
batch, num_heads, seq_q, head_dim = query.shape
q = jnp.transpose(query, (0, 2, 1, 3))
k = jnp.transpose(key, (0, 2, 1, 3))
v = jnp.transpose(value, (0, 2, 1, 3))
if mask is not None:
mask = jnp.broadcast_to(
mask, (batch, num_heads, seq_q, key.shape[2])
)
scale = 1.0 / math.sqrt(float(head_dim))
output = jax.nn.dot_product_attention(
q, k, v, bias=None, mask=mask, scale=scale
)
return jnp.transpose(output, (0, 2, 1, 3))
return memory_efficient_attention(
query,
key,
value,
mask=mask,
score_mod=score_mod,
dropout_rng=dropout_rng,
dropout_rate=dropout_rate,
deterministic=deterministic,
block_size=block_size,
)
def lite_attention(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
gate_proj: nnx.Linear,
*,
mask: Optional[jnp.ndarray] = None,
dropout_rng: Optional[jax.Array] = None,
dropout_rate: float = 0.0,
deterministic: bool = True,
) -> jnp.ndarray:
"""Element-wise gated attention.
Replaces the :math:`QK^\\top` matmul with the Hadamard product
:math:`Q \\odot K` and a learnable linear gate. Carries trainable
state (``gate_proj``) and therefore does not conform to the
:data:`AttentionFn` protocol; selected through
:attr:`~attnax.AttentionType.LITE`.
Args:
query: Array of shape ``(batch, num_heads, seq_q, head_dim)``.
key: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
value: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
gate_proj: Linear layer mapping ``(head_dim,)`` to a scalar gate
logit.
mask: Boolean mask broadcastable to
``(batch, num_heads, seq_q, seq_kv)``.
dropout_rng: PRNG key for output dropout.
dropout_rate: Dropout probability.
deterministic: If ``True``, disables dropout.
Returns:
Array of shape ``(batch, num_heads, seq_q, head_dim)``.
"""
batch, num_heads, seq_q, head_dim = query.shape
_, _, seq_kv, _ = key.shape
if seq_q == seq_kv:
attention_scores = query * key
else:
key_expanded = jnp.repeat(key[:, :, :1, :], seq_q, axis=2)
attention_scores = query * key_expanded
scores_flat = attention_scores.reshape(-1, head_dim)
gate_scores = gate_proj(scores_flat)
gate_scores = gate_scores.reshape(batch, num_heads, seq_q, 1)
attn_weights = jax.nn.softmax(gate_scores, axis=-2)
if mask is not None:
mask_expanded = mask[..., :1]
attn_weights = attn_weights * mask_expanded
if seq_q == seq_kv:
output = attn_weights * value
else:
value_expanded = jnp.repeat(value[:, :, :1, :], seq_q, axis=2)
output = attn_weights * value_expanded
if not deterministic and dropout_rate > 0.0 and dropout_rng is not None:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, output.shape)
output = jnp.where(keep, output / keep_prob, 0.0)
return output
def _phi(x: jnp.ndarray) -> jnp.ndarray:
"""Feature map :math:`\\phi(x) = \\mathrm{elu}(x) + 1`."""
return jax.nn.elu(x) + 1.0
def _linear_attention_non_causal(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
*,
mask: Optional[jnp.ndarray],
dropout_rng: Optional[jax.Array],
dropout_rate: float,
deterministic: bool,
) -> jnp.ndarray:
"""Non-causal linear attention as a single matmul."""
q_phi = _phi(query)
k_phi = _phi(key)
if mask is not None:
batch, num_heads, seq_q, _ = query.shape
mask = jnp.broadcast_to(mask, (batch, num_heads, seq_q, seq_q))
key_mask = mask[:, :, 0, :, None]
k_phi = jnp.where(key_mask, k_phi, 0.0)
value = jnp.where(key_mask, value, 0.0)
s = jnp.einsum("bhkd,bhke->bhde", k_phi, value)
z = jnp.sum(k_phi, axis=2)
numer = jnp.einsum("bhqd,bhde->bhqe", q_phi, s)
denom = jnp.einsum("bhqd,bhd->bhq", q_phi, z)
out = numer / (denom[..., None] + 1e-6)
if not deterministic and dropout_rate > 0.0 and dropout_rng is not None:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, out.shape)
out = jnp.where(keep, out / keep_prob, 0.0)
return out
[docs]
def linear_attention(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
*,
mask: Optional[jnp.ndarray] = None,
score_mod: Optional[ScoreMod] = None,
dropout_rng: Optional[jax.Array] = None,
dropout_rate: float = 0.0,
deterministic: bool = True,
causal: bool = True,
chunk_size: int = 64,
) -> jnp.ndarray:
"""Chunkwise-parallel linear attention.
Softmax-free attention with the recurrence
.. math::
S_t = S_{t-1} + \\phi(K_t)^\\top V_t,
\\qquad
Y_t = \\frac{\\phi(Q_t)\\,S_t}{\\phi(Q_t)\\,Z_t + \\epsilon},
where :math:`\\phi(x) = \\mathrm{elu}(x) + 1` and :math:`Z_t` is the
running sum of :math:`\\phi(K)`. Inside a chunk, attention is a
softmax-free matmul; across chunks, a ``(head_dim, head_dim)`` state
and a ``(head_dim,)`` normaliser are propagated with
:func:`jax.lax.scan`. Activation memory is :math:`O(L \\, d^2)`.
Args:
query: Array of shape ``(batch, num_heads, seq_q, head_dim)``.
key: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
``seq_kv`` must equal ``seq_q``.
value: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
mask: Boolean mask broadcastable to
``(batch, num_heads, seq_q, seq_kv)``. Only the first row along
the query axis is consulted (treated as a key-padding mask).
score_mod: Must be ``None``.
dropout_rng: PRNG key for output dropout.
dropout_rate: Dropout probability applied to the output.
deterministic: If ``True``, disables dropout.
causal: If ``True``, intra-chunk attention is causal and
inter-chunk state propagates left-to-right.
chunk_size: Tokens per chunk; ``seq_q`` must be divisible by
``chunk_size``.
Returns:
Array of shape ``(batch, num_heads, seq_q, head_dim)``.
Raises:
NotImplementedError: If ``score_mod`` is set.
ValueError: If ``seq_q != seq_kv`` or ``seq_q`` is not divisible
by ``chunk_size``.
References:
Katharopoulos et al., `Transformers are RNNs: Fast Autoregressive
Transformers with Linear Attention
<https://arxiv.org/abs/2006.16236>`_, 2020.
"""
if score_mod is not None:
raise NotImplementedError(
"linear_attention is softmax-free and has no scores to bias; "
"use standard_attention / memory_efficient_attention with the "
"desired score_mod, or write a custom AttentionFn."
)
batch, num_heads, seq_q, head_dim = query.shape
if key.shape[2] != seq_q or value.shape[2] != seq_q:
raise ValueError(
"linear_attention requires self-attention with matching Q/K/V "
f"sequence lengths; got Q={seq_q}, K={key.shape[2]}, "
f"V={value.shape[2]}"
)
if not causal:
return _linear_attention_non_causal(
query, key, value, mask=mask,
dropout_rng=dropout_rng, dropout_rate=dropout_rate,
deterministic=deterministic,
)
if seq_q % chunk_size != 0:
raise ValueError(
f"seq_q ({seq_q}) must be divisible by chunk_size "
f"({chunk_size}); pad the sequence to a multiple of chunk_size"
)
q_phi = _phi(query)
k_phi = _phi(key)
if mask is not None:
mask = jnp.broadcast_to(mask, (batch, num_heads, seq_q, seq_q))
key_mask = mask[:, :, 0, :, None]
k_phi = jnp.where(key_mask, k_phi, 0.0)
value = jnp.where(key_mask, value, 0.0)
num_chunks = seq_q // chunk_size
q_chunks = q_phi.reshape(
batch, num_heads, num_chunks, chunk_size, head_dim
)
k_chunks = k_phi.reshape(
batch, num_heads, num_chunks, chunk_size, head_dim
)
v_chunks = value.reshape(
batch, num_heads, num_chunks, chunk_size, head_dim
)
positions = jnp.arange(chunk_size)
causal_intra = (
positions[:, None] >= positions[None, :]
).astype(query.dtype)
def step(state, inputs):
s, z = state
q_c, k_c, v_c = inputs
inter_out = jnp.einsum("bhcd,bhde->bhce", q_c, s)
inter_z = jnp.einsum("bhcd,bhd->bhc", q_c, z)
qk = jnp.einsum("bhcd,bhed->bhce", q_c, k_c) * causal_intra
intra_out = jnp.einsum("bhce,bhed->bhcd", qk, v_c)
intra_z = jnp.sum(qk, axis=-1)
out_chunk = (inter_out + intra_out) / (
inter_z[..., None] + intra_z[..., None] + 1e-6
)
s_next = s + jnp.einsum("bhcd,bhce->bhde", k_c, v_c)
z_next = z + jnp.sum(k_c, axis=2)
return (s_next, z_next), out_chunk
init_state = (
jnp.zeros((batch, num_heads, head_dim, head_dim), dtype=query.dtype),
jnp.zeros((batch, num_heads, head_dim), dtype=query.dtype),
)
scan_inputs = (
jnp.transpose(q_chunks, (2, 0, 1, 3, 4)),
jnp.transpose(k_chunks, (2, 0, 1, 3, 4)),
jnp.transpose(v_chunks, (2, 0, 1, 3, 4)),
)
_, out_per_chunk = jax.lax.scan(step, init_state, scan_inputs)
out = jnp.transpose(out_per_chunk, (1, 2, 0, 3, 4)).reshape(
batch, num_heads, seq_q, head_dim
)
if not deterministic and dropout_rate > 0.0 and dropout_rng is not None:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, out.shape)
out = jnp.where(keep, out / keep_prob, 0.0)
return out
[docs]
def ring_attention(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
*,
mask: Optional[jnp.ndarray] = None,
score_mod: Optional[ScoreMod] = None,
dropout_rng: Optional[jax.Array] = None,
dropout_rate: float = 0.0,
deterministic: bool = True,
axis_name: Optional[str] = None,
block_size: int = 512,
) -> jnp.ndarray:
"""Sequence-parallel ring attention.
Each device holds a shard of ``Q``, ``K``, ``V`` along the sequence
axis. The kernel rotates the local ``K`` / ``V`` shard around a ring
of devices via :func:`jax.lax.ppermute`, accumulating online-softmax
statistics so that each device has attended to every key/value
after one full rotation. Activation memory is :math:`O(L / P)` per
device for sequence length :math:`L` across :math:`P` devices.
Falls back to :func:`memory_efficient_attention` when ``axis_name``
is ``None`` or the named axis is unbound.
Args:
query: Local shard of shape
``(batch, num_heads, seq_q_local, head_dim)``.
key: Local shard of shape
``(batch, num_heads, seq_kv_local, head_dim)``.
value: Local shard of shape
``(batch, num_heads, seq_kv_local, head_dim)``.
mask: Boolean mask broadcastable to the local score shape
``(batch, num_heads, seq_q_local, seq_kv_local)``.
score_mod: Callable applied to the pre-softmax scores; position
indices are global.
dropout_rng: PRNG key for output dropout.
dropout_rate: Output dropout probability.
deterministic: If ``True``, disables dropout.
axis_name: Mesh axis to ring-permute over. When ``None``, runs on
the current device.
block_size: Block size for the single-device fallback.
Returns:
Local output shard of shape
``(batch, num_heads, seq_q_local, head_dim)``.
References:
Liu et al., `Ring Attention with Blockwise Transformers for
Near-Infinite Context <https://arxiv.org/abs/2310.01889>`_, 2023.
"""
if axis_name is None:
return memory_efficient_attention(
query,
key,
value,
mask=mask,
score_mod=score_mod,
dropout_rng=dropout_rng,
dropout_rate=dropout_rate,
deterministic=deterministic,
block_size=block_size,
)
try:
axis_size = jax.lax.axis_size(axis_name)
except Exception:
return memory_efficient_attention(
query,
key,
value,
mask=mask,
score_mod=score_mod,
dropout_rng=dropout_rng,
dropout_rate=dropout_rate,
deterministic=deterministic,
block_size=block_size,
)
axis_idx = jax.lax.axis_index(axis_name)
batch, num_heads, seq_q_local, head_dim = query.shape
seq_kv_local = key.shape[2]
scale = jax.lax.rsqrt(jnp.asarray(head_dim, dtype=query.dtype))
large_neg = jnp.finfo(query.dtype).min
m_running = jnp.full(
(batch, num_heads, seq_q_local, 1), large_neg, dtype=query.dtype
)
l_running = jnp.zeros(
(batch, num_heads, seq_q_local, 1), dtype=query.dtype
)
o_running = jnp.zeros_like(query)
k_cur = key
v_cur = value
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
def step(carry, i):
m_running, l_running, o_running, k_cur, v_cur = carry
kv_owner = (axis_idx - i) % axis_size
q_global_start = axis_idx * seq_q_local
kv_global_start = kv_owner * seq_kv_local
scores = jnp.einsum("bhqd,bhkd->bhqk", query, k_cur) * scale
if score_mod is not None:
b_idx = jnp.arange(batch, dtype=jnp.int32)[:, None, None, None]
h_idx = jnp.arange(
num_heads, dtype=jnp.int32
)[None, :, None, None]
q_idx = (
jnp.arange(seq_q_local, dtype=jnp.int32) + q_global_start
)[None, None, :, None]
kv_idx = (
jnp.arange(seq_kv_local, dtype=jnp.int32) + kv_global_start
)[None, None, None, :]
scores = score_mod(scores, b_idx, h_idx, q_idx, kv_idx)
if mask is not None:
scores = jnp.where(mask, scores, large_neg)
m_block = jnp.max(scores, axis=-1, keepdims=True)
m_new = jnp.maximum(m_running, m_block)
alpha = jnp.exp(m_running - m_new)
p_block = jnp.exp(scores - m_new)
l_new = l_running * alpha + jnp.sum(p_block, axis=-1, keepdims=True)
o_new = (
o_running * alpha
+ jnp.einsum("bhqk,bhkd->bhqd", p_block, v_cur)
)
k_next = jax.lax.ppermute(k_cur, axis_name=axis_name, perm=perm)
v_next = jax.lax.ppermute(v_cur, axis_name=axis_name, perm=perm)
return (m_new, l_new, o_new, k_next, v_next), None
(m_running, l_running, o_running, _, _), _ = jax.lax.scan(
step,
(m_running, l_running, o_running, k_cur, v_cur),
jnp.arange(axis_size),
)
output = o_running / (l_running + 1e-10)
if not deterministic and dropout_rate > 0.0 and dropout_rng is not None:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, output.shape)
output = jnp.where(keep, output / keep_prob, 0.0)
return output
[docs]
def ring_attention_reference(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
*,
mask: Optional[jnp.ndarray] = None,
score_mod: Optional[ScoreMod] = None,
) -> jnp.ndarray:
"""Single-device reference for :func:`ring_attention`.
Computes the same output as :func:`ring_attention` when the entire
sequence resides on one device. Used for numerical equivalence
tests.
Args:
query: Array of shape ``(batch, num_heads, seq_q, head_dim)``.
key: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
value: Array of shape ``(batch, num_heads, seq_kv, head_dim)``.
mask: Boolean mask broadcastable to
``(batch, num_heads, seq_q, seq_kv)``.
score_mod: Callable applied to the pre-softmax scores.
Returns:
Array of shape ``(batch, num_heads, seq_q, head_dim)``.
"""
return standard_attention(
query, key, value, mask=mask, score_mod=score_mod
)