Embeddings#

TokenEmbedding(*args, **kwargs)

Token embedding lookup table.

PositionalEncoding(*args, **kwargs)

Fixed sinusoidal positional encoding.

RotaryEmbedding(*args, **kwargs)

Precomputed RoPE cos/sin table.

apply_rope(x, cos, sin)

Applies RoPE rotation in the split-half formulation.

rope_cos_sin_table(max_seq_len, head_dim, ...)

Precomputes the RoPE cos/sin table.

rope_cos_sin_from_positions(position_ids, ...)

Gathers RoPE cos/sin values for given positions.

rope_inv_freq(head_dim, base, dtype)

RoPE inverse-frequency vector.

PatchEmbedding(*args, **kwargs)

Patchify an image and project each patch to d_model.

LearnedPositionalEmbedding(*args, **kwargs)

Additive learnable positional embedding.

Token and absolute positions#

class attnax.TokenEmbedding(*args, **kwargs)[source]#

Bases: Module

Token embedding lookup table.

Parameters:
  • rngs – Flax NNX random key container.

  • vocab_size – Vocabulary size.

  • d_model – Embedding dimension.

  • args (Any)

  • kwargs (Any)

Return type:

Any

class attnax.PositionalEncoding(*args, **kwargs)[source]#

Bases: Module

Fixed sinusoidal positional encoding.

\[\mathrm{PE}(p, 2i) = \sin\!\left(\frac{p}{10000^{2i/d}}\right), \qquad \mathrm{PE}(p, 2i+1) = \cos\!\left(\frac{p}{10000^{2i/d}}\right)\]
Parameters:
  • max_len – Maximum sequence length.

  • d_model – Embedding dimension.

  • args (Any)

  • kwargs (Any)

Return type:

Any

References

Vaswani et al., Attention Is All You Need, 2017.

Rotary positional embeddings (RoPE)#

class attnax.RotaryEmbedding(*args, **kwargs)[source]#

Bases: Module

Precomputed RoPE cos/sin table.

Parameters:
  • head_dim – Per-head feature size; must be even.

  • max_positions – Length of the precomputed table.

  • base – RoPE base \(\theta\).

  • dtype – Dtype of the cos/sin buffers.

  • args (Any)

  • kwargs (Any)

Raises:

ValueError – If head_dim is odd.

Return type:

Any

cos_sin_for_positions(position_ids, *, out_dtype)[source]#

Gathers cos/sin for the given positions.

Parameters:
  • position_ids (Array) – Integer positions of shape (batch, seq_len).

  • out_dtype (dtype) – Output dtype.

Returns:

(cos, sin) each of shape (batch, seq_len, head_dim // 2).

Return type:

tuple[Array, Array]

attnax.apply_rope(x, cos, sin)[source]#

Applies RoPE rotation in the split-half formulation.

Splits the last dimension into halves \((x_1, x_2)\) and returns

\[\bigl(x_1 \cos\theta - x_2 \sin\theta,\; x_1 \sin\theta + x_2 \cos\theta\bigr).\]
Parameters:
  • x (Array) – Array with last dimension head_dim.

  • cos (Array) – Cosine values broadcastable to x with last dim head_dim // 2.

  • sin (Array) – Sine values of the same shape as cos.

Returns:

Array of the same shape as x.

Return type:

Array

References

Su et al., RoFormer: Enhanced Transformer with Rotary Position Embedding, 2021.

attnax.rope_cos_sin_table(max_seq_len, head_dim, base, dtype)[source]#

Precomputes the RoPE cos/sin table.

Parameters:
  • max_seq_len (int) – Number of positions.

  • head_dim (int) – Per-head feature size; must be even.

  • base (float) – RoPE base \(\theta\).

  • dtype (dtype) – Output dtype.

Returns:

(cos, sin) each of shape (max_seq_len, head_dim // 2).

Return type:

tuple[Array, Array]

attnax.rope_cos_sin_from_positions(position_ids, head_dim, base, table_len, out_dtype)[source]#

Gathers RoPE cos/sin values for given positions.

Parameters:
  • position_ids (Array) – Integer positions of shape (batch, seq_len).

  • head_dim (int) – Per-head feature size; must be even.

  • base (float) – RoPE base \(\theta\).

  • table_len (int) – Length of the precomputed table.

  • out_dtype (dtype) – Output dtype.

Returns:

(cos, sin) each of shape (batch, seq_len, head_dim // 2).

Return type:

tuple[Array, Array]

attnax.rope_inv_freq(head_dim, base, dtype)[source]#

RoPE inverse-frequency vector.

Returns \(b^{-2i / d}\) for \(i \in [0, d/2)\).

Parameters:
  • head_dim (int) – Per-head feature size; must be even.

  • base (float) – RoPE base \(\theta\).

  • dtype (dtype) – Output dtype.

Returns:

Array of shape (head_dim // 2,).

Return type:

Array

Vision embeddings#

class attnax.PatchEmbedding(*args, **kwargs)[source]#

Bases: Module

Patchify an image and project each patch to d_model.

Implemented as a strided 2-D convolution with kernel and stride equal to patch_size.

Parameters:
  • rngs – Flax NNX random key container.

  • image_sizeint or (height, width).

  • patch_sizeint or (patch_height, patch_width).

  • num_channels – Number of image channels.

  • d_model – Output token dimension.

  • args (Any)

  • kwargs (Any)

Raises:

ValueError – If image_size is not divisible by patch_size.

Return type:

Any

class attnax.LearnedPositionalEmbedding(*args, **kwargs)[source]#

Bases: Module

Additive learnable positional embedding.

Initialised with truncated-normal noise (standard deviation init_std).

Parameters:
  • rngs – Flax NNX random key container.

  • num_positions – Maximum number of positions.

  • d_model – Embedding dimension.

  • init_std – Truncated-normal standard deviation.

  • args (Any)

  • kwargs (Any)

Return type:

Any