SuSy / test_patch.py
pbernabeu
Release Model
1051963
raw
history blame
578 Bytes
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
# Load the model
model = torch.jit.load("SuSy.pt")
# Load patch
patch = Image.open("midjourney-images-example-patch0.png")
# Transform patch to tensor
patch = transforms.PILToTensor()(patch).unsqueeze(0) / 255.
# Predict patch
model.eval()
with torch.no_grad():
preds = model(patch)
# Print results
classes = ['authentic', 'dalle-3-images', 'diffusiondb', 'midjourney-images', 'midjourney_tti', 'realisticSDXL']
result = pd.DataFrame(preds.numpy(), columns=classes)
print(result)