interior-ai / app.py
TakahashiShotaro's picture
Update app.py
d1ddda1
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import random
import numpy as np
import os
import time
from models import make_image_controlnet, make_inpainting
from segmentation import segment_image
from config import (
HEIGHT,
WIDTH,
POS_PROMPT,
NEG_PROMPT,
COLOR_MAPPING,
map_colors,
map_colors_rgb,
)
from palette import COLOR_MAPPING_CATEGORY
from preprocessing import preprocess_seg_mask, get_image, get_mask
from explanation import (
make_inpainting_explanation,
make_regeneration_explanation,
make_segmentation_explanation,
)
from colors import INTERIOR
# wide layout
st.set_page_config(layout="wide")
def on_upload() -> None:
"""Upload image to the canvas."""
if (
"input_image" in st.session_state
and st.session_state["input_image"] is not None
):
image = Image.open(st.session_state["input_image"]).convert("RGB")
st.session_state["initial_image"] = image
if "seg" in st.session_state:
del st.session_state["seg"]
if "unique_colors" in st.session_state:
del st.session_state["unique_colors"]
if "output_image" in st.session_state:
del st.session_state["output_image"]
def make_image_row(image_0, image_1):
col_0, col_1 = st.columns(2)
with col_0:
st.image(image_0, use_column_width=True)
with col_1:
st.image(image_1, use_column_width=True)
def check_reset_state() -> bool:
"""Check whether the UI elements need to be reset
Returns:
bool: True if the UI elements need to be reset, False otherwise
"""
if "reset_canvas" in st.session_state and st.session_state["reset_canvas"]:
st.session_state["reset_canvas"] = False
return True
st.session_state["reset_canvas"] = False
return False
def move_image(
source: Union[str, Image.Image],
dest: str,
rerun: bool = True,
remove_state: bool = True,
) -> None:
"""Move image from source to destination.
Args:
source (Union[str, Image.Image]): source image
dest (str): destination image location
rerun (bool, optional): rerun streamlit. Defaults to True.
remove_state (bool, optional): remove the canvas state. Defaults to True.
"""
source_image = (
source if isinstance(source, Image.Image) else st.session_state[source]
)
if remove_state:
st.session_state["reset_canvas"] = True
if "seg" in st.session_state:
del st.session_state["seg"]
if "unique_colors" in st.session_state:
del st.session_state["unique_colors"]
st.session_state[dest] = source_image
st.session_state["dest"] = source_image
if rerun:
st.experimental_rerun()
def on_change_radio() -> None:
"""Reset the UI elements when the radio button is changed."""
st.session_state["reset_canvas"] = True
def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
canvas_dict = dict(
fill_color=canvas_color,
stroke_color=canvas_color,
background_color="#FFFFFF",
background_image=st.session_state["initial_image"]
if "initial_image" in st.session_state
else None,
stroke_width=brush,
initial_drawing={"version": "4.4.0", "objects": []} if _reset_state else None,
update_streamlit=True,
height=512,
width=512,
drawing_mode=paint_mode,
key="canvas",
)
return canvas_dict
def make_output_image():
st.write("#### After")
output_images = st.session_state["output_images"] if st.session_state["output_images"] else []
for output_image in output_images:
if isinstance(output_image, np.ndarray):
output_image = Image.fromarray(output_image)
if isinstance(output_image, Image.Image):
output_image = output_image.resize((512, 512))
if len(output_images) >= 1:
st.image(output_images[0], width=512)
else:
st.spinner()
if len(output_images) >= 2:
st.image(output_images[1], width=512)
else:
st.spinner()
if len(output_images) >= 3:
st.image(output_images[2], width=512)
else:
st.spinner()
if len(output_images) >= 4:
st.image(output_images[3], width=512)
else:
st.spinner()
def generate():
image = get_image()
segmentation = st.session_state["seg"]
mask = np.zeros_like(segmentation)
interior = list(INTERIOR.values())
print(interior)
for color in st.session_state["unique_colors"]:
print(map_colors_rgb(color))
# 壁や床を変えると違う家になってしまう
if map_colors_rgb(color) in interior:
continue
# if the color is in the segmentation, set mask to 1
mask[np.where((segmentation == color).all(axis=2))] = 1
positive_prompt = "a photograph of a room, interior design, 4k, high resolution"
negative_prompt = "lowres, watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, out of focus, out of frame, surreal, ugly"
with st.spinner(text="生成中…"):
st.session_state["output_images"] = []
result_image = make_image_controlnet(
image=image,
mask_image=mask,
controlnet_conditioning_image=segmentation,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
seed=random.randint(0, 100000), # nosec
)
if isinstance(result_image, np.ndarray):
result_image = Image.fromarray(result_image)
st.session_state["output_images"].append(result_image)
result_image = make_image_controlnet(
image=image,
mask_image=mask,
controlnet_conditioning_image=segmentation,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
seed=random.randint(0, 100000), # nosec
)
if isinstance(result_image, np.ndarray):
result_image = Image.fromarray(result_image)
st.session_state["output_images"].append(result_image)
result_image = make_image_controlnet(
image=image,
mask_image=mask,
controlnet_conditioning_image=segmentation,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
seed=random.randint(0, 100000), # nosec
)
if isinstance(result_image, np.ndarray):
result_image = Image.fromarray(result_image)
st.session_state["output_images"].append(result_image)
result_image = make_image_controlnet(
image=image,
mask_image=mask,
controlnet_conditioning_image=segmentation,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
seed=random.randint(0, 100000), # nosec
)
if isinstance(result_image, np.ndarray):
result_image = Image.fromarray(result_image)
st.session_state["output_images"].append(result_image)
def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
st.write("#### Before")
canvas_dict = make_canvas_dict(
canvas_color=canvas_color,
paint_mode=paint_mode,
brush=brush,
_reset_state=_reset_state,
)
canvas = st_canvas(
**canvas_dict,
)
if "seg" not in st.session_state:
with st.spinner(text="生成中…"):
image = get_image()
real_seg = np.array(segment_image(Image.fromarray(image)))
st.session_state["seg"] = real_seg
if "unique_colors" not in st.session_state:
real_seg = st.session_state["seg"]
unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0)
unique_colors = [tuple(color) for color in unique_colors]
st.session_state["unique_colors"] = unique_colors
# chosen_colors = st.multiselect(
# label="Choose which concepts you want to regenerate in the image",
# options=st.session_state['unique_colors'],
# key='chosen_colors',
# default=st.session_state['unique_colors'],
# format_func=map_colors_rgb,
# )
if "output_images" not in st.session_state:
generate()
else:
if st.button("再生成"):
generate()
def main():
# center text
st.write("## Interior AI", unsafe_allow_html=True)
input_image = st.file_uploader(
"部屋の画像を選択してください", type=["png", "jpg"], key="input_image", on_change=on_upload
)
generation_mode = "Regenerate"
color_chooser = "rgba(0, 0, 0, 0.0)"
paint_mode = "freedraw"
brush = 0
_reset_state = check_reset_state()
if input_image:
col1, col2 = st.columns(2)
with col1:
make_editing_canvas(
canvas_color=color_chooser,
brush=brush,
_reset_state=_reset_state,
generation_mode=generation_mode,
paint_mode=paint_mode,
)
with col2:
make_output_image()
if __name__ == "__main__":
main()