jxtc's picture
docs: Update README
798944b verified
metadata
license: bsd-3-clause
base_model:
  - microsoft/resnet-50
pipeline_tag: image-feature-extraction

ResNet-50 Embeddings Only

This is a modified version of a standard ResNet-50 architecture, where the final, fully connected layer that does the classification, has been removed.

This effectively gives you the embeddings.

NB: You may want to flatten the embeddings, as it'll be of shape (1, 20248, 1, 1) otherwise.

Example

import onnxruntime
from PIL import Image
from torchvision import transforms


def load_and_preprocess_image(image_path):
    # Define the same preprocessing as used in training
    preprocess = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # Open the image file
    img = Image.open(image_path)

    # Preprocess the image
    img_preprocessed = preprocess(img)

    # Add batch dimension
    return img_preprocessed.unsqueeze(0).numpy()


onnx_model_path = "resnet50_embeddings.onnx"

session = onnxruntime.InferenceSession(onnx_model_path)

input_name = session.get_inputs()[0].name

# Load and preprocess an image (replace with your image path)
image_path = "disco-ball.jpg"
input_data = load_and_preprocess_image(image_path)

# Run inference
outputs = session.run(None, {input_name: input_data})

# The output should be a single tensor (the embeddings)
embeddings = outputs[0]

# Flatten the embeddings
embeddings = embeddings.reshape(embeddings.shape[0], -1)