yibolu
commited on
Commit
•
3308ae3
1
Parent(s):
6eca12e
update ipadapter
Browse files
lyrasd_model/module/lyrasd_ip_adapter.py
CHANGED
@@ -45,17 +45,11 @@ class LyraIPAdapter:
|
|
45 |
image_encoder_path=None,
|
46 |
num_ip_tokens=4,
|
47 |
ip_projection_dim=None,
|
48 |
-
fp_ckpt=None,
|
49 |
-
num_fp_tokens=1,
|
50 |
-
fp_projection_dim=None,
|
51 |
):
|
52 |
self.pipe = sd_pipe
|
53 |
self.device = device
|
54 |
-
self.fp_ckpt = fp_ckpt
|
55 |
self.ip_ckpt = ip_ckpt
|
56 |
-
self.num_fp_tokens = num_fp_tokens
|
57 |
self.num_ip_tokens = num_ip_tokens
|
58 |
-
self.fp_projection_dim = fp_projection_dim
|
59 |
self.ip_projection_dim = ip_projection_dim
|
60 |
self.sdxl = sdxl
|
61 |
self.ip_plus = ip_plus
|
@@ -76,10 +70,6 @@ class LyraIPAdapter:
|
|
76 |
else:
|
77 |
self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)
|
78 |
|
79 |
-
# face proj model
|
80 |
-
if self.fp_ckpt:
|
81 |
-
self.face_proj_model = self.init_proj(self.fp_projection_dim, self.num_fp_tokens)
|
82 |
-
|
83 |
self.load_ip_adapter()
|
84 |
|
85 |
def init_proj_diffuser(self, state_dict):
|
@@ -131,16 +121,9 @@ class LyraIPAdapter:
|
|
131 |
pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
|
132 |
dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
|
133 |
unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
|
134 |
-
|
135 |
-
if self.fp_ckpt:
|
136 |
-
state_dict = torch.load(self.fp_ckpt, map_location="cpu")
|
137 |
-
self.face_proj_model.load_state_dict(state_dict["face_proj"])
|
138 |
-
pretrained_path, subfolder, weight_name = parse_ckpt_path(self.fp_ckpt)
|
139 |
-
dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
|
140 |
-
unet.load_facein(dir_ipadapter, "fp16")
|
141 |
|
142 |
@torch.inference_mode()
|
143 |
-
def get_image_embeds(self, image=None
|
144 |
image_prompt_embeds, uncond_image_prompt_embeds = None, None
|
145 |
|
146 |
if image is not None:
|
@@ -160,22 +143,11 @@ class LyraIPAdapter:
|
|
160 |
uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
161 |
image_prompt_embeds = clip_image_prompt_embeds
|
162 |
uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds
|
163 |
-
|
164 |
-
if face_emb is not None:
|
165 |
-
face_embeds = face_emb.to(self.device, dtype=torch.float16)
|
166 |
-
face_prompt_embeds = self.face_proj_model(face_embeds)
|
167 |
-
uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
|
168 |
-
if image_prompt_embeds is None:
|
169 |
-
image_prompt_embeds = face_prompt_embeds
|
170 |
-
uncond_image_prompt_embeds = uncond_face_prompt_embeds
|
171 |
-
else:
|
172 |
-
image_prompt_embeds = torch.cat([face_prompt_embeds, image_prompt_embeds], axis=1)
|
173 |
-
uncond_image_prompt_embeds = torch.cat([uncond_face_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
174 |
|
175 |
return image_prompt_embeds, uncond_image_prompt_embeds
|
176 |
|
177 |
@torch.inference_mode()
|
178 |
-
def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None,
|
179 |
dict_tensor = {}
|
180 |
|
181 |
if self.ip_ckpt and ip_scale>0:
|
@@ -199,91 +171,4 @@ class LyraIPAdapter:
|
|
199 |
clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
|
200 |
ip_image_embeds = self.image_proj_model(clip_image_embeds)
|
201 |
dict_tensor["ip_hidden_states"] = ip_image_embeds
|
202 |
-
|
203 |
-
if face_emb is not None and self.fp_ckpt and ip_scale>0:
|
204 |
-
face_embeds = face_emb.to(self.device, dtype=torch.float16)
|
205 |
-
face_prompt_embeds = self.face_proj_model(face_embeds)
|
206 |
-
uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
|
207 |
-
if do_classifier_free_guidance:
|
208 |
-
fp_image_embeds = torch.cat([uncond_face_prompt_embeds, face_prompt_embeds])
|
209 |
-
else:
|
210 |
-
fp_image_embeds = face_prompt_embeds
|
211 |
-
dict_tensor["fp_hidden_states"] = fp_image_embeds
|
212 |
return dict_tensor
|
213 |
-
|
214 |
-
|
215 |
-
if __name__ == "__main__":
|
216 |
-
sys.path.append("/data/home/kiokaxiao/repos/LyraSD/python/lyrasd")
|
217 |
-
from lyrasd_model import LyraSdXLTxt2ImgPipeline
|
218 |
-
|
219 |
-
model_path = "/data/SharedModels/SD/checkpoints/stable-diffusion-xl-base-1.0/"
|
220 |
-
# model_path = "/cfs-datasets/projects/VirtualIdol/models/base_model/sdxl/xxmix9realisticsdxlV1"
|
221 |
-
lib_path = os.environ.get("LIBLYRASD_SO")
|
222 |
-
|
223 |
-
dir_ip_adapter = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin"
|
224 |
-
dir_facein = "/cfs-datasets/projects/VirtualIdol/models/FaceIn/v1/FaceIn_sdxl.bin"
|
225 |
-
image_encoder_path = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/models/image_encoder"
|
226 |
-
|
227 |
-
pipeline = LyraSdXLTxt2ImgPipeline(model_path, lib_path)
|
228 |
-
pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, dir_facein, 1, 512)
|
229 |
-
# pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, "", 1, 512)
|
230 |
-
|
231 |
-
face_emb = np.load("/data/home/kiokaxiao/repos/VidolImageDraw/girl.npy")
|
232 |
-
face_emb = torch.Tensor(face_emb.reshape([1,-1]))
|
233 |
-
ip_image = Image.open("/data/home/kiokaxiao/repos/VidolImageDraw/images/input_image.png").convert('RGB')
|
234 |
-
|
235 |
-
generator = torch.Generator("cuda").manual_seed(123)
|
236 |
-
batches = [2]
|
237 |
-
sizes = [[512, 512], [768, 768], [1024, 1024]]
|
238 |
-
# sizes = [[832, 640]]
|
239 |
-
# sizes = [[1024, 1024]]
|
240 |
-
running_cnt = 1
|
241 |
-
do_bench = False
|
242 |
-
|
243 |
-
ip_ratio = 1
|
244 |
-
facein_ratio = 0.6
|
245 |
-
extra_tensor_dict = {}
|
246 |
-
extra_tensor_dict = pipeline.ip_adapter_helper.get_image_embeds_lyrasd(ip_image, None, face_emb, batches[0], ip_ratio, facein_ratio)
|
247 |
-
param_scale_dict = {"facein_ratio": facein_ratio, "ip_ratio": ip_ratio}
|
248 |
-
draw_cfg = {'width': 640,
|
249 |
-
'num_inference_steps': 30,
|
250 |
-
'height': 832,
|
251 |
-
'negative_prompt': '(worst quality, low quality, 3d, 2d, cartoons, sketch), tooth, open mouth',
|
252 |
-
'guidance_scale': 7,
|
253 |
-
'prompt': 'xxmixgirl, masterpiece, best quality, 1girl, solo, looking at viewer, simple background, hair ornament, black eyes, portrait',
|
254 |
-
'output_type': 'pil',
|
255 |
-
'extra_tensor_dict': extra_tensor_dict,
|
256 |
-
"param_scale_dict": param_scale_dict}
|
257 |
-
|
258 |
-
|
259 |
-
def warmup(draw_cfg):
|
260 |
-
draw_cfg_wm = deepcopy(draw_cfg)
|
261 |
-
draw_cfg_wm['num_inference_steps'] = 1
|
262 |
-
pipeline(**draw_cfg_wm, generator= generator)
|
263 |
-
|
264 |
-
if not do_bench:
|
265 |
-
images = pipeline(**draw_cfg, generator= generator)
|
266 |
-
else:
|
267 |
-
for batch in batches:
|
268 |
-
for height, width in sizes:
|
269 |
-
draw_cfg['width'] = width
|
270 |
-
draw_cfg['height'] = height
|
271 |
-
draw_cfg['num_images_per_prompt'] = batch
|
272 |
-
draw_cfg["num_inference_steps"] = 20
|
273 |
-
warmup(draw_cfg)
|
274 |
-
time_uses = []
|
275 |
-
for x in range(running_cnt):
|
276 |
-
start = time.perf_counter()
|
277 |
-
draw_cfg['num_images_per_prompt'] = batch
|
278 |
-
generator = torch.Generator("cuda").manual_seed(123)
|
279 |
-
print("draw_cfg: ", draw_cfg.keys())
|
280 |
-
print("draw_cfg: ", draw_cfg)
|
281 |
-
|
282 |
-
images = pipeline(**draw_cfg, generator= generator)
|
283 |
-
time_use = time.perf_counter() - start
|
284 |
-
time_uses.append(time_use)
|
285 |
-
print("bench", batch, width, sum(time_uses)/running_cnt, get_mem_use())
|
286 |
-
|
287 |
-
print(type(images))
|
288 |
-
images[0].save("t.png")
|
289 |
-
|
|
|
45 |
image_encoder_path=None,
|
46 |
num_ip_tokens=4,
|
47 |
ip_projection_dim=None,
|
|
|
|
|
|
|
48 |
):
|
49 |
self.pipe = sd_pipe
|
50 |
self.device = device
|
|
|
51 |
self.ip_ckpt = ip_ckpt
|
|
|
52 |
self.num_ip_tokens = num_ip_tokens
|
|
|
53 |
self.ip_projection_dim = ip_projection_dim
|
54 |
self.sdxl = sdxl
|
55 |
self.ip_plus = ip_plus
|
|
|
70 |
else:
|
71 |
self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)
|
72 |
|
|
|
|
|
|
|
|
|
73 |
self.load_ip_adapter()
|
74 |
|
75 |
def init_proj_diffuser(self, state_dict):
|
|
|
121 |
pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
|
122 |
dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
|
123 |
unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
@torch.inference_mode()
|
126 |
+
def get_image_embeds(self, image=None):
|
127 |
image_prompt_embeds, uncond_image_prompt_embeds = None, None
|
128 |
|
129 |
if image is not None:
|
|
|
143 |
uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
144 |
image_prompt_embeds = clip_image_prompt_embeds
|
145 |
uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
return image_prompt_embeds, uncond_image_prompt_embeds
|
148 |
|
149 |
@torch.inference_mode()
|
150 |
+
def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None, batch_size = 1, ip_scale=1.0, do_classifier_free_guidance=True):
|
151 |
dict_tensor = {}
|
152 |
|
153 |
if self.ip_ckpt and ip_scale>0:
|
|
|
171 |
clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
|
172 |
ip_image_embeds = self.image_proj_model(clip_image_embeds)
|
173 |
dict_tensor["ip_hidden_states"] = ip_image_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
return dict_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|