Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
from collections import defaultdict | |
import streamlit as st | |
from streamlit_drawable_canvas import st_canvas | |
import matplotlib as mpl | |
from model import device, segment_image, inpaint | |
# define utils and helpers | |
def closest_number(n, m=8): | |
""" Obtains closest number to n that is divisble by m """ | |
return int(n/m) * m | |
def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'): | |
# Create a canvas component | |
canvas_result = st_canvas( | |
fill_color="rgba(255, 165, 0, 0.3)", | |
stroke_width=2, | |
stroke_color="#000000", | |
background_image=image, | |
update_streamlit=True, | |
height=height, | |
width=width, | |
drawing_mode=drawing_mode, | |
point_display_radius=5, | |
key="canvas", | |
) | |
# get selections from mask | |
if canvas_result.json_data is not None: | |
objects = pd.json_normalize(canvas_result.json_data["objects"]) | |
for col in objects.select_dtypes(include=["object"]).columns: | |
objects[col] = objects[col].astype("str") | |
if len(objects) > 0: | |
left_coords = objects.left.to_numpy() | |
top_coords = objects.top.to_numpy() | |
right_coords = left_coords + objects.width.to_numpy() | |
bottom_coords = top_coords + objects.height.to_numpy() | |
# add selections to mask | |
for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords): | |
cropped = image.crop((left, top, right, bottom)) | |
st.image(cropped) | |
mask[top:bottom, left:right] = 255 | |
st.header("Mask Created!") | |
st.image(mask) | |
return mask | |
def get_mask(image, edit_method, height, width): | |
mask = np.zeros((height, width), dtype=np.uint8) | |
if edit_method == "AutoSegment Area": | |
# get displayable segmented image | |
seg_prediction, segment_labels = segment_image(image) | |
seg = seg_prediction['segmentation'].cpu().numpy() | |
viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg)) | |
seg_image = Image.fromarray(np.uint8(viridis(seg)*255)) | |
st.image(seg_image) | |
# prompt user to select valid labels to edit | |
seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values())) | |
if seg_selections: | |
tgts = [] | |
for s in seg_selections: | |
tgts.append(s[0]) | |
mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255) | |
st.header("Mask Created!") | |
st.image(mask) | |
elif edit_method == "Draw Custom Area": | |
mask = get_mask_from_rectangles(image, mask, height, width) | |
return mask | |
if __name__ == '__main__': | |
st.title("Stable Edit") | |
st.title("Edit your photos with Stable Diffusion!") | |
st.write(f"Device found: {device}") | |
sf = st.text_input("Please enter resizing scale factor to downsize image (default=2)", value="2") | |
try: | |
sf = int(sf) | |
except: | |
sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it") | |
sf = 2 | |
# upload image | |
filename = st.file_uploader("upload an image") | |
if filename: | |
image = Image.open(filename) | |
width, height = image.size | |
width, height = closest_number(width/sf), closest_number(height/sf) | |
image = image.resize((width, height)) | |
st.image(image) | |
# st.write(f"{width} {height}") | |
# Select an editing method | |
edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area")) | |
if edit_method: | |
mask = get_mask(image, edit_method, height, width) | |
# get inpainted images | |
prompt = st.text_input("Please enter prompt for image inpainting", value="") | |
if prompt: # and isinstance(seed, int): | |
st.write("Inpainting Images, patience is a virtue and this will take a while to run on a CPU :)") | |
images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3) | |
# display all images | |
st.write("Original Image") | |
st.image(image) | |
for i, img in enumerate(images, 1): | |
st.write(f"result: {i}") | |
st.image(img) | |