File size: 955 Bytes
1c1d081 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import torch.nn.functional as F
@torch.no_grad()
def project_face_embs(pipeline, face_embs):
'''
face_embs: (N, 512) normalized ArcFace embeddings
'''
arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0]
input_ids = pipeline.tokenizer(
"photo of a id person",
truncation=True,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
return_tensors="pt",
).input_ids.to(pipeline.device)
face_embs_padded = F.pad(face_embs, (0, pipeline.text_encoder.config.hidden_size-512), "constant", 0)
token_embs = pipeline.text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True)
token_embs[input_ids==arcface_token_id] = face_embs_padded
prompt_embeds = pipeline.text_encoder(
input_ids=input_ids,
input_token_embs=token_embs
)[0]
return prompt_embeds |