Update lvd_pipeline.py
Browse files- lvd_pipeline.py +14 -10
lvd_pipeline.py
CHANGED
@@ -742,15 +742,18 @@ class GroundedTextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
|
742 |
lvd_gligen_phrases_frame = lvd_gligen_phrases_frame[:max_objs]
|
743 |
lvd_gligen_boxes_frame = lvd_gligen_boxes_frame[:max_objs]
|
744 |
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
|
|
|
|
|
|
754 |
# For each entity, described in phrases, is denoted with a bounding box,
|
755 |
# we represent the location information as (xmin,ymin,xmax,ymax)
|
756 |
boxes = torch.zeros(max_objs, 4, device=device,
|
@@ -759,7 +762,8 @@ class GroundedTextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
|
759 |
text_embeddings = torch.zeros(
|
760 |
max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
|
761 |
)
|
762 |
-
|
|
|
763 |
# Generate a mask for each object that is entity described by phrases
|
764 |
masks = torch.zeros(max_objs, device=device,
|
765 |
dtype=self.text_encoder.dtype)
|
|
|
742 |
lvd_gligen_phrases_frame = lvd_gligen_phrases_frame[:max_objs]
|
743 |
lvd_gligen_boxes_frame = lvd_gligen_boxes_frame[:max_objs]
|
744 |
|
745 |
+
n_objs = len(lvd_gligen_boxes_frame)
|
746 |
+
|
747 |
+
if n_objs:
|
748 |
+
# prepare batched input to the PositionNet (boxes, phrases, mask)
|
749 |
+
# Get tokens for phrases from pre-trained CLIPTokenizer
|
750 |
+
tokenizer_inputs = self.tokenizer(
|
751 |
+
lvd_gligen_phrases_frame, padding=True, return_tensors="pt").to(device)
|
752 |
+
# For the token, we use the same pre-trained text encoder
|
753 |
+
# to obtain its text feature
|
754 |
+
_text_embeddings = self.text_encoder(
|
755 |
+
**tokenizer_inputs).pooler_output
|
756 |
+
|
757 |
# For each entity, described in phrases, is denoted with a bounding box,
|
758 |
# we represent the location information as (xmin,ymin,xmax,ymax)
|
759 |
boxes = torch.zeros(max_objs, 4, device=device,
|
|
|
762 |
text_embeddings = torch.zeros(
|
763 |
max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
|
764 |
)
|
765 |
+
if n_objs:
|
766 |
+
text_embeddings[:n_objs] = _text_embeddings
|
767 |
# Generate a mask for each object that is entity described by phrases
|
768 |
masks = torch.zeros(max_objs, device=device,
|
769 |
dtype=self.text_encoder.dtype)
|