Embeddings#
|
Token embedding lookup table. |
|
Fixed sinusoidal positional encoding. |
|
Precomputed RoPE cos/sin table. |
|
Applies RoPE rotation in the split-half formulation. |
|
Precomputes the RoPE cos/sin table. |
|
Gathers RoPE cos/sin values for given positions. |
|
RoPE inverse-frequency vector. |
|
Patchify an image and project each patch to |
|
Additive learnable positional embedding. |
Token and absolute positions#
- class attnax.TokenEmbedding(*args, **kwargs)[source]#
Bases:
ModuleToken embedding lookup table.
- Parameters:
rngs – Flax NNX random key container.
vocab_size – Vocabulary size.
d_model – Embedding dimension.
args (Any)
kwargs (Any)
- Return type:
Any
- class attnax.PositionalEncoding(*args, **kwargs)[source]#
Bases:
ModuleFixed sinusoidal positional encoding.
\[\mathrm{PE}(p, 2i) = \sin\!\left(\frac{p}{10000^{2i/d}}\right), \qquad \mathrm{PE}(p, 2i+1) = \cos\!\left(\frac{p}{10000^{2i/d}}\right)\]- Parameters:
max_len – Maximum sequence length.
d_model – Embedding dimension.
args (Any)
kwargs (Any)
- Return type:
Any
References
Vaswani et al., Attention Is All You Need, 2017.
Rotary positional embeddings (RoPE)#
- class attnax.RotaryEmbedding(*args, **kwargs)[source]#
Bases:
ModulePrecomputed RoPE cos/sin table.
- Parameters:
head_dim – Per-head feature size; must be even.
max_positions – Length of the precomputed table.
base – RoPE base \(\theta\).
dtype – Dtype of the cos/sin buffers.
args (Any)
kwargs (Any)
- Raises:
ValueError – If
head_dimis odd.- Return type:
Any
- attnax.apply_rope(x, cos, sin)[source]#
Applies RoPE rotation in the split-half formulation.
Splits the last dimension into halves \((x_1, x_2)\) and returns
\[\bigl(x_1 \cos\theta - x_2 \sin\theta,\; x_1 \sin\theta + x_2 \cos\theta\bigr).\]- Parameters:
x (Array) – Array with last dimension
head_dim.cos (Array) – Cosine values broadcastable to
xwith last dimhead_dim // 2.sin (Array) – Sine values of the same shape as
cos.
- Returns:
Array of the same shape as
x.- Return type:
Array
References
Su et al., RoFormer: Enhanced Transformer with Rotary Position Embedding, 2021.
- attnax.rope_cos_sin_table(max_seq_len, head_dim, base, dtype)[source]#
Precomputes the RoPE cos/sin table.
- attnax.rope_cos_sin_from_positions(position_ids, head_dim, base, table_len, out_dtype)[source]#
Gathers RoPE cos/sin values for given positions.
- Parameters:
- Returns:
(cos, sin)each of shape(batch, seq_len, head_dim // 2).- Return type:
tuple[Array, Array]
Vision embeddings#
- class attnax.PatchEmbedding(*args, **kwargs)[source]#
Bases:
ModulePatchify an image and project each patch to
d_model.Implemented as a strided 2-D convolution with kernel and stride equal to
patch_size.- Parameters:
rngs – Flax NNX random key container.
image_size –
intor(height, width).patch_size –
intor(patch_height, patch_width).num_channels – Number of image channels.
d_model – Output token dimension.
args (Any)
kwargs (Any)
- Raises:
ValueError – If
image_sizeis not divisible bypatch_size.- Return type:
Any
- class attnax.LearnedPositionalEmbedding(*args, **kwargs)[source]#
Bases:
ModuleAdditive learnable positional embedding.
Initialised with truncated-normal noise (standard deviation
init_std).- Parameters:
rngs – Flax NNX random key container.
num_positions – Maximum number of positions.
d_model – Embedding dimension.
init_std – Truncated-normal standard deviation.
args (Any)
kwargs (Any)
- Return type:
Any