File size: 9,732 Bytes
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8c7d9d
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cae6b6
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ea3bd8
0cae6b6
 
6e3a1b8
0cae6b6
5a67d9b
 
2379311
5a67d9b
 
 
 
 
2379311
 
8604dfb
2379311
5a67d9b
8604dfb
3d4d894
 
 
f6b9c19
 
 
 
 
 
 
 
3d4d894
 
 
 
d0613b1
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import random
import numpy as np
import os
import time

# get OS environment variables of EMULATED
EMULATED = os.environ.get('EMULATED', False)
print(EMULATED)

if not EMULATED:
    from models import make_image_controlnet, make_inpainting, segment_image
else:
    from models_stub import make_image_controlnet, make_inpainting, segment_image

from config import HEIGHT, WIDTH, POS_PROMPT, NEG_PROMPT, COLOR_MAPPING, map_colors, map_colors_rgb
from palette import COLOR_MAPPING_CATEGORY
from preprocessing import preprocess_seg_mask, get_image, get_mask

# wide layout
st.set_page_config(layout="wide")


def on_upload() -> None:
    """Upload image to the canvas."""
    if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
        image = Image.open(st.session_state['input_image']).convert('RGB')
        st.session_state['initial_image'] = image
        # st.session_state['history'] = [{'image': image.resize((512, 512)),
        #                                 'message': "initial image",
        #                                 "positive_prompt": "",
        #                                 "negative_prompt": "",
        #                                 "index": 0}]


def check_reset_state() -> bool:
    """Check whether the UI elements need to be reset
    Returns:
        bool: True if the UI elements need to be reset, False otherwise
    """
    if ('reset_canvas' in st.session_state and st.session_state['reset_canvas']):
        st.session_state['reset_canvas'] = False
        return True
    st.session_state['reset_canvas'] = False
    return False


def move_image(source: Union[str, Image.Image],
               dest: str,
               rerun: bool = True,
               remove_state: bool = True) -> None:
    """Move image from source to destination.
    Args:
        source (Union[str, Image.Image]): source image
        dest (str): destination image location
        rerun (bool, optional): rerun streamlit. Defaults to True.
        remove_state (bool, optional): remove the canvas state. Defaults to True.
    """
    source_image = source if isinstance(source, Image.Image) else st.session_state[source]

    if remove_state:
        st.session_state['reset_canvas'] = True

    st.session_state[dest] = source_image
    if rerun:
        st.experimental_rerun()


def on_change_radio() -> None:
    """Reset the UI elements when the radio button is changed."""
    st.session_state['reset_canvas'] = True


def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
    canvas_dict = dict(
        fill_color=canvas_color,
        stroke_color=canvas_color,
        background_color="#FFFFFF",
        background_image=st.session_state['initial_image'] if 'initial_image' in st.session_state else None,
        stroke_width=brush,
        initial_drawing={'version': '4.4.0', 'objects': []} if _reset_state else None,
        update_streamlit=True,
        height=512,
        width=512,
        drawing_mode=paint_mode,
        key="canvas",
    )
    return canvas_dict  

def make_prompt_row():
    col_0_0, col_0_1 = st.columns(2)
    with col_0_0:
        st.text_input(label="Positive prompt", value="", key='positive_prompt')
    with col_0_1:
        st.text_input(label="Negative prompt", value="", key='negative_prompt')

def make_sidebar():
    with st.sidebar:
        input_image = st.file_uploader("", type=["png", "jpg"], key='input_image', on_change=on_upload)
        generation_mode = st.selectbox("Generation mode", ["Segmentation conditioning", "Inpainting"], on_change=on_change_radio)

        paint_mode = st.sidebar.selectbox("Painting mode", ("freedraw", "polygon"))
        if paint_mode == "freedraw":
            brush = st.slider("Stroke width", 5, 140, 100, key='slider_seg')
        else:
            brush = 5

        if generation_mode == "Segmentation conditioning":
            category_chooser = st.sidebar.selectbox("Filter on category", list(
                COLOR_MAPPING_CATEGORY.keys()), index=0, key='category_chooser')

            chosen_colors = list(COLOR_MAPPING_CATEGORY[category_chooser].keys())

            color_chooser = st.sidebar.selectbox(
                "Choose a color", chosen_colors, index=0, format_func=map_colors, key='color_chooser'
            )
        else:
            color_chooser = "#000000"
    return input_image, generation_mode, brush, color_chooser, paint_mode


def make_output_image():
    if 'output_image' in st.session_state:
        output_image = st.session_state['output_image']
        if isinstance(output_image, np.ndarray):
            output_image = Image.fromarray(output_image)

        if isinstance(output_image, Image.Image):
            output_image = output_image.resize((512, 512))
    else:
        output_image = Image.new('RGB', (512, 512), (255, 255, 255))

    st.write("#### Output image")
    st.image(output_image, width=512)
    if st.button("Move to input image"):
        move_image('output_image', 'initial_image', remove_state=True, rerun=True)

def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
    st.write("#### Input image")
    canvas_dict = make_canvas_dict(
        canvas_color=canvas_color,
        paint_mode=paint_mode,
        brush=brush,
        _reset_state=_reset_state
    )

    if generation_mode == "Segmentation conditioning":
        canvas = st_canvas(
            **canvas_dict,
        )
        if 'seg' not in st.session_state:
            with st.spinner(text="Preparing image segmentation"):
                image = get_image()
                real_seg = np.array(segment_image(Image.fromarray(image)))
                st.session_state['seg'] = real_seg

        if 'unique_colors' not in st.session_state:
            real_seg = st.session_state['seg']
            unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0)
            unique_colors = [tuple(color) for color in unique_colors]
            st.session_state['unique_colors'] = unique_colors

        chosen_colors = st.multiselect(
            label="Choose colors",
            options=st.session_state['unique_colors'],
            key='chosen_colors',
            format_func=map_colors_rgb
        )
        print(st.session_state['unique_colors'])

        if st.button("generate image", key='generate_button'):
            image = get_image()
            print(chosen_colors)

            segmentation = st.session_state['seg']
            mask = np.zeros_like(segmentation)
            for color in chosen_colors:
                # if the color is in the segmentation, set mask to 1
                mask[np.where((segmentation == color).all(axis=2))] = 1
            print(mask)

            with st.spinner(text="Generating image"):
                result_image = make_image_controlnet(image=image,
                                                        mask_image=mask,
                                                        controlnet_conditioning_image=segmentation,
                                                        positive_prompt=st.session_state['positive_prompt'],
                                                        negative_prompt=st.session_state['negative_prompt'],
                                                        seed=random.randint(0, 100000) # nosec
                                                        )[0]
                if isinstance(result_image, np.ndarray):
                    result_image = Image.fromarray(result_image)
                st.session_state['output_image'] = result_image

    elif generation_mode == "Inpainting":
        image = get_image()

        canvas = st_canvas(
            **canvas_dict,
        )

        if st.button("generate images", key='generate_button'):
            canvas_mask = canvas.image_data
            if not isinstance(canvas_mask, np.ndarray):
                canvas_mask = np.array(canvas_mask)
            mask = get_mask(canvas_mask)

            with st.spinner(text="Generating new images"):
                print("Making image")
                result_image = make_inpainting(positive_prompt=st.session_state['positive_prompt'],
                                                image=image,
                                                mask_image=mask,
                                                negative_prompt=st.session_state['negative_prompt'],
                                                )[0]
                if isinstance(result_image, np.ndarray):
                    result_image = Image.fromarray(result_image)
                st.session_state['output_image'] = result_image

def main():
    # center text
    st.write("## Controlnet sprint - interior design", unsafe_allow_html=True)

    input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()

    # check if there is an input_image
    if not ('input_image' in st.session_state and st.session_state['input_image'] is not None):
        print("Image not present")
        st.success("Upload an image to start")
    else:
        make_prompt_row()

        _reset_state = check_reset_state()

        col1, col2 = st.columns(2)
        with col1:
            make_editing_canvas(canvas_color=color_chooser,
                                brush=brush,
                                _reset_state=_reset_state,
                                generation_mode=generation_mode,
                                paint_mode=paint_mode
                                )

        with col2:
            make_output_image()
            

if __name__ == "__main__":
    main()