Edit model card

Perceiver IO image classifier (MNIST)

This model is a small Perceiver IO image classifier (907K parameters) trained from scratch on the MNIST dataset. It is a training example of the perceiver-io library.

Model description

Like krasserm/perceiver-io-img-clf this model also uses 2D Fourier features for position encoding and cross-attends to individual pixels of an input image but uses repeated cross-attention, a configuration that was described in the original Perceiver paper which has been dropped in the follow-up Perceiver IO paper (see building blocks for more details).

Model training

The model was trained with randomly initialized weights on the MNIST handwritten digits dataset. Images were normalized, data augmentations were turned off. Training was done with PyTorch Lightning and the resulting checkpoint was converted to this 🤗 model with a library-specific conversion utility.

Intended use and limitations

The model can be used for MNIST handwritten digit classification.

Usage examples

To use this model you first need to install the perceiver-io library with extension vision.

pip install perceiver-io[vision]

Then the model can be used with PyTorch. Either use the model and image processor directly

from datasets import load_dataset
from transformers import AutoModelForImageClassification, AutoImageProcessor
from perceiver.model.vision import image_classifier  # auto-class registration

repo_id = "krasserm/perceiver-io-img-clf-mnist"

mnist_dataset = load_dataset("mnist", split="test")[:9]

images = mnist_dataset["image"]
labels = mnist_dataset["label"]

model = AutoModelForImageClassification.from_pretrained(repo_id)
processor = AutoImageProcessor.from_pretrained(repo_id)

inputs = processor(images, return_tensors="pt")
logits = model(**inputs).logits

print(f"Labels:      {labels}")
print(f"Predictions: {logits.argmax(dim=-1).numpy().tolist()}")
Labels:      [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]

or use an image-classification pipeline:

from datasets import load_dataset
from transformers import pipeline
from perceiver.model.vision import image_classifier  # auto-class registration

repo_id = "krasserm/perceiver-io-img-clf-mnist"

mnist_dataset = load_dataset("mnist", split="test")[:9]

images = mnist_dataset["image"]
labels = mnist_dataset["label"]

classifier = pipeline("image-classification", model=repo_id)
predictions = [pred[0]["label"] for pred in classifier(images)]

print(f"Labels:      {labels}")
print(f"Predictions: {predictions}")
Labels:      [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]

Checkpoint conversion

The krasserm/perceiver-io-img-clf-mnist model has been created from a training checkpoint with:

from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint

convert_mnist_classifier_checkpoint(
    save_dir="krasserm/perceiver-io-img-clf-mnist",
    ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
    push_to_hub=True,
)
Downloads last month
6
Inference Examples
Inference API (serverless) has been turned off for this model.

Dataset used to train krasserm/perceiver-io-img-clf-mnist

Collection including krasserm/perceiver-io-img-clf-mnist