defectdetection / app06.py
nazlicanto's picture
Resolved merge conflicts
4446708
raw
history blame
No virus
1.45 kB
import streamlit as st
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import numpy as np
import torch
# Load the model and processor
model_dir = "/home/user/app/defectdetection/model"
model = SegformerForSemanticSegmentation.from_pretrained(model_path)
preprocessor = SegformerImageProcessor.from_pretrained(model_path)
model = SegformerForSemanticSegmentation.from_pretrained(model_dir)
processor = SegformerImageProcessor.from_pretrained(model_dir)
model.eval()
st.title("PCB Defect Detection")
# Upload image in Streamlit
uploaded_file = st.file_uploader("Upload a PCB image", type=["jpg", "png"])
if uploaded_file:
# Preprocess the image
test_image = Image.open(uploaded_file).convert("RGB")
inputs = processor(images=test_image, return_tensors="pt")
# Model inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process
semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[test_image.size[::-1]])[0]
semantic_map = np.uint8(semantic_map)
semantic_map[semantic_map==1] = 255
semantic_map[semantic_map==2] = 195
semantic_map[semantic_map==3] = 135
semantic_map[semantic_map==4] = 75
# Display the results
st.image(test_image, caption="Uploaded Image", use_column_width=True)
st.image(semantic_map, caption="Predicted Defects", use_column_width=True, channels="GRAY")