Source code for attnax.vision
# SPDX-License-Identifier: Apache-2.0
"""Vision Transformer."""
from __future__ import annotations
from typing import Optional
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from .blocks import EncoderBlock
from .config import VisionTransformerConfig
from .embeddings import LearnedPositionalEmbedding, PatchEmbedding
from .norms import create_norm
[docs]
class VisionTransformer(nnx.Module):
"""Vision Transformer with optional classification head.
Patch embedding, optional ``[CLS]`` token, learnable positional
embedding, ``config.num_layers`` encoder blocks, final
normalisation, and an optional linear classification head.
Args:
rngs: Flax NNX random key container.
config: Vision Transformer hyperparameters.
References:
Dosovitskiy et al., `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale
<https://arxiv.org/abs/2010.11929>`_, 2020.
"""
def __init__(self, rngs: nnx.Rngs, config: VisionTransformerConfig):
self.config = config
self.patch_embed = PatchEmbedding(
rngs,
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
d_model=config.d_model,
)
if config.use_cls_token:
key = rngs.params()
cls_init = jax.random.truncated_normal(
key, lower=-2.0, upper=2.0, shape=(1, 1, config.d_model)
) * 0.02
self.cls_token = nnx.Param(cls_init)
else:
self.cls_token = None
num_positions = self.patch_embed.num_patches + int(config.use_cls_token)
self.pos_embed = LearnedPositionalEmbedding(
rngs, num_positions=num_positions, d_model=config.d_model
)
self.dropout = nnx.Dropout(rate=config.dropout_rate, rngs=rngs)
self.layers = nnx.List(
[
EncoderBlock(
rngs,
d_model=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
dropout_rate=config.dropout_rate,
pre_norm=config.use_pre_norm,
attention_type=config.attention_type,
attention_block_size=config.attention_block_size,
linear_attention_chunk_size=config.linear_attention_chunk_size,
norm_type=config.norm_type,
ff_activation=config.ff_activation,
num_kv_heads=config.num_kv_heads,
attention_window=config.attention_window,
use_rope=False,
)
for _ in range(config.num_layers)
]
)
self.final_norm = create_norm(config.norm_type, config.d_model, rngs=rngs)
if config.num_classes is not None:
self.head = nnx.Linear(config.d_model, config.num_classes, rngs=rngs)
else:
self.head = None
def __call__(
self,
images: jnp.ndarray,
*,
deterministic: Optional[bool] = None,
) -> jnp.ndarray:
"""Applies the Vision Transformer.
Args:
images: Array of shape
``(batch, height, width, num_channels)``.
deterministic: If ``True``, disables dropout.
Returns:
Logits of shape ``(batch, num_classes)`` when
``config.num_classes`` is set, otherwise tokens of shape
``(batch, num_patches + int(use_cls_token), d_model)``.
"""
x = self.patch_embed(images)
if self.cls_token is not None:
batch = x.shape[0]
cls = jnp.broadcast_to(
self.cls_token[...], (batch, 1, self.config.d_model)
)
x = jnp.concatenate([cls, x], axis=1)
x = self.pos_embed(x)
x = self.dropout(x, deterministic=deterministic)
for layer in self.layers:
x = layer(x, deterministic=deterministic)
x = self.final_norm(x)
if self.head is None:
return x
if self.config.pool == "cls":
pooled = x[:, 0, :]
else:
pooled = jnp.mean(x, axis=1)
return self.head(pooled)