longlian commited on
Commit
2e20888
1 Parent(s): 66f32ec

Update lvd_pipeline.py

Browse files
Files changed (1) hide show
  1. lvd_pipeline.py +14 -10
lvd_pipeline.py CHANGED
@@ -742,15 +742,18 @@ class GroundedTextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMix
742
  lvd_gligen_phrases_frame = lvd_gligen_phrases_frame[:max_objs]
743
  lvd_gligen_boxes_frame = lvd_gligen_boxes_frame[:max_objs]
744
 
745
- # prepare batched input to the PositionNet (boxes, phrases, mask)
746
- # Get tokens for phrases from pre-trained CLIPTokenizer
747
- tokenizer_inputs = self.tokenizer(
748
- lvd_gligen_phrases_frame, padding=True, return_tensors="pt").to(device)
749
- # For the token, we use the same pre-trained text encoder
750
- # to obtain its text feature
751
- _text_embeddings = self.text_encoder(
752
- **tokenizer_inputs).pooler_output
753
- n_objs = len(lvd_gligen_boxes_frame)
 
 
 
754
  # For each entity, described in phrases, is denoted with a bounding box,
755
  # we represent the location information as (xmin,ymin,xmax,ymax)
756
  boxes = torch.zeros(max_objs, 4, device=device,
@@ -759,7 +762,8 @@ class GroundedTextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMix
759
  text_embeddings = torch.zeros(
760
  max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
761
  )
762
- text_embeddings[:n_objs] = _text_embeddings
 
763
  # Generate a mask for each object that is entity described by phrases
764
  masks = torch.zeros(max_objs, device=device,
765
  dtype=self.text_encoder.dtype)
 
742
  lvd_gligen_phrases_frame = lvd_gligen_phrases_frame[:max_objs]
743
  lvd_gligen_boxes_frame = lvd_gligen_boxes_frame[:max_objs]
744
 
745
+ n_objs = len(lvd_gligen_boxes_frame)
746
+
747
+ if n_objs:
748
+ # prepare batched input to the PositionNet (boxes, phrases, mask)
749
+ # Get tokens for phrases from pre-trained CLIPTokenizer
750
+ tokenizer_inputs = self.tokenizer(
751
+ lvd_gligen_phrases_frame, padding=True, return_tensors="pt").to(device)
752
+ # For the token, we use the same pre-trained text encoder
753
+ # to obtain its text feature
754
+ _text_embeddings = self.text_encoder(
755
+ **tokenizer_inputs).pooler_output
756
+
757
  # For each entity, described in phrases, is denoted with a bounding box,
758
  # we represent the location information as (xmin,ymin,xmax,ymax)
759
  boxes = torch.zeros(max_objs, 4, device=device,
 
762
  text_embeddings = torch.zeros(
763
  max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
764
  )
765
+ if n_objs:
766
+ text_embeddings[:n_objs] = _text_embeddings
767
  # Generate a mask for each object that is entity described by phrases
768
  masks = torch.zeros(max_objs, device=device,
769
  dtype=self.text_encoder.dtype)