Source code for attnax.kernels.pallas

# SPDX-License-Identifier: Apache-2.0

"""Pallas FlashAttention kernel."""

from __future__ import annotations

import math
from typing import Optional

import jax
import jax.numpy as jnp

from ._api import ScoreMod
from .attention import memory_efficient_attention


def _supports_pallas() -> bool:
  """Returns whether Pallas is importable on a GPU/TPU backend."""
  try:
    import jax.experimental.pallas as _pl  # noqa: F401
  except ImportError:
    return False
  backend = jax.default_backend()
  return backend in ("gpu", "tpu")


[docs] def pallas_flash_attention( query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray, *, mask: Optional[jnp.ndarray] = None, score_mod: Optional[ScoreMod] = None, dropout_rng: Optional[jax.Array] = None, dropout_rate: float = 0.0, deterministic: bool = True, block_q: int = 128, block_kv: int = 128, force_fallback: bool = False, ) -> jnp.ndarray: """FlashAttention forward pass compiled with Pallas. Dispatches to a Pallas-lowered kernel on GPU/TPU and falls back to :func:`~attnax.kernels.memory_efficient_attention` on CPU or when the kernel fails to lower. Forward pass only; the gradient is computed by tracing the kernel with ``jvp`` rather than by a hand-written backward. Args: query: Array of shape ``(batch, num_heads, seq_q, head_dim)``. key: Array of shape ``(batch, num_heads, seq_kv, head_dim)``. value: Array of shape ``(batch, num_heads, seq_kv, head_dim)``. mask: Boolean mask broadcastable to ``(batch, num_heads, seq_q, seq_kv)``. score_mod: Callable traced into the inner loop; position indices are global. dropout_rng: PRNG key for output dropout. dropout_rate: Output dropout probability. deterministic: If ``True``, disables dropout. block_q: Queries per kernel program. ``seq_q`` must be divisible by ``block_q``. block_kv: Keys per inner-loop block. ``seq_kv`` must be divisible by ``block_kv``. force_fallback: If ``True``, always use the pure-JAX fallback. Returns: Array of shape ``(batch, num_heads, seq_q, head_dim)``. References: Dao et al., `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`_, 2022. """ if force_fallback or not _supports_pallas(): return memory_efficient_attention( query, key, value, mask=mask, score_mod=score_mod, dropout_rng=dropout_rng, dropout_rate=dropout_rate, deterministic=deterministic, block_size=max(block_q, block_kv), ) try: output = _pallas_forward( query, key, value, mask=mask, score_mod=score_mod, block_q=block_q, block_kv=block_kv, ) except Exception: # Any Pallas lowering failure (unsupported head_dim, unsupported # dtype, missing toolchain, …) degrades to the pure-JAX path. return memory_efficient_attention( query, key, value, mask=mask, score_mod=score_mod, dropout_rng=dropout_rng, dropout_rate=dropout_rate, deterministic=deterministic, block_size=max(block_q, block_kv), ) if not deterministic and dropout_rate > 0.0 and dropout_rng is not None: keep_prob = 1.0 - dropout_rate keep = jax.random.bernoulli(dropout_rng, keep_prob, output.shape) output = jnp.where(keep, output / keep_prob, 0.0) return output
def _pallas_forward( query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray, *, mask: Optional[jnp.ndarray], score_mod: Optional[ScoreMod], block_q: int, block_kv: int, ) -> jnp.ndarray: """Compiles and dispatches the Pallas FlashAttention kernel.""" import jax.experimental.pallas as pl batch, num_heads, seq_q, head_dim = query.shape _, _, seq_kv, _ = key.shape scale = 1.0 / math.sqrt(float(head_dim)) if seq_q % block_q != 0: raise ValueError( f"seq_q ({seq_q}) must be divisible by block_q ({block_q}); pad " "the sequence or pick a different block_q." ) if seq_kv % block_kv != 0: raise ValueError( f"seq_kv ({seq_kv}) must be divisible by block_kv ({block_kv}); " "pad the sequence or pick a different block_kv." ) num_kv_blocks = seq_kv // block_kv def kernel(q_ref, k_ref, v_ref, mask_ref, o_ref): program_b = pl.program_id(0) program_h = pl.program_id(1) program_q = pl.program_id(2) q_block = q_ref[...] q_start = program_q * block_q large_neg = jnp.finfo(q_block.dtype).min m_i = jnp.full((block_q,), large_neg, dtype=jnp.float32) l_i = jnp.zeros((block_q,), dtype=jnp.float32) acc = jnp.zeros((block_q, head_dim), dtype=jnp.float32) def body(i, carry): m_i, l_i, acc = carry k_block = pl.load( k_ref, (pl.ds(i * block_kv, block_kv), slice(None)) ) v_block = pl.load( v_ref, (pl.ds(i * block_kv, block_kv), slice(None)) ) scores = jnp.dot(q_block, k_block.T) * scale if score_mod is not None: q_idx = (q_start + jnp.arange(block_q, dtype=jnp.int32)) kv_idx = (i * block_kv + jnp.arange(block_kv, dtype=jnp.int32)) scores_4d = scores[None, None, :, :] b_idx = jnp.asarray(program_b, dtype=jnp.int32)[ None, None, None, None ] h_idx = jnp.asarray(program_h, dtype=jnp.int32)[ None, None, None, None ] scores_4d = score_mod( scores_4d, b_idx, h_idx, q_idx[None, None, :, None], kv_idx[None, None, None, :], ) scores = scores_4d[0, 0] if mask_ref is not None: m_block = pl.load( mask_ref, ( pl.ds(0, block_q), pl.ds(i * block_kv, block_kv), ), ) scores = jnp.where(m_block, scores, large_neg) m_block_max = jnp.max(scores, axis=-1) m_new = jnp.maximum(m_i, m_block_max) alpha = jnp.exp(m_i - m_new) p = jnp.exp(scores - m_new[:, None]) l_new = l_i * alpha + jnp.sum(p, axis=-1) acc = acc * alpha[:, None] + jnp.dot(p, v_block).astype(acc.dtype) return m_new, l_new, acc _, l_final, acc_final = jax.lax.fori_loop( 0, num_kv_blocks, body, (m_i, l_i, acc) ) o_ref[...] = (acc_final / l_final[:, None]).astype(o_ref.dtype) out_shape = jax.ShapeDtypeStruct(query.shape, query.dtype) in_specs = [ pl.BlockSpec((1, 1, block_q, head_dim), lambda b, h, q: (b, h, q, 0)), pl.BlockSpec((1, 1, seq_kv, head_dim), lambda b, h, q: (b, h, 0, 0)), pl.BlockSpec((1, 1, seq_kv, head_dim), lambda b, h, q: (b, h, 0, 0)), ] out_spec = pl.BlockSpec( (1, 1, block_q, head_dim), lambda b, h, q: (b, h, q, 0) ) args = [query, key, value] if mask is not None: mask = jnp.broadcast_to(mask, (batch, num_heads, seq_q, seq_kv)) in_specs.append( pl.BlockSpec( (1, 1, block_q, seq_kv), lambda b, h, q: (b, h, q, 0) ) ) args.append(mask) else: # The kernel signature must be stable, so we pad with a dummy # zero-sized tensor when no mask is supplied. Pallas tolerates # zero-sized BlockSpecs as long as the kernel never reads from # them; the inner ``if mask_ref is not None`` guards access. in_specs.append(None) args.append(None) return pl.pallas_call( kernel, grid=(batch, num_heads, seq_q // block_q), in_specs=in_specs, out_specs=out_spec, out_shape=out_shape, )(*args)