Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
import os | |
import spaces | |
from OmniGen import OmniGenPipeline | |
pipe = OmniGenPipeline.from_pretrained( | |
"Shitao/OmniGen-v1" | |
) | |
# 示例处理函数:生成图像 | |
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed): | |
input_images = [img1, img2, img3] | |
# 去除 None | |
input_images = [img for img in input_images if img is not None] | |
if len(input_images) == 0: | |
input_images = None | |
output = pipe( | |
prompt=text, | |
input_images=input_images, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
img_guidance_scale=1.6, | |
num_inference_steps=inference_steps, | |
separate_cfg_infer=True, | |
use_kv_cache=False, | |
seed=seed, | |
) | |
img = output[0] | |
return img | |
# def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps): | |
# input_images = [] | |
# if img1: | |
# input_images.append(Image.open(img1)) | |
# if img2: | |
# input_images.append(Image.open(img2)) | |
# if img3: | |
# input_images.append(Image.open(img3)) | |
# return input_images[0] if input_images else None | |
def get_example(): | |
case = [ | |
[ | |
"A woman holds a bouquet of flowers and faces the camera.", | |
None, | |
None, | |
None, | |
1024, | |
1024, | |
3.0, | |
50, | |
42, | |
], | |
[ | |
"A woman holds a bouquet of flowers and faces the camera. Thw woman is the one in <img><|image_1|></img>.", | |
"./imgs/test_cases/liuyifei.png", | |
None, | |
None, | |
1024, | |
1024, | |
3.0, | |
50, | |
42, | |
], | |
[ | |
"Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.", | |
"./imgs/test_cases/img1.jpg", | |
"./imgs/test_cases/img2.jpg", | |
"./imgs/test_cases/img3.jpg", | |
1024, | |
1024, | |
3.0, | |
50, | |
42, | |
], | |
] | |
return case | |
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed): | |
return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed) | |
# Gradio 接口 | |
with gr.Blocks() as demo: | |
gr.Markdown("# OmniGen: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen)") | |
with gr.Row(): | |
with gr.Column(): | |
# 文本输入框 | |
prompt_input = gr.Textbox( | |
label="Enter your prompt, use <img><|image_i|></img> tokens for images", placeholder="Type your prompt here..." | |
) | |
with gr.Row(equal_height=True): | |
# 图片上传框 | |
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath") | |
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath") | |
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath") | |
# 高度和宽度滑块 | |
height_input = gr.Slider( | |
label="Height", minimum=256, maximum=2048, value=1024, step=16 | |
) | |
width_input = gr.Slider( | |
label="Width", minimum=256, maximum=2048, value=1024, step=16 | |
) | |
# 引导尺度输入 | |
guidance_scale_input = gr.Slider( | |
label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1 | |
) | |
num_inference_steps = gr.Slider( | |
label="Inference Steps", minimum=1, maximum=100, value=50, step=1 | |
) | |
seed_input = gr.Slider( | |
label="Seed", minimum=0, maximum=2147483647, value=42, step=1 | |
) | |
# 生成按钮 | |
generate_button = gr.Button("Generate Image") | |
with gr.Column(): | |
# 输出图像框 | |
output_image = gr.Image(label="Output Image") | |
# 按钮点击事件 | |
generate_button.click( | |
generate_image, | |
inputs=[ | |
prompt_input, | |
image_input_1, | |
image_input_2, | |
image_input_3, | |
height_input, | |
width_input, | |
guidance_scale_input, | |
num_inference_steps, | |
seed_input, | |
], | |
outputs=output_image, | |
) | |
gr.Examples( | |
examples=get_example(), | |
fn=run_for_examples, | |
inputs=[ | |
prompt_input, | |
image_input_1, | |
image_input_2, | |
image_input_3, | |
height_input, | |
width_input, | |
guidance_scale_input, | |
num_inference_steps, | |
seed_input, | |
], | |
outputs=output_image, | |
) | |
# 启动应用 | |
demo.launch() |