Shitao commited on
Commit
200a130
1 Parent(s): 84db7c2

add examples

Browse files
Files changed (1) hide show
  1. app.py +101 -24
app.py CHANGED
@@ -3,16 +3,15 @@ from PIL import Image
3
  import os
4
  import spaces
5
 
6
- # os.environ['CUDA_VISIBLE_DEVICES'] = '7'
7
-
8
  from OmniGen import OmniGenPipeline
9
 
10
- pipe = OmniGenPipeline.from_pretrained("shitao/tmp-preview")
11
- # pipe.to("cuda")
 
12
 
13
- # 示例处理函数:生成图像
14
  @spaces.GPU
15
- def generate_image(text, img1, img2, img3, height, width, guidance_scale):
 
16
  input_images = [img1, img2, img3]
17
  # 去除 None
18
  input_images = [img for img in input_images if img is not None]
@@ -24,38 +23,91 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale):
24
  input_images=input_images,
25
  height=height,
26
  width=width,
27
- guidance_scale=guidance_scale,
28
  img_guidance_scale=1.6,
 
29
  separate_cfg_infer=True,
30
- use_kv_cache=False
31
  )
32
  img = output[0]
33
  return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # Gradio 接口
36
  with gr.Blocks() as demo:
37
  gr.Markdown("## Text + Multiple Images to Image Generator")
38
-
39
  with gr.Row():
40
  with gr.Column():
41
  # 文本输入框
42
- prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here...")
43
-
44
- # 图片上传框
45
- image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
46
- image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
47
- image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
48
-
 
 
 
49
  # 高度和宽度滑块
50
- height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=1024, step=16)
51
- width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1024, step=16)
52
-
 
 
 
 
53
  # 引导尺度输入
54
- guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
55
-
 
 
 
 
 
 
56
  # 生成按钮
57
  generate_button = gr.Button("Generate Image")
58
-
59
  with gr.Column():
60
  # 输出图像框
61
  output_image = gr.Image(label="Output Image")
@@ -63,8 +115,33 @@ with gr.Blocks() as demo:
63
  # 按钮点击事件
64
  generate_button.click(
65
  generate_image,
66
- inputs=[prompt_input, image_input_1, image_input_2, image_input_3, height_input, width_input, guidance_scale_input],
67
- outputs=output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
 
70
  # 启动应用
 
3
  import os
4
  import spaces
5
 
 
 
6
  from OmniGen import OmniGenPipeline
7
 
8
+ pipe = OmniGenPipeline.from_pretrained(
9
+ "shitao/tmp-preview"
10
+ )
11
 
 
12
  @spaces.GPU
13
+ # 示例处理函数:生成图像
14
+ def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
15
  input_images = [img1, img2, img3]
16
  # 去除 None
17
  input_images = [img for img in input_images if img is not None]
 
23
  input_images=input_images,
24
  height=height,
25
  width=width,
26
+ guidance_scale=guidance_scale,
27
  img_guidance_scale=1.6,
28
+ num_inference_steps=inference_steps,
29
  separate_cfg_infer=True,
30
+ use_kv_cache=False,
31
  )
32
  img = output[0]
33
  return img
34
+ # def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
35
+ # input_images = []
36
+ # if img1:
37
+ # input_images.append(Image.open(img1))
38
+ # if img2:
39
+ # input_images.append(Image.open(img2))
40
+ # if img3:
41
+ # input_images.append(Image.open(img3))
42
+
43
+ # return input_images[0] if input_images else None
44
+
45
+
46
+ def get_example():
47
+ case = [
48
+ [
49
+ "A woman holds a bouquet of flowers and faces the camera. Thw woman is the one in <img><|image_1|></img>.",
50
+ "./imgs/test_cases/liuyifei.png",
51
+ None,
52
+ None,
53
+ 1024,
54
+ 1024,
55
+ 3.0,
56
+ 20,
57
+ ],
58
+ [
59
+ "Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.",
60
+ "./imgs/test_cases/img1.jpg",
61
+ "./imgs/test_cases/img2.jpg",
62
+ "./imgs/test_cases/img3.jpg",
63
+ 1024,
64
+ 1024,
65
+ 3.0,
66
+ 20,
67
+ ],
68
+ ]
69
+ return case
70
+
71
+ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
72
+ return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps)
73
+
74
 
75
  # Gradio 接口
76
  with gr.Blocks() as demo:
77
  gr.Markdown("## Text + Multiple Images to Image Generator")
 
78
  with gr.Row():
79
  with gr.Column():
80
  # 文本输入框
81
+ prompt_input = gr.Textbox(
82
+ label="Enter your prompt", placeholder="Type your prompt here..."
83
+ )
84
+
85
+ with gr.Row(equal_height=True):
86
+ # 图片上传框
87
+ image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
88
+ image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
89
+ image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
90
+
91
  # 高度和宽度滑块
92
+ height_input = gr.Slider(
93
+ label="Height", minimum=256, maximum=2048, value=1024, step=16
94
+ )
95
+ width_input = gr.Slider(
96
+ label="Width", minimum=256, maximum=2048, value=1024, step=16
97
+ )
98
+
99
  # 引导尺度输入
100
+ guidance_scale_input = gr.Slider(
101
+ label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1
102
+ )
103
+
104
+ num_inference_steps = gr.Slider(
105
+ label="Inference Steps", minimum=1, maximum=50, value=50, step=1
106
+ )
107
+
108
  # 生成按钮
109
  generate_button = gr.Button("Generate Image")
110
+
111
  with gr.Column():
112
  # 输出图像框
113
  output_image = gr.Image(label="Output Image")
 
115
  # 按钮点击事件
116
  generate_button.click(
117
  generate_image,
118
+ inputs=[
119
+ prompt_input,
120
+ image_input_1,
121
+ image_input_2,
122
+ image_input_3,
123
+ height_input,
124
+ width_input,
125
+ guidance_scale_input,
126
+ num_inference_steps,
127
+ ],
128
+ outputs=output_image,
129
+ )
130
+
131
+ gr.Examples(
132
+ examples=get_example(),
133
+ fn=run_for_examples,
134
+ inputs=[
135
+ prompt_input,
136
+ image_input_1,
137
+ image_input_2,
138
+ image_input_3,
139
+ height_input,
140
+ width_input,
141
+ guidance_scale_input,
142
+ num_inference_steps,
143
+ ],
144
+ outputs=output_image,
145
  )
146
 
147
  # 启动应用