Shitao commited on
Commit
730f5fd
1 Parent(s): 0fe01cf

more examples

Browse files
OmniGen/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/OmniGen/__pycache__/__init__.cpython-310.pyc and b/OmniGen/__pycache__/__init__.cpython-310.pyc differ
 
OmniGen/__pycache__/model.cpython-310.pyc CHANGED
Binary files a/OmniGen/__pycache__/model.cpython-310.pyc and b/OmniGen/__pycache__/model.cpython-310.pyc differ
 
OmniGen/__pycache__/pipeline.cpython-310.pyc CHANGED
Binary files a/OmniGen/__pycache__/pipeline.cpython-310.pyc and b/OmniGen/__pycache__/pipeline.cpython-310.pyc differ
 
OmniGen/__pycache__/processor.cpython-310.pyc CHANGED
Binary files a/OmniGen/__pycache__/processor.cpython-310.pyc and b/OmniGen/__pycache__/processor.cpython-310.pyc differ
 
OmniGen/__pycache__/scheduler.cpython-310.pyc CHANGED
Binary files a/OmniGen/__pycache__/scheduler.cpython-310.pyc and b/OmniGen/__pycache__/scheduler.cpython-310.pyc differ
 
OmniGen/__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/OmniGen/__pycache__/transformer.cpython-310.pyc and b/OmniGen/__pycache__/transformer.cpython-310.pyc differ
 
OmniGen/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/OmniGen/__pycache__/utils.cpython-310.pyc and b/OmniGen/__pycache__/utils.cpython-310.pyc differ
 
OmniGen/pipeline.py CHANGED
@@ -16,6 +16,7 @@ from diffusers.utils import (
16
  scale_lora_layers,
17
  unscale_lora_layers,
18
  )
 
19
 
20
  from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
21
 
@@ -59,12 +60,12 @@ class OmniGenPipeline:
59
 
60
  @classmethod
61
  def from_pretrained(cls, model_name, vae_path: str=None):
62
- if not os.path.exists(model_name):
63
  logger.info("Model not found, downloading...")
64
  cache_folder = os.getenv('HF_HUB_CACHE')
65
  model_name = snapshot_download(repo_id=model_name,
66
  cache_dir=cache_folder,
67
- ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
68
  logger.info(f"Downloaded model to {model_name}")
69
  model = OmniGen.from_pretrained(model_name)
70
  processor = OmniGenProcessor.from_pretrained(model_name)
@@ -82,6 +83,8 @@ class OmniGenPipeline:
82
  def merge_lora(self, lora_path: str):
83
  model = PeftModel.from_pretrained(self.model, lora_path)
84
  model.merge_and_unload()
 
 
85
  self.model = model
86
 
87
  def to(self, device: Union[str, torch.device]):
 
16
  scale_lora_layers,
17
  unscale_lora_layers,
18
  )
19
+ from safetensors.torch import load_file
20
 
21
  from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
22
 
 
60
 
61
  @classmethod
62
  def from_pretrained(cls, model_name, vae_path: str=None):
63
+ if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
64
  logger.info("Model not found, downloading...")
65
  cache_folder = os.getenv('HF_HUB_CACHE')
66
  model_name = snapshot_download(repo_id=model_name,
67
  cache_dir=cache_folder,
68
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
69
  logger.info(f"Downloaded model to {model_name}")
70
  model = OmniGen.from_pretrained(model_name)
71
  processor = OmniGenProcessor.from_pretrained(model_name)
 
83
  def merge_lora(self, lora_path: str):
84
  model = PeftModel.from_pretrained(self.model, lora_path)
85
  model.merge_and_unload()
86
+
87
+
88
  self.model = model
89
 
90
  def to(self, device: Union[str, torch.device]):
OmniGen/train_helper/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .data import DatasetFromJson, TrainDataCollator
2
+ from .loss import training_losses
OmniGen/train_helper/data.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datasets
3
+ from datasets import load_dataset, ClassLabel, concatenate_datasets
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ from PIL import Image
8
+ import json
9
+ import copy
10
+ # import torchvision.transforms as T
11
+ from torchvision import transforms
12
+ import pickle
13
+ import re
14
+
15
+ from OmniGen import OmniGenProcessor
16
+ from OmniGen.processor import OmniGenCollator
17
+
18
+
19
+ class DatasetFromJson(torch.utils.data.Dataset):
20
+ def __init__(
21
+ self,
22
+ json_file: str,
23
+ image_path: str,
24
+ processer: OmniGenProcessor,
25
+ image_transform,
26
+ max_input_length_limit: int = 18000,
27
+ condition_dropout_prob: float = 0.1,
28
+ keep_raw_resolution: bool = True,
29
+ ):
30
+
31
+ self.image_transform = image_transform
32
+ self.processer = processer
33
+ self.condition_dropout_prob = condition_dropout_prob
34
+ self.max_input_length_limit = max_input_length_limit
35
+ self.keep_raw_resolution = keep_raw_resolution
36
+
37
+ self.data = load_dataset('json', data_files=json_file)['train']
38
+ self.image_path = image_path
39
+
40
+ def process_image(self, image_file):
41
+ if self.image_path is not None:
42
+ image_file = os.path.join(self.image_path, image_file)
43
+ image = Image.open(image_file).convert('RGB')
44
+ return self.image_transform(image)
45
+
46
+ def get_example(self, index):
47
+ example = self.data[index]
48
+
49
+ instruction, input_images, output_image = example['instruction'], example['input_images'], example['output_image']
50
+ if random.random() < self.condition_dropout_prob:
51
+ instruction = '<cfg>'
52
+ input_images = None
53
+ if input_images is not None:
54
+ input_images = [self.process_image(x) for x in input_images]
55
+ mllm_input = self.processer.process_multi_modal_prompt(instruction, input_images)
56
+
57
+ output_image = self.process_image(output_image)
58
+
59
+ return (mllm_input, output_image)
60
+
61
+
62
+ def __getitem__(self, index):
63
+ return self.get_example(index)
64
+ for _ in range(8):
65
+ try:
66
+ mllm_input, output_image = self.get_example(index)
67
+ if len(mllm_input['input_ids']) > self.max_input_length_limit:
68
+ raise RuntimeError(f"cur number of tokens={len(mllm_input['input_ids'])}, larger than max_input_length_limit={self.max_input_length_limit}")
69
+ return mllm_input, output_image
70
+ except Exception as e:
71
+ print("error when loading data: ", e)
72
+ print(self.data[index])
73
+ index = random.randint(0, len(self.data)-1)
74
+ raise RuntimeError("Too many bad data.")
75
+
76
+
77
+ def __len__(self):
78
+ return len(self.data)
79
+
80
+
81
+
82
+ class TrainDataCollator(OmniGenCollator):
83
+ def __init__(self, pad_token_id: int, hidden_size: int, keep_raw_resolution: bool):
84
+ self.pad_token_id = pad_token_id
85
+ self.hidden_size = hidden_size
86
+ self.keep_raw_resolution = keep_raw_resolution
87
+
88
+ def __call__(self, features):
89
+ mllm_inputs = [f[0] for f in features]
90
+
91
+ output_images = [f[1].unsqueeze(0) for f in features]
92
+ target_img_size = [[x.size(-2), x.size(-1)] for x in output_images]
93
+
94
+ all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
95
+
96
+ if not self.keep_raw_resolution:
97
+ output_image = torch.cat(output_image, dim=0)
98
+ if len(pixel_values) > 0:
99
+ all_pixel_values = torch.cat(all_pixel_values, dim=0)
100
+ else:
101
+ all_pixel_values = None
102
+
103
+ data = {"input_ids": all_padded_input_ids,
104
+ "attention_mask": all_attention_mask,
105
+ "position_ids": all_position_ids,
106
+ "input_pixel_values": all_pixel_values,
107
+ "input_image_sizes": all_image_sizes,
108
+ "padding_images": all_padding_images,
109
+ "output_images": output_images,
110
+ }
111
+ return data
112
+
113
+
114
+
115
+
116
+
OmniGen/train_helper/loss.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def sample_x0(x1):
5
+ """Sampling x0 & t based on shape of x1 (if needed)
6
+ Args:
7
+ x1 - data point; [batch, *dim]
8
+ """
9
+ if isinstance(x1, (list, tuple)):
10
+ x0 = [torch.randn_like(img_start) for img_start in x1]
11
+ else:
12
+ x0 = torch.randn_like(x1)
13
+
14
+ return x0
15
+
16
+ def sample_timestep(x1):
17
+ u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
18
+ t = 1 / (1 + torch.exp(-u))
19
+ t = t.to(x1[0])
20
+ return t
21
+
22
+
23
+ def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
24
+ """Loss for training torche score model
25
+ Args:
26
+ - model: backbone model; could be score, noise, or velocity
27
+ - x1: datapoint
28
+ - model_kwargs: additional arguments for torche model
29
+ """
30
+ if model_kwargs == None:
31
+ model_kwargs = {}
32
+
33
+ B = len(x1)
34
+
35
+ x0 = sample_x0(x1)
36
+ t = sample_timestep(x1)
37
+
38
+ if isinstance(x1, (list, tuple)):
39
+ xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
40
+ ut = [x1[i] - x0[i] for i in range(B)]
41
+ else:
42
+ dims = [1] * (len(x1.size()) - 1)
43
+ t_ = t.view(t.size(0), *dims)
44
+ xt = t_ * x1 + (1 - t_) * x0
45
+ ut = x1 - x0
46
+
47
+ model_output = model(xt, t, **model_kwargs)
48
+
49
+ terms = {}
50
+
51
+ if isinstance(x1, (list, tuple)):
52
+ assert len(model_output) == len(ut) == len(x1)
53
+ for i in range(B):
54
+ terms["loss"] = torch.stack(
55
+ [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
56
+ dim=0,
57
+ )
58
+ else:
59
+ terms["loss"] = mean_flat(((model_output - ut) ** 2))
60
+
61
+ return terms
62
+
63
+
64
+ def mean_flat(x):
65
+ """
66
+ Take torche mean over all non-batch dimensions.
67
+ """
68
+ return torch.mean(x, dim=list(range(1, len(x.size()))))
app.py CHANGED
@@ -11,7 +11,7 @@ pipe = OmniGenPipeline.from_pretrained(
11
 
12
  @spaces.GPU(duration=180)
13
  # 示例处理函数:生成图像
14
- def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
15
  input_images = [img1, img2, img3]
16
  # 去除 None
17
  input_images = [img for img in input_images if img is not None]
@@ -26,7 +26,7 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, infere
26
  guidance_scale=guidance_scale,
27
  img_guidance_scale=1.6,
28
  num_inference_steps=inference_steps,
29
- separate_cfg_infer=True,
30
  use_kv_cache=False,
31
  seed=seed,
32
  )
@@ -47,26 +47,28 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, infere
47
  def get_example():
48
  case = [
49
  [
50
- "A vintage camera placed on the ground, ejecting a swirling cloud of Polaroid-style photographs into the air. The photos, showing landscapes, wildlife, and travel scenes, seem to defy gravity, floating upward in a vortex of motion. The camera emits a glowing, smoky light from within, enhancing the magical, surreal atmosphere. The dark background contrasts with the illuminated photos and camera, creating a dreamlike, nostalgic scene filled with vibrant colors and dynamic movement. Scattered photos are visible on the ground, further contributing to the idea of an explosion of captured memories.",
51
  None,
52
  None,
53
  None,
54
  1024,
55
  1024,
56
  2.5,
 
57
  50,
58
  0,
59
  ],
60
  [
61
- "A woman <img><|image_1|></img> in a wedding dress. Next to her is a black-haired man.",
62
  "./imgs/test_cases/yifei2.png",
63
  None,
64
  None,
65
  1024,
66
  1024,
67
  2.5,
 
68
  50,
69
- 0,
70
  ],
71
  [
72
  "A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
@@ -76,17 +78,55 @@ def get_example():
76
  1024,
77
  1024,
78
  2.5,
 
79
  50,
80
  0,
81
  ],
82
  [
83
- "Two men are celebrating with raised glasses in a restaurant. A man is <img><|image_1|></img>. The other man is <img><|image_2|></img>.",
84
- "./imgs/test_cases/young_musk.jpg",
85
- "./imgs/test_cases/young_trump.jpeg",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  None,
87
  1024,
88
  1024,
89
  2.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  50,
91
  0,
92
  ],
@@ -98,6 +138,7 @@ def get_example():
98
  1024,
99
  1024,
100
  2.5,
 
101
  50,
102
  222,
103
  ],
@@ -109,6 +150,7 @@ def get_example():
109
  1024,
110
  1024,
111
  2.0,
 
112
  50,
113
  0,
114
  ],
@@ -120,6 +162,7 @@ def get_example():
120
  1024,
121
  1024,
122
  2,
 
123
  50,
124
  42,
125
  ],
@@ -131,9 +174,22 @@ def get_example():
131
  1024,
132
  1024,
133
  2.0,
 
134
  50,
135
  123,
136
  ],
 
 
 
 
 
 
 
 
 
 
 
 
137
  [
138
  "<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
139
  "./imgs/test_cases/watch.jpg",
@@ -142,25 +198,27 @@ def get_example():
142
  1024,
143
  1024,
144
  2.5,
 
145
  50,
146
  0,
147
  ],
148
  [
149
- "Three guitars are displayed side by side on a rustic wooden stage, each showcasing its unique character and style. The left guitar is <img><|image_1|><\/img>. The middle guitar is <img><|image_2|><\/img>. The right guitars is <img><|image_3|><\/img>.",
150
- "./imgs/test_cases/guitar1.png",
151
- "./imgs/test_cases/guitar1.png",
152
- "./imgs/test_cases/guitar1.png",
153
  1024,
154
  1024,
155
  2.5,
 
156
  50,
157
- 0,
158
  ],
159
  ]
160
  return case
161
 
162
- def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
163
- return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed)
164
 
165
  description = """
166
  OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
@@ -168,6 +226,13 @@ OmniGen is a unified image generation model that you can use to perform various
168
  For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
169
  For example, use an image of a woman to generate a new image:
170
  prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
 
 
 
 
 
 
 
171
  """
172
 
173
  # Gradio 接口
@@ -197,7 +262,11 @@ with gr.Blocks() as demo:
197
 
198
  # 引导尺度输入
199
  guidance_scale_input = gr.Slider(
200
- label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1
 
 
 
 
201
  )
202
 
203
  num_inference_steps = gr.Slider(
@@ -226,6 +295,7 @@ with gr.Blocks() as demo:
226
  height_input,
227
  width_input,
228
  guidance_scale_input,
 
229
  num_inference_steps,
230
  seed_input,
231
  ],
@@ -243,6 +313,7 @@ with gr.Blocks() as demo:
243
  height_input,
244
  width_input,
245
  guidance_scale_input,
 
246
  num_inference_steps,
247
  seed_input,
248
  ],
 
11
 
12
  @spaces.GPU(duration=180)
13
  # 示例处理函数:生成图像
14
+ def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed):
15
  input_images = [img1, img2, img3]
16
  # 去除 None
17
  input_images = [img for img in input_images if img is not None]
 
26
  guidance_scale=guidance_scale,
27
  img_guidance_scale=1.6,
28
  num_inference_steps=inference_steps,
29
+ separate_cfg_infer=True, # set False can speed up the inference process
30
  use_kv_cache=False,
31
  seed=seed,
32
  )
 
47
  def get_example():
48
  case = [
49
  [
50
+ "A curly-haired man in a red shirt is drinking tea.",
51
  None,
52
  None,
53
  None,
54
  1024,
55
  1024,
56
  2.5,
57
+ 1.6,
58
  50,
59
  0,
60
  ],
61
  [
62
+ "The woman in <img><|image_1|></img> waves her hand happily in the crowd",
63
  "./imgs/test_cases/yifei2.png",
64
  None,
65
  None,
66
  1024,
67
  1024,
68
  2.5,
69
+ 1.9,
70
  50,
71
+ 128,
72
  ],
73
  [
74
  "A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
 
78
  1024,
79
  1024,
80
  2.5,
81
+ 1.6,
82
  50,
83
  0,
84
  ],
85
  [
86
+ "Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. The other woman is <img><|image_2|></img>.",
87
+ "./imgs/test_cases/mckenna.jpg",
88
+ "./imgs/test_cases/Amanda.jpg",
89
+ None,
90
+ 1024,
91
+ 1024,
92
+ 2.5,
93
+ 1.8,
94
+ 50,
95
+ 168,
96
+ ],
97
+ [
98
+ "A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
99
+ "./imgs/test_cases/1.jpg",
100
+ "./imgs/test_cases/2.jpg",
101
+ None,
102
+ 1024,
103
+ 1024,
104
+ 2.5,
105
+ 1.6,
106
+ 50,
107
+ 60,
108
+ ],
109
+ [
110
+ "A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
111
+ "./imgs/test_cases/3.jpg",
112
+ "./imgs/test_cases/4.jpg",
113
  None,
114
  1024,
115
  1024,
116
  2.5,
117
+ 1.8,
118
+ 50,
119
+ 66,
120
+ ],
121
+ [
122
+ "The flower <img><|image_1|><\/img> is placed in the vase which is in the middle of <img><|image_2|><\/img> on a wooden table of a living room",
123
+ "./imgs/test_cases/rose.jpg",
124
+ "./imgs/test_cases/vase.jpg",
125
+ None,
126
+ 1024,
127
+ 1024,
128
+ 2.5,
129
+ 1.6,
130
  50,
131
  0,
132
  ],
 
138
  1024,
139
  1024,
140
  2.5,
141
+ 1.6,
142
  50,
143
  222,
144
  ],
 
150
  1024,
151
  1024,
152
  2.0,
153
+ 1.6,
154
  50,
155
  0,
156
  ],
 
162
  1024,
163
  1024,
164
  2,
165
+ 1.6,
166
  50,
167
  42,
168
  ],
 
174
  1024,
175
  1024,
176
  2.0,
177
+ 1.6,
178
  50,
179
  123,
180
  ],
181
+ [
182
+ "Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
183
+ "./imgs/demo_cases/edit.png",
184
+ None,
185
+ None,
186
+ 1024,
187
+ 1024,
188
+ 2.0,
189
+ 1.6,
190
+ 50,
191
+ 1,
192
+ ],
193
  [
194
  "<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
195
  "./imgs/test_cases/watch.jpg",
 
198
  1024,
199
  1024,
200
  2.5,
201
+ 1.6,
202
  50,
203
  0,
204
  ],
205
  [
206
+ "According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
207
+ "./imgs/test_cases/icl1.jpg",
208
+ "./imgs/test_cases/icl2.jpg",
209
+ "./imgs/test_cases/icl3.jpg",
210
  1024,
211
  1024,
212
  2.5,
213
+ 1.6,
214
  50,
215
+ 1,
216
  ],
217
  ]
218
  return case
219
 
220
+ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed):
221
+ return generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed)
222
 
223
  description = """
224
  OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
 
226
  For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
227
  For example, use an image of a woman to generate a new image:
228
  prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
229
+
230
+ Tips:
231
+ - Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
232
+ - Low-quality: More detailed prompt will lead to better results.
233
+ - Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
234
+ - Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
235
+ - For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
236
  """
237
 
238
  # Gradio 接口
 
262
 
263
  # 引导尺度输入
264
  guidance_scale_input = gr.Slider(
265
+ label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
266
+ )
267
+
268
+ img_guidance_scale_input = gr.Slider(
269
+ label="img_guidance_scale", minimum=1.0, maximum=2.0, value=1.6, step=0.1
270
  )
271
 
272
  num_inference_steps = gr.Slider(
 
295
  height_input,
296
  width_input,
297
  guidance_scale_input,
298
+ img_guidance_scale_input,
299
  num_inference_steps,
300
  seed_input,
301
  ],
 
313
  height_input,
314
  width_input,
315
  guidance_scale_input,
316
+ img_guidance_scale_input,
317
  num_inference_steps,
318
  seed_input,
319
  ],
imgs/demo_cases/edit.png CHANGED

Git LFS Details

  • SHA256: 889f3da462745ffdb4c3300099ea941d0a75db3259e8eaa1dbc95f8a28f11c70
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB

Git LFS Details

  • SHA256: a83fc3b2ab185a93cb10d207a8776f3a04dc187739d87816cfb33f52d46af502
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
imgs/demo_cases/entity.png CHANGED

Git LFS Details

  • SHA256: f99af692215c6292ab64b26450bb1eb2b7bc6d5c2f450f21a9953e37455fe1b5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB

Git LFS Details

  • SHA256: 5e18387fa43989515fd18dcb4ce8edeab0e32aa539d6c14ce374cb5790d8f64b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
imgs/demo_cases/reasoning.png CHANGED

Git LFS Details

  • SHA256: 0df9b7aab47e59d792bf5148f1797eea966064dbf6fec3ba8b169facf51c09d7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB

Git LFS Details

  • SHA256: eb510edcb5628c0def3871cef2e0351acc578a1ceef445ebbd72f8b6eb92fc9d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
imgs/demo_cases/same_pose.png CHANGED

Git LFS Details

  • SHA256: 2359fe0b95233c418f554e2dd98c39d9bc79053be0242569286f90d50201313a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB

Git LFS Details

  • SHA256: beccbeabfc408f319661d9af1063005cbc21c977ba50b910491611ca3babd876
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
imgs/demo_cases/skeletal.png CHANGED

Git LFS Details

  • SHA256: 7f05a79ce9ae0179fe4948568cb71dc3fbc8e9f6b2ecf5a839fa255dfce0aaf6
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB

Git LFS Details

  • SHA256: 30c7937855228adec69da7d9bc3170c9f434a6b159feaf02d362033c1901a671
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
imgs/demo_cases/skeletal2img.png CHANGED

Git LFS Details

  • SHA256: 3d538579f45fce26053f2af83a7aea5be12f2270fba2282a17e9d3bae0c2e91d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB

Git LFS Details

  • SHA256: 86c21341018bb633f364d40afbf361b5e5690bf1e6539b99150e4aea0ed695b6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
imgs/{demo_cases.png → test_cases/1.jpg} RENAMED
File without changes
imgs/{overall.jpg → test_cases/2.jpg} RENAMED
File without changes
imgs/test_cases/3.jpg ADDED

Git LFS Details

  • SHA256: c8fef6b304efc3fc189991ec28b83bbe15c391af55b2bfd85276eb19d49194c9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
imgs/test_cases/4.jpg ADDED

Git LFS Details

  • SHA256: 222e844198656a13facbf0f0afe327b074641a7f20d4120418fa1302e61db538
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
imgs/test_cases/Amanda.jpg ADDED

Git LFS Details

  • SHA256: c20a508b8619fca4d963f574bca51c7460f274218507c97c2853fa6eaea6d0cb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
imgs/test_cases/icl1.jpg ADDED

Git LFS Details

  • SHA256: 0e2b2086ad903c43aee1cc98902b7c53864765c6c99758acf39618ac1ad54b0e
  • Pointer size: 130 Bytes
  • Size of remote file: 76.6 kB
imgs/test_cases/icl2.jpg ADDED

Git LFS Details

  • SHA256: 48bafc52d6721c636e1aec9ebcd1a76c017cc926909bf03270993dd423bc49f9
  • Pointer size: 130 Bytes
  • Size of remote file: 86.3 kB
imgs/test_cases/icl3.jpg ADDED

Git LFS Details

  • SHA256: 077a7c69f7ca24808922e5acc7762a62182d51031f4f9e0d035ce80b09a81d5e
  • Pointer size: 130 Bytes
  • Size of remote file: 77.8 kB
imgs/test_cases/mckenna.jpg ADDED

Git LFS Details

  • SHA256: bd20a5841f84114859e46c4000d9b8035a40378b5d40fbb2b559864813cd402f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.78 MB
imgs/test_cases/rose.jpg ADDED

Git LFS Details

  • SHA256: 2578fcc252f8b3240f9bc621f9c22f7b186a959ee1b2a037d2c9c31be99fae91
  • Pointer size: 130 Bytes
  • Size of remote file: 68.8 kB
imgs/test_cases/vase.jpg ADDED

Git LFS Details

  • SHA256: ab4e3e4b1228d85e7a9c4979bb9d825817d88799fe187f12216a22e2c3ceaa93
  • Pointer size: 130 Bytes
  • Size of remote file: 31.5 kB
imgs/test_cases/zhang.png ADDED

Git LFS Details

  • SHA256: 020925b411e9e053354876116e92722e7a5ee003d45070a2cf58b1902f2162cd
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB