File size: 9,306 Bytes
11edfba
 
 
 
 
 
 
 
 
 
 
d1ddda1
 
 
 
 
 
 
 
 
11edfba
 
d1ddda1
 
 
 
 
 
 
11edfba
 
 
 
 
 
d1ddda1
 
 
 
 
 
 
 
 
 
 
 
 
11edfba
 
 
 
 
 
 
 
 
 
 
 
 
 
d1ddda1
 
11edfba
d1ddda1
11edfba
 
 
d1ddda1
 
 
 
 
 
11edfba
 
 
 
 
 
 
d1ddda1
 
 
11edfba
 
d1ddda1
 
 
 
 
11edfba
 
d1ddda1
11edfba
 
 
 
 
 
d1ddda1
11edfba
 
 
 
 
 
 
d1ddda1
 
 
11edfba
d1ddda1
11edfba
 
 
 
 
 
d1ddda1
11edfba
 
d1ddda1
 
11edfba
d1ddda1
 
 
 
11edfba
d1ddda1
 
11edfba
d1ddda1
 
 
 
11edfba
d1ddda1
 
11edfba
d1ddda1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11edfba
 
 
d1ddda1
11edfba
 
 
 
d1ddda1
11edfba
 
d1ddda1
 
 
 
 
11edfba
d1ddda1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11edfba
 
 
 
d1ddda1
11edfba
d1ddda1
 
 
 
 
 
 
11edfba
d1ddda1
11edfba
d1ddda1
11edfba
 
d1ddda1
 
 
 
 
 
 
11edfba
 
 
 
d1ddda1
11edfba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import random
import numpy as np
import os
import time

from models import make_image_controlnet, make_inpainting
from segmentation import segment_image
from config import (
    HEIGHT,
    WIDTH,
    POS_PROMPT,
    NEG_PROMPT,
    COLOR_MAPPING,
    map_colors,
    map_colors_rgb,
)
from palette import COLOR_MAPPING_CATEGORY
from preprocessing import preprocess_seg_mask, get_image, get_mask
from explanation import (
    make_inpainting_explanation,
    make_regeneration_explanation,
    make_segmentation_explanation,
)
from colors import INTERIOR

# wide layout
st.set_page_config(layout="wide")


def on_upload() -> None:
    """Upload image to the canvas."""
    if (
        "input_image" in st.session_state
        and st.session_state["input_image"] is not None
    ):
        image = Image.open(st.session_state["input_image"]).convert("RGB")
        st.session_state["initial_image"] = image
        if "seg" in st.session_state:
            del st.session_state["seg"]
        if "unique_colors" in st.session_state:
            del st.session_state["unique_colors"]
        if "output_image" in st.session_state:
            del st.session_state["output_image"]


def make_image_row(image_0, image_1):
    col_0, col_1 = st.columns(2)
    with col_0:
        st.image(image_0, use_column_width=True)
    with col_1:
        st.image(image_1, use_column_width=True)


def check_reset_state() -> bool:
    """Check whether the UI elements need to be reset
    Returns:
        bool: True if the UI elements need to be reset, False otherwise
    """
    if "reset_canvas" in st.session_state and st.session_state["reset_canvas"]:
        st.session_state["reset_canvas"] = False
        return True
    st.session_state["reset_canvas"] = False
    return False


def move_image(
    source: Union[str, Image.Image],
    dest: str,
    rerun: bool = True,
    remove_state: bool = True,
) -> None:
    """Move image from source to destination.
    Args:
        source (Union[str, Image.Image]): source image
        dest (str): destination image location
        rerun (bool, optional): rerun streamlit. Defaults to True.
        remove_state (bool, optional): remove the canvas state. Defaults to True.
    """
    source_image = (
        source if isinstance(source, Image.Image) else st.session_state[source]
    )

    if remove_state:
        st.session_state["reset_canvas"] = True
        if "seg" in st.session_state:
            del st.session_state["seg"]
        if "unique_colors" in st.session_state:
            del st.session_state["unique_colors"]

    st.session_state[dest] = source_image
    st.session_state["dest"] = source_image
    if rerun:
        st.experimental_rerun()


def on_change_radio() -> None:
    """Reset the UI elements when the radio button is changed."""
    st.session_state["reset_canvas"] = True


def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
    canvas_dict = dict(
        fill_color=canvas_color,
        stroke_color=canvas_color,
        background_color="#FFFFFF",
        background_image=st.session_state["initial_image"]
        if "initial_image" in st.session_state
        else None,
        stroke_width=brush,
        initial_drawing={"version": "4.4.0", "objects": []} if _reset_state else None,
        update_streamlit=True,
        height=512,
        width=512,
        drawing_mode=paint_mode,
        key="canvas",
    )
    return canvas_dict


def make_output_image():
    st.write("#### After")

    output_images = st.session_state["output_images"] if st.session_state["output_images"] else []
    for output_image in output_images:
      if isinstance(output_image, np.ndarray):
          output_image = Image.fromarray(output_image)

      if isinstance(output_image, Image.Image):
          output_image = output_image.resize((512, 512))

    if len(output_images) >= 1:
        st.image(output_images[0], width=512)
    else:
        st.spinner()

    if len(output_images) >= 2:
        st.image(output_images[1], width=512)
    else:
        st.spinner()

    if len(output_images) >= 3:
        st.image(output_images[2], width=512)
    else:
        st.spinner()

    if len(output_images) >= 4:
        st.image(output_images[3], width=512)
    else:
        st.spinner()

def generate():
    image = get_image()

    segmentation = st.session_state["seg"]
    mask = np.zeros_like(segmentation)
    interior = list(INTERIOR.values())
    print(interior)
    for color in st.session_state["unique_colors"]:
        print(map_colors_rgb(color))
        # 壁や床を変えると違う家になってしまう
        if map_colors_rgb(color) in interior:
            continue
        # if the color is in the segmentation, set mask to 1
        mask[np.where((segmentation == color).all(axis=2))] = 1

    positive_prompt = "a photograph of a room, interior design, 4k, high resolution"
    negative_prompt = "lowres, watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, out of focus, out of frame, surreal, ugly"

    with st.spinner(text="生成中…"):
        st.session_state["output_images"] = []
        result_image = make_image_controlnet(
            image=image,
            mask_image=mask,
            controlnet_conditioning_image=segmentation,
            positive_prompt=positive_prompt,
            negative_prompt=negative_prompt,
            seed=random.randint(0, 100000),  # nosec
        )
        if isinstance(result_image, np.ndarray):
            result_image = Image.fromarray(result_image)
        st.session_state["output_images"].append(result_image)

        result_image = make_image_controlnet(
            image=image,
            mask_image=mask,
            controlnet_conditioning_image=segmentation,
            positive_prompt=positive_prompt,
            negative_prompt=negative_prompt,
            seed=random.randint(0, 100000),  # nosec
        )
        if isinstance(result_image, np.ndarray):
            result_image = Image.fromarray(result_image)
        st.session_state["output_images"].append(result_image)

        result_image = make_image_controlnet(
            image=image,
            mask_image=mask,
            controlnet_conditioning_image=segmentation,
            positive_prompt=positive_prompt,
            negative_prompt=negative_prompt,
            seed=random.randint(0, 100000),  # nosec
        )
        if isinstance(result_image, np.ndarray):
            result_image = Image.fromarray(result_image)
        st.session_state["output_images"].append(result_image)

        result_image = make_image_controlnet(
            image=image,
            mask_image=mask,
            controlnet_conditioning_image=segmentation,
            positive_prompt=positive_prompt,
            negative_prompt=negative_prompt,
            seed=random.randint(0, 100000),  # nosec
        )
        if isinstance(result_image, np.ndarray):
            result_image = Image.fromarray(result_image)
        st.session_state["output_images"].append(result_image)


def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
    st.write("#### Before")
    canvas_dict = make_canvas_dict(
        canvas_color=canvas_color,
        paint_mode=paint_mode,
        brush=brush,
        _reset_state=_reset_state,
    )

    canvas = st_canvas(
        **canvas_dict,
    )
    if "seg" not in st.session_state:
        with st.spinner(text="生成中…"):
            image = get_image()
            real_seg = np.array(segment_image(Image.fromarray(image)))
            st.session_state["seg"] = real_seg

    if "unique_colors" not in st.session_state:
        real_seg = st.session_state["seg"]
        unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0)
        unique_colors = [tuple(color) for color in unique_colors]
        st.session_state["unique_colors"] = unique_colors

    # chosen_colors = st.multiselect(
    #     label="Choose which concepts you want to regenerate in the image",
    #     options=st.session_state['unique_colors'],
    #     key='chosen_colors',
    #     default=st.session_state['unique_colors'],
    #     format_func=map_colors_rgb,
    # )

    if "output_images" not in st.session_state:
        generate()
    else:
        if st.button("再生成"):
            generate()


def main():
    # center text
    st.write("## Interior AI", unsafe_allow_html=True)

    input_image = st.file_uploader(
        "部屋の画像を選択してください", type=["png", "jpg"], key="input_image", on_change=on_upload
    )
    generation_mode = "Regenerate"
    color_chooser = "rgba(0, 0, 0, 0.0)"
    paint_mode = "freedraw"
    brush = 0

    _reset_state = check_reset_state()

    if input_image:
        col1, col2 = st.columns(2)
        with col1:
            make_editing_canvas(
                canvas_color=color_chooser,
                brush=brush,
                _reset_state=_reset_state,
                generation_mode=generation_mode,
                paint_mode=paint_mode,
            )

        with col2:
            make_output_image()


if __name__ == "__main__":
    main()