Vision Transformer#
|
Vision Transformer with optional classification head. |
|
Vision Transformer hyperparameters. |
- class attnax.VisionTransformer(*args, **kwargs)[source]#
Bases:
ModuleVision Transformer with optional classification head.
Patch embedding, optional
[CLS]token, learnable positional embedding,config.num_layersencoder blocks, final normalisation, and an optional linear classification head.- Parameters:
rngs – Flax NNX random key container.
config – Vision Transformer hyperparameters.
args (Any)
kwargs (Any)
- Return type:
Any
References
Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, 2020.
- class attnax.VisionTransformerConfig(*, image_size=224, patch_size=16, num_channels=3, num_classes=None, use_cls_token=True, pool='cls', d_model=768, num_heads=12, num_layers=12, d_ff=3072, dropout_rate=0.0, use_pre_norm=True, ff_activation='gelu', norm_type='layer', attention_type=AttentionType.STANDARD, attention_block_size=512, linear_attention_chunk_size=64, num_kv_heads=None, attention_window=None)[source]#
Bases:
objectVision Transformer hyperparameters.
Defaults reproduce ViT-Base/16 on 224x224 RGB images.
- Parameters:
image_size (int | tuple[int, int]) –
intor(height, width). Must be divisible bypatch_size.patch_size (int | tuple[int, int]) –
intor(patch_height, patch_width).num_channels (int) – Number of image channels.
num_classes (int | None) – Output classes.
Nonedisables the head and returns the token sequence.use_cls_token (bool) – Prepend a learnable
[CLS]token.pool (Literal['cls', 'mean']) – Pooling strategy before the classification head.
d_model (int) – Hidden size.
num_heads (int) – Number of query heads.
num_layers (int) – Number of transformer blocks.
d_ff (int) – Feed-forward hidden width.
dropout_rate (float) – Dropout probability.
use_pre_norm (bool) – Pre-norm when
True, post-norm otherwise.ff_activation (Literal['gelu', 'relu', 'swiglu', 'geglu']) – Feed-forward activation.
norm_type (Literal['layer', 'rms']) – Normalisation kind.
attention_type (AttentionType) – Attention backend.
attention_block_size (int) – Block size for memory-efficient, flash and pallas_flash backends.
linear_attention_chunk_size (int) – Chunk size for the linear backend.
num_kv_heads (int | None) – KV heads for GQA/MQA.
attention_window (int | None) – Causal sliding-window size.
- Raises:
ValueError – If
image_sizeis not divisible bypatch_size, or ifpool == "cls"whileuse_cls_tokenisFalse.
References
Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, 2020.