import pandas as pd | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
# Load the model | |
model = torch.jit.load("SuSy.pt") | |
# Load patch | |
patch = Image.open("midjourney-images-example-patch0.png") | |
# Transform patch to tensor | |
patch = transforms.PILToTensor()(patch).unsqueeze(0) / 255. | |
# Predict patch | |
model.eval() | |
with torch.no_grad(): | |
preds = model(patch) | |
# Print results | |
classes = ['authentic', 'dalle-3-images', 'diffusiondb', 'midjourney-images', 'midjourney_tti', 'realisticSDXL'] | |
result = pd.DataFrame(preds.numpy(), columns=classes) | |
print(result) |