File size: 5,437 Bytes
f8a748e
7f48662
 
1876385
f8a748e
7f48662
f8a748e
200a130
 
 
f8a748e
1876385
200a130
44bc074
7f48662
 
 
 
 
f8a748e
7f48662
 
 
f8a748e
7f48662
200a130
7f48662
200a130
7f48662
200a130
44bc074
7f48662
 
 
200a130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44bc074
200a130
 
 
 
 
 
 
 
 
 
44bc074
200a130
 
 
 
44bc074
 
200a130
7f48662
 
 
44bc074
7f48662
 
 
200a130
44bc074
200a130
 
 
 
 
 
 
 
7f48662
200a130
 
 
 
 
 
 
7f48662
200a130
 
 
 
 
 
 
 
44bc074
 
 
 
7f48662
 
200a130
7f48662
 
 
 
 
 
 
200a130
 
 
 
 
 
 
 
 
44bc074
200a130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44bc074
200a130
 
f8a748e
 
7f48662
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
from PIL import Image
import os
import spaces

from OmniGen import OmniGenPipeline

pipe = OmniGenPipeline.from_pretrained(
    "shitao/tmp-preview"
)

@spaces.GPU
# 示例处理函数:生成图像
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
    input_images = [img1, img2, img3]
    # 去除 None
    input_images = [img for img in input_images if img is not None]
    if len(input_images) == 0:
        input_images = None

    output = pipe(
        prompt=text,
        input_images=input_images,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        img_guidance_scale=1.6,
        num_inference_steps=inference_steps,
        separate_cfg_infer=True,
        use_kv_cache=False,
        seed=seed,
    )
    img = output[0]
    return img
# def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
#     input_images = []
#     if img1:
#         input_images.append(Image.open(img1))
#     if img2:
#         input_images.append(Image.open(img2))
#     if img3:
#         input_images.append(Image.open(img3))
        
#     return input_images[0] if input_images else None


def get_example():
    case = [
        [
            "A woman holds a bouquet of flowers and faces the camera. Thw woman is the one in <img><|image_1|></img>.",
            "./imgs/test_cases/liuyifei.png",
            None,
            None,
            1024,
            1024,
            3.0,
            20,
            42,
        ],
        [
            "Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.",
            "./imgs/test_cases/img1.jpg",
            "./imgs/test_cases/img2.jpg",
            "./imgs/test_cases/img3.jpg",
            1024,
            1024,
            3.0,
            20,
            42,
        ],
    ]
    return case

def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):    
    return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed)


# Gradio 接口
with gr.Blocks() as demo:
    gr.Markdown("# OmniGen: Unified Image Generation")
    with gr.Row():
        with gr.Column():
            # 文本输入框
            prompt_input = gr.Textbox(
                label="Enter your prompt, use <img><|image_i|></img> tokens for images", placeholder="Type your prompt here..."
            )

            with gr.Row(equal_height=True):
                # 图片上传框
                image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
                image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
                image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")

            # 高度和宽度滑块
            height_input = gr.Slider(
                label="Height", minimum=256, maximum=2048, value=1024, step=16
            )
            width_input = gr.Slider(
                label="Width", minimum=256, maximum=2048, value=1024, step=16
            )

            # 引导尺度输入
            guidance_scale_input = gr.Slider(
                label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1
            )

            num_inference_steps = gr.Slider(
                label="Inference Steps", minimum=1, maximum=50, value=50, step=1
            )

            seed_input = gr.Slider(
                label="Seed", minimum=0, maximum=2147483647, value=42, step=1
            )

            # 生成按钮
            generate_button = gr.Button("Generate Image")

        with gr.Column():
            # 输出图像框
            output_image = gr.Image(label="Output Image")

    # 按钮点击事件
    generate_button.click(
        generate_image,
        inputs=[
            prompt_input,
            image_input_1,
            image_input_2,
            image_input_3,
            height_input,
            width_input,
            guidance_scale_input,
            num_inference_steps,
            seed_input,
        ],
        outputs=output_image,
    )

    gr.Examples(
        examples=get_example(),
        fn=run_for_examples,
        inputs=[
            prompt_input,
            image_input_1,
            image_input_2,
            image_input_3,
            height_input,
            width_input,
            guidance_scale_input,
            num_inference_steps,
            seed_input,
        ],
        outputs=output_image,
    )

# 启动应用
demo.launch()