Vision Transformer

Vision Transformer#

VisionTransformer(*args, **kwargs)

Vision Transformer with optional classification head.

VisionTransformerConfig(*[, image_size, ...])

Vision Transformer hyperparameters.

class attnax.VisionTransformer(*args, **kwargs)[source]#

Bases: Module

Vision Transformer with optional classification head.

Patch embedding, optional [CLS] token, learnable positional embedding, config.num_layers encoder blocks, final normalisation, and an optional linear classification head.

Parameters:
  • rngs – Flax NNX random key container.

  • config – Vision Transformer hyperparameters.

  • args (Any)

  • kwargs (Any)

Return type:

Any

References

Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, 2020.

class attnax.VisionTransformerConfig(*, image_size=224, patch_size=16, num_channels=3, num_classes=None, use_cls_token=True, pool='cls', d_model=768, num_heads=12, num_layers=12, d_ff=3072, dropout_rate=0.0, use_pre_norm=True, ff_activation='gelu', norm_type='layer', attention_type=AttentionType.STANDARD, attention_block_size=512, linear_attention_chunk_size=64, num_kv_heads=None, attention_window=None)[source]#

Bases: object

Vision Transformer hyperparameters.

Defaults reproduce ViT-Base/16 on 224x224 RGB images.

Parameters:
  • image_size (int | tuple[int, int]) – int or (height, width). Must be divisible by patch_size.

  • patch_size (int | tuple[int, int]) – int or (patch_height, patch_width).

  • num_channels (int) – Number of image channels.

  • num_classes (int | None) – Output classes. None disables the head and returns the token sequence.

  • use_cls_token (bool) – Prepend a learnable [CLS] token.

  • pool (Literal['cls', 'mean']) – Pooling strategy before the classification head.

  • d_model (int) – Hidden size.

  • num_heads (int) – Number of query heads.

  • num_layers (int) – Number of transformer blocks.

  • d_ff (int) – Feed-forward hidden width.

  • dropout_rate (float) – Dropout probability.

  • use_pre_norm (bool) – Pre-norm when True, post-norm otherwise.

  • ff_activation (Literal['gelu', 'relu', 'swiglu', 'geglu']) – Feed-forward activation.

  • norm_type (Literal['layer', 'rms']) – Normalisation kind.

  • attention_type (AttentionType) – Attention backend.

  • attention_block_size (int) – Block size for memory-efficient, flash and pallas_flash backends.

  • linear_attention_chunk_size (int) – Chunk size for the linear backend.

  • num_kv_heads (int | None) – KV heads for GQA/MQA.

  • attention_window (int | None) – Causal sliding-window size.

Raises:

ValueError – If image_size is not divisible by patch_size, or if pool == "cls" while use_cls_token is False.

References

Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, 2020.