defectdetection / app06.py
nazlicanto's picture
Update app06.py
6e3cb6a
raw
history blame
No virus
1.48 kB
import streamlit as st
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, SegformerConfig
from PIL import Image
import numpy as np
import torch
import os
model_path = "/home/user/app/defectdetection/model"
config = SegformerConfig.from_json_file(os.path.join(model_path, "config.json"))
model = SegformerForSemanticSegmentation(config=config)
model.load_state_dict(torch.load(os.path.join(model_path, "pytorch_model.bin")))
preprocessor = SegformerImageProcessor.from_pretrained(model_path, local_files_only=True)
st.title("PCB Defect Detection")
# Upload image in Streamlit
uploaded_file = st.file_uploader("Upload a PCB image", type=["jpg", "png"])
if uploaded_file:
# Preprocess the image
test_image = Image.open(uploaded_file).convert("RGB")
inputs = preprocessor(images=test_image, return_tensors="pt")
# Model inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process
semantic_map = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[test_image.size[::-1]])[0]
semantic_map = np.uint8(semantic_map)
semantic_map[semantic_map==1] = 255
semantic_map[semantic_map==2] = 195
semantic_map[semantic_map==3] = 135
semantic_map[semantic_map==4] = 75
# Display the results
st.image(test_image, caption="Uploaded Image", use_column_width=True)
st.image(semantic_map, caption="Predicted Defects", use_column_width=True, channels="GRAY")