Jyothirmai commited on
Commit
897dd2f
1 Parent(s): 8b9889c

Update clipGPT.py

Browse files
Files changed (1) hide show
  1. clipGPT.py +5 -5
clipGPT.py CHANGED
@@ -70,11 +70,11 @@ class ClipGPT2Model(nn.Module):
70
  def generate_beam(
71
  model,
72
  tokenizer,
 
 
 
73
  beam_size: int = 10,
74
  prompt=None,
75
- embed=None,
76
- entry_length=76,
77
- temperature=0.9,
78
  stop_token: str = ".",
79
  ):
80
 
@@ -144,7 +144,7 @@ def generate_beam(
144
 
145
 
146
 
147
- def generate_caption_clipgpt(img):
148
 
149
  prefix_length = 10
150
  model = ClipGPT2Model(prefix_length)
@@ -164,7 +164,7 @@ def generate_caption_clipgpt(img):
164
  with torch.no_grad():
165
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
166
  prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
167
- beam_caption = generate_beam(model, tokenizer, embed=prefix_embed)[0]
168
 
169
  end_time = time.time()
170
  print("--- Time taken to generate: %s seconds ---" % (end_time - start_time))
 
70
  def generate_beam(
71
  model,
72
  tokenizer,
73
+ entry_length,
74
+ temperature,
75
+ embed=None,
76
  beam_size: int = 10,
77
  prompt=None,
 
 
 
78
  stop_token: str = ".",
79
  ):
80
 
 
144
 
145
 
146
 
147
+ def generate_caption_clipgpt(img, entry_length, temperature):
148
 
149
  prefix_length = 10
150
  model = ClipGPT2Model(prefix_length)
 
164
  with torch.no_grad():
165
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
166
  prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
167
+ beam_caption = generate_beam(model, tokenizer, entry_length, temperature, embed=prefix_embed)[0]
168
 
169
  end_time = time.time()
170
  print("--- Time taken to generate: %s seconds ---" % (end_time - start_time))