Spaces:
Runtime error
Runtime error
import streamlit as st | |
from streamlit_drawable_canvas import st_canvas | |
from PIL import Image | |
from typing import Union | |
import random | |
import numpy as np | |
import os | |
import time | |
from models import make_image_controlnet, make_inpainting | |
from segmentation import segment_image | |
from config import ( | |
HEIGHT, | |
WIDTH, | |
POS_PROMPT, | |
NEG_PROMPT, | |
COLOR_MAPPING, | |
map_colors, | |
map_colors_rgb, | |
) | |
from palette import COLOR_MAPPING_CATEGORY | |
from preprocessing import preprocess_seg_mask, get_image, get_mask | |
from explanation import ( | |
make_inpainting_explanation, | |
make_regeneration_explanation, | |
make_segmentation_explanation, | |
) | |
from colors import INTERIOR | |
# wide layout | |
st.set_page_config(layout="wide") | |
def on_upload() -> None: | |
"""Upload image to the canvas.""" | |
if ( | |
"input_image" in st.session_state | |
and st.session_state["input_image"] is not None | |
): | |
image = Image.open(st.session_state["input_image"]).convert("RGB") | |
st.session_state["initial_image"] = image | |
if "seg" in st.session_state: | |
del st.session_state["seg"] | |
if "unique_colors" in st.session_state: | |
del st.session_state["unique_colors"] | |
if "output_image" in st.session_state: | |
del st.session_state["output_image"] | |
def make_image_row(image_0, image_1): | |
col_0, col_1 = st.columns(2) | |
with col_0: | |
st.image(image_0, use_column_width=True) | |
with col_1: | |
st.image(image_1, use_column_width=True) | |
def check_reset_state() -> bool: | |
"""Check whether the UI elements need to be reset | |
Returns: | |
bool: True if the UI elements need to be reset, False otherwise | |
""" | |
if "reset_canvas" in st.session_state and st.session_state["reset_canvas"]: | |
st.session_state["reset_canvas"] = False | |
return True | |
st.session_state["reset_canvas"] = False | |
return False | |
def move_image( | |
source: Union[str, Image.Image], | |
dest: str, | |
rerun: bool = True, | |
remove_state: bool = True, | |
) -> None: | |
"""Move image from source to destination. | |
Args: | |
source (Union[str, Image.Image]): source image | |
dest (str): destination image location | |
rerun (bool, optional): rerun streamlit. Defaults to True. | |
remove_state (bool, optional): remove the canvas state. Defaults to True. | |
""" | |
source_image = ( | |
source if isinstance(source, Image.Image) else st.session_state[source] | |
) | |
if remove_state: | |
st.session_state["reset_canvas"] = True | |
if "seg" in st.session_state: | |
del st.session_state["seg"] | |
if "unique_colors" in st.session_state: | |
del st.session_state["unique_colors"] | |
st.session_state[dest] = source_image | |
st.session_state["dest"] = source_image | |
if rerun: | |
st.experimental_rerun() | |
def on_change_radio() -> None: | |
"""Reset the UI elements when the radio button is changed.""" | |
st.session_state["reset_canvas"] = True | |
def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state): | |
canvas_dict = dict( | |
fill_color=canvas_color, | |
stroke_color=canvas_color, | |
background_color="#FFFFFF", | |
background_image=st.session_state["initial_image"] | |
if "initial_image" in st.session_state | |
else None, | |
stroke_width=brush, | |
initial_drawing={"version": "4.4.0", "objects": []} if _reset_state else None, | |
update_streamlit=True, | |
height=512, | |
width=512, | |
drawing_mode=paint_mode, | |
key="canvas", | |
) | |
return canvas_dict | |
def make_output_image(): | |
st.write("#### After") | |
output_images = st.session_state["output_images"] if st.session_state["output_images"] else [] | |
for output_image in output_images: | |
if isinstance(output_image, np.ndarray): | |
output_image = Image.fromarray(output_image) | |
if isinstance(output_image, Image.Image): | |
output_image = output_image.resize((512, 512)) | |
if len(output_images) >= 1: | |
st.image(output_images[0], width=512) | |
else: | |
st.spinner() | |
if len(output_images) >= 2: | |
st.image(output_images[1], width=512) | |
else: | |
st.spinner() | |
if len(output_images) >= 3: | |
st.image(output_images[2], width=512) | |
else: | |
st.spinner() | |
if len(output_images) >= 4: | |
st.image(output_images[3], width=512) | |
else: | |
st.spinner() | |
def generate(): | |
image = get_image() | |
segmentation = st.session_state["seg"] | |
mask = np.zeros_like(segmentation) | |
interior = list(INTERIOR.values()) | |
print(interior) | |
for color in st.session_state["unique_colors"]: | |
print(map_colors_rgb(color)) | |
# 壁や床を変えると違う家になってしまう | |
if map_colors_rgb(color) in interior: | |
continue | |
# if the color is in the segmentation, set mask to 1 | |
mask[np.where((segmentation == color).all(axis=2))] = 1 | |
positive_prompt = "a photograph of a room, interior design, 4k, high resolution" | |
negative_prompt = "lowres, watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, out of focus, out of frame, surreal, ugly" | |
with st.spinner(text="生成中…"): | |
st.session_state["output_images"] = [] | |
result_image = make_image_controlnet( | |
image=image, | |
mask_image=mask, | |
controlnet_conditioning_image=segmentation, | |
positive_prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
seed=random.randint(0, 100000), # nosec | |
) | |
if isinstance(result_image, np.ndarray): | |
result_image = Image.fromarray(result_image) | |
st.session_state["output_images"].append(result_image) | |
result_image = make_image_controlnet( | |
image=image, | |
mask_image=mask, | |
controlnet_conditioning_image=segmentation, | |
positive_prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
seed=random.randint(0, 100000), # nosec | |
) | |
if isinstance(result_image, np.ndarray): | |
result_image = Image.fromarray(result_image) | |
st.session_state["output_images"].append(result_image) | |
result_image = make_image_controlnet( | |
image=image, | |
mask_image=mask, | |
controlnet_conditioning_image=segmentation, | |
positive_prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
seed=random.randint(0, 100000), # nosec | |
) | |
if isinstance(result_image, np.ndarray): | |
result_image = Image.fromarray(result_image) | |
st.session_state["output_images"].append(result_image) | |
result_image = make_image_controlnet( | |
image=image, | |
mask_image=mask, | |
controlnet_conditioning_image=segmentation, | |
positive_prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
seed=random.randint(0, 100000), # nosec | |
) | |
if isinstance(result_image, np.ndarray): | |
result_image = Image.fromarray(result_image) | |
st.session_state["output_images"].append(result_image) | |
def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode): | |
st.write("#### Before") | |
canvas_dict = make_canvas_dict( | |
canvas_color=canvas_color, | |
paint_mode=paint_mode, | |
brush=brush, | |
_reset_state=_reset_state, | |
) | |
canvas = st_canvas( | |
**canvas_dict, | |
) | |
if "seg" not in st.session_state: | |
with st.spinner(text="生成中…"): | |
image = get_image() | |
real_seg = np.array(segment_image(Image.fromarray(image))) | |
st.session_state["seg"] = real_seg | |
if "unique_colors" not in st.session_state: | |
real_seg = st.session_state["seg"] | |
unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0) | |
unique_colors = [tuple(color) for color in unique_colors] | |
st.session_state["unique_colors"] = unique_colors | |
# chosen_colors = st.multiselect( | |
# label="Choose which concepts you want to regenerate in the image", | |
# options=st.session_state['unique_colors'], | |
# key='chosen_colors', | |
# default=st.session_state['unique_colors'], | |
# format_func=map_colors_rgb, | |
# ) | |
if "output_images" not in st.session_state: | |
generate() | |
else: | |
if st.button("再生成"): | |
generate() | |
def main(): | |
# center text | |
st.write("## Interior AI", unsafe_allow_html=True) | |
input_image = st.file_uploader( | |
"部屋の画像を選択してください", type=["png", "jpg"], key="input_image", on_change=on_upload | |
) | |
generation_mode = "Regenerate" | |
color_chooser = "rgba(0, 0, 0, 0.0)" | |
paint_mode = "freedraw" | |
brush = 0 | |
_reset_state = check_reset_state() | |
if input_image: | |
col1, col2 = st.columns(2) | |
with col1: | |
make_editing_canvas( | |
canvas_color=color_chooser, | |
brush=brush, | |
_reset_state=_reset_state, | |
generation_mode=generation_mode, | |
paint_mode=paint_mode, | |
) | |
with col2: | |
make_output_image() | |
if __name__ == "__main__": | |
main() | |