m7mdal7aj commited on
Commit
ca90c3f
1 Parent(s): 2a9f27a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -11,13 +11,13 @@ from transformers import Blip2Processor, Blip2ForConditionalGeneration, Instruct
11
  def load_caption_model(blip2=True, instructblip=False):
12
 
13
  if blip2:
14
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="cuda")
15
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="cuda")
16
  #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
17
 
18
  if instructblip:
19
- model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="cuda")
20
- processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="cuda")
21
 
22
  return model, processor
23
 
 
11
  def load_caption_model(blip2=True, instructblip=False):
12
 
13
  if blip2:
14
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
15
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
16
  #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
17
 
18
  if instructblip:
19
+ model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
20
+ processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
21
 
22
  return model, processor
23