Skip to content

Latest commit

 

History

History
222 lines (189 loc) · 7.31 KB

README.md

File metadata and controls

222 lines (189 loc) · 7.31 KB

AIM: Autoregressive Image Models

Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar, Joshua M Susskind, and Armand Joulin

To appear at ICML 2024

[Paper] [BibTex]

This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.

We introduce AIM a collection of vision models pre-trained with an autoregressive generative objective. We show that autoregressive pre-training of image features exhibits similar scaling properties to their textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:

  1. the model capacity can be trivially scaled to billions of parameters, and
  2. AIM effectively leverages large collections of uncurated image data.

Installation

Please install PyTorch using the official installation instructions. Afterward, install the package as:

pip install git+https://[email protected]/apple/ml-aim.git

We also offer MLX backend support for research and experimentation on Apple silicon. To enable MLX support, simply run:

pip install mlx

Usage

Below we provide an example of usage in PyTorch:

from PIL import Image

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
logits, features = model(inp)
and in both MLX
from PIL import Image
import mlx.core as mx

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, features = model(inp)
and JAX
from PIL import Image
import jax.numpy as jnp

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, features), _ = model.apply(params, inp, mutable=['batch_stats'])

Pre-trained checkpoints

The pre-trained models can be accessed via PyTorch Hub as:

import torch

aim_600m = torch.hub.load("apple/ml-aim", "aim_600M")
aim_1b   = torch.hub.load("apple/ml-aim", "aim_1B")
aim_3b   = torch.hub.load("apple/ml-aim", "aim_3B")
aim_7b   = torch.hub.load("apple/ml-aim", "aim_7B")

or via HuggingFace Hub as:

from aim.torch.models import AIMForImageClassification

aim_600m = AIMForImageClassification.from_pretrained("apple/aim-600M")
aim_1b   = AIMForImageClassification.from_pretrained("apple/aim-1B")
aim_3b   = AIMForImageClassification.from_pretrained("apple/aim-3B")
aim_7b   = AIMForImageClassification.from_pretrained("apple/aim-7B")

Pre-trained backbones

The following table contains pre-trained backbones used in our paper.

model #params attn (best layer) backbone, SHA256
AIM-0.6B 0.6B 79.4% link, 0d6f6b8f
AIM-1B 1B 82.3% link, d254ecd3
AIM-3B 3B 83.3% link, 8475ce4e
AIM-7B 7B 84.0% link, 184ed94c

Pre-trained attention heads

The table below contains the classification results on ImageNet-1k validation set.

model top-1 IN-1k attention head, SHA256
last layer best layer last layer best layer
AIM-0.6B 78.5% 79.4% link, 5ce5a341 link, ebd45c05
AIM-1B 80.6% 82.3% link, db3be2ad link, f1ed7852
AIM-3B 82.2% 83.3% link, 5c057b30 link, ad380e16
AIM-7B 82.4% 84.0% link, 1e5c99ba link, 73ecd732

Reproducing the IN-1k classification results

The commands below reproduce the attention probe results on ImageNet-1k validation set. We run the evaluation using 1 node with 8 GPUs:

torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
  --model=aim-7B \
  --batch-size=64 \
  --data-path=/path/to/imagenet \
  --probe-layers=best \
  --backbone-ckpt-path=/path/to/backbone_ckpt.pth \
  --head-ckpt-path=/path/to/head_ckpt.pth

By default, we probe features from the intermediate 6 layers that provide the best performance. To change this, simply pass --probe-layers=last.

Citation

If you find our work useful, please consider citing us as:

@article{el2024scalable,
  title={Scalable Pre-training of Large Autoregressive Image Models},
  author={El-Nouby, Alaaeldin and Klein, Michal and Zhai, Shuangfei and Bautista, Miguel Angel and Toshev, Alexander and Shankar, Vaishaal and Susskind, Joshua M and Joulin, Armand},
  journal={International Conference on Machine Learning},
  year={2024}
}