Personal_Color / app.py
kikuepi's picture
Update app.py
bebf967 verified
raw
history blame contribute delete
No virus
1.99 kB
from torchvision import models, transforms
from PIL import Image
import torch
import torch.nn as nn
import io
import streamlit as st
import time
st.title("パーソナルカラー診断AI")
SIZE = 224
MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)
transform = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
model = models.resnet152(pretrained=True)
n_classes = 4
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('Resnet_2024_0214_version1', map_location=device))
model.to(device)
model.eval()
view_flag = True
skip = False
def predict_image(img):
img = img.convert('RGB')
img_transformed = transform(img)
inputs = img_transformed.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
return preds.item()
uploaded_file = st.file_uploader('Choose an image...', type=['jpg', 'png'])
if uploaded_file:
img = Image.open(uploaded_file)
st.image(img, caption="Uploaded Image", use_column_width=True)
pred = predict_image(img)
if pred == 0:
season_type = "秋"
elif pred == 1:
season_type = "春"
elif pred == 2:
season_type = "夏"
else:
season_type = "冬"
if 'show_video' not in st.session_state:
st.session_state.show_video = False
if 'skip' not in st.session_state:
st.session_state.skip = False
if 'result' not in st.session_state:
st.session_state.result = False
st.write(f"パーソナルカラー診断結果:{season_type} ")
st.write("あなたにおすすめの色はこちらです")
st.session_state.result = True
st.image(f"{season_type}.png")
st.write(
"""
あなたにおすすめの商品はこちらです
""")
st.image("服.png")