import streamlit as st from streamlit_drawable_canvas import st_canvas from PIL import Image from typing import Union import random import numpy as np import os import time from models import make_image_controlnet, make_inpainting from segmentation import segment_image from config import ( HEIGHT, WIDTH, POS_PROMPT, NEG_PROMPT, COLOR_MAPPING, map_colors, map_colors_rgb, ) from palette import COLOR_MAPPING_CATEGORY from preprocessing import preprocess_seg_mask, get_image, get_mask from explanation import ( make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation, ) from colors import INTERIOR # wide layout st.set_page_config(layout="wide") def on_upload() -> None: """Upload image to the canvas.""" if ( "input_image" in st.session_state and st.session_state["input_image"] is not None ): image = Image.open(st.session_state["input_image"]).convert("RGB") st.session_state["initial_image"] = image if "seg" in st.session_state: del st.session_state["seg"] if "unique_colors" in st.session_state: del st.session_state["unique_colors"] if "output_image" in st.session_state: del st.session_state["output_image"] def make_image_row(image_0, image_1): col_0, col_1 = st.columns(2) with col_0: st.image(image_0, use_column_width=True) with col_1: st.image(image_1, use_column_width=True) def check_reset_state() -> bool: """Check whether the UI elements need to be reset Returns: bool: True if the UI elements need to be reset, False otherwise """ if "reset_canvas" in st.session_state and st.session_state["reset_canvas"]: st.session_state["reset_canvas"] = False return True st.session_state["reset_canvas"] = False return False def move_image( source: Union[str, Image.Image], dest: str, rerun: bool = True, remove_state: bool = True, ) -> None: """Move image from source to destination. Args: source (Union[str, Image.Image]): source image dest (str): destination image location rerun (bool, optional): rerun streamlit. Defaults to True. remove_state (bool, optional): remove the canvas state. Defaults to True. """ source_image = ( source if isinstance(source, Image.Image) else st.session_state[source] ) if remove_state: st.session_state["reset_canvas"] = True if "seg" in st.session_state: del st.session_state["seg"] if "unique_colors" in st.session_state: del st.session_state["unique_colors"] st.session_state[dest] = source_image st.session_state["dest"] = source_image if rerun: st.experimental_rerun() def on_change_radio() -> None: """Reset the UI elements when the radio button is changed.""" st.session_state["reset_canvas"] = True def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state): canvas_dict = dict( fill_color=canvas_color, stroke_color=canvas_color, background_color="#FFFFFF", background_image=st.session_state["initial_image"] if "initial_image" in st.session_state else None, stroke_width=brush, initial_drawing={"version": "4.4.0", "objects": []} if _reset_state else None, update_streamlit=True, height=512, width=512, drawing_mode=paint_mode, key="canvas", ) return canvas_dict def make_output_image(): st.write("#### After") output_images = st.session_state["output_images"] if st.session_state["output_images"] else [] for output_image in output_images: if isinstance(output_image, np.ndarray): output_image = Image.fromarray(output_image) if isinstance(output_image, Image.Image): output_image = output_image.resize((512, 512)) if len(output_images) >= 1: st.image(output_images[0], width=512) else: st.spinner() if len(output_images) >= 2: st.image(output_images[1], width=512) else: st.spinner() if len(output_images) >= 3: st.image(output_images[2], width=512) else: st.spinner() if len(output_images) >= 4: st.image(output_images[3], width=512) else: st.spinner() def generate(): image = get_image() segmentation = st.session_state["seg"] mask = np.zeros_like(segmentation) interior = list(INTERIOR.values()) print(interior) for color in st.session_state["unique_colors"]: print(map_colors_rgb(color)) # 壁や床を変えると違う家になってしまう if map_colors_rgb(color) in interior: continue # if the color is in the segmentation, set mask to 1 mask[np.where((segmentation == color).all(axis=2))] = 1 positive_prompt = "a photograph of a room, interior design, 4k, high resolution" negative_prompt = "lowres, watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, out of focus, out of frame, surreal, ugly" with st.spinner(text="生成中…"): st.session_state["output_images"] = [] result_image = make_image_controlnet( image=image, mask_image=mask, controlnet_conditioning_image=segmentation, positive_prompt=positive_prompt, negative_prompt=negative_prompt, seed=random.randint(0, 100000), # nosec ) if isinstance(result_image, np.ndarray): result_image = Image.fromarray(result_image) st.session_state["output_images"].append(result_image) result_image = make_image_controlnet( image=image, mask_image=mask, controlnet_conditioning_image=segmentation, positive_prompt=positive_prompt, negative_prompt=negative_prompt, seed=random.randint(0, 100000), # nosec ) if isinstance(result_image, np.ndarray): result_image = Image.fromarray(result_image) st.session_state["output_images"].append(result_image) result_image = make_image_controlnet( image=image, mask_image=mask, controlnet_conditioning_image=segmentation, positive_prompt=positive_prompt, negative_prompt=negative_prompt, seed=random.randint(0, 100000), # nosec ) if isinstance(result_image, np.ndarray): result_image = Image.fromarray(result_image) st.session_state["output_images"].append(result_image) result_image = make_image_controlnet( image=image, mask_image=mask, controlnet_conditioning_image=segmentation, positive_prompt=positive_prompt, negative_prompt=negative_prompt, seed=random.randint(0, 100000), # nosec ) if isinstance(result_image, np.ndarray): result_image = Image.fromarray(result_image) st.session_state["output_images"].append(result_image) def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode): st.write("#### Before") canvas_dict = make_canvas_dict( canvas_color=canvas_color, paint_mode=paint_mode, brush=brush, _reset_state=_reset_state, ) canvas = st_canvas( **canvas_dict, ) if "seg" not in st.session_state: with st.spinner(text="生成中…"): image = get_image() real_seg = np.array(segment_image(Image.fromarray(image))) st.session_state["seg"] = real_seg if "unique_colors" not in st.session_state: real_seg = st.session_state["seg"] unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0) unique_colors = [tuple(color) for color in unique_colors] st.session_state["unique_colors"] = unique_colors # chosen_colors = st.multiselect( # label="Choose which concepts you want to regenerate in the image", # options=st.session_state['unique_colors'], # key='chosen_colors', # default=st.session_state['unique_colors'], # format_func=map_colors_rgb, # ) if "output_images" not in st.session_state: generate() else: if st.button("再生成"): generate() def main(): # center text st.write("## Interior AI", unsafe_allow_html=True) input_image = st.file_uploader( "部屋の画像を選択してください", type=["png", "jpg"], key="input_image", on_change=on_upload ) generation_mode = "Regenerate" color_chooser = "rgba(0, 0, 0, 0.0)" paint_mode = "freedraw" brush = 0 _reset_state = check_reset_state() if input_image: col1, col2 = st.columns(2) with col1: make_editing_canvas( canvas_color=color_chooser, brush=brush, _reset_state=_reset_state, generation_mode=generation_mode, paint_mode=paint_mode, ) with col2: make_output_image() if __name__ == "__main__": main()