import streamlit as st import torch from torch import nn import torchvision.transforms as transforms from PIL import Image import numpy as np # Define the model architecture (same as before) class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(64 * 8 * 8, 512) self.fc2 = nn.Linear(512, 10) def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = x.view(-1, 64 * 8 * 8) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # Load the trained model @st.cache_resource def load_model(): model = SimpleCNN() model.load_state_dict(torch.load('cifar10_model.pth', map_location=torch.device('cpu'))) model.eval() return model # Define class names class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # Define image transformation transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Streamlit app st.title('CIFAR-10 Image Classification') uploaded_file = st.file_uploader("Choose an image...", type="jpg") if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image.', use_column_width=True) # Preprocess the image input_tensor = transform(image).unsqueeze(0) # Load model and make prediction model = load_model() with torch.no_grad(): output = model(input_tensor) # Get the predicted class _, predicted_idx = torch.max(output, 1) predicted_class = class_names[predicted_idx.item()] # Display the result st.write(f"Prediction: {predicted_class}") # Display probabilities probabilities = torch.nn.functional.softmax(output[0], dim=0) st.write("Class Probabilities:") for i, prob in enumerate(probabilities): st.write(f"{class_names[i]}: {prob.item():.2%}")