Spaces:
BAAI
/
Running on L40S

ryanzhangfan commited on
Commit
db312d6
1 Parent(s): 9f2b36a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -27
app.py CHANGED
@@ -44,6 +44,9 @@ EMU_GEN_HUB = "BAAI/Emu3-Gen"
44
  EMU_CHAT_HUB = "BAAI/Emu3-Chat"
45
  VQ_HUB = "BAAI/Emu3-VisionTokenizer"
46
 
 
 
 
47
  # Prepare models and processors
48
  # Emu3-Gen model and processor
49
  gen_model = AutoModelForCausalLM.from_pretrained(
@@ -54,15 +57,6 @@ gen_model = AutoModelForCausalLM.from_pretrained(
54
  trust_remote_code=True,
55
  )
56
 
57
- # Emu3-Chat model and processor
58
- chat_model = AutoModelForCausalLM.from_pretrained(
59
- EMU_CHAT_HUB,
60
- device_map="cpu",
61
- torch_dtype=torch.bfloat16,
62
- attn_implementation="flash_attention_2",
63
- trust_remote_code=True,
64
- )
65
-
66
  tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
67
  image_processor = AutoImageProcessor.from_pretrained(
68
  VQ_HUB, trust_remote_code=True
@@ -70,16 +64,16 @@ image_processor = AutoImageProcessor.from_pretrained(
70
  image_tokenizer = AutoModel.from_pretrained(
71
  VQ_HUB, device_map="cpu", trust_remote_code=True
72
  ).eval()
73
- processor = Emu3Processor(
74
- image_processor, image_tokenizer, tokenizer
75
- )
76
 
77
  print(device)
78
  gen_model.to(device)
79
- chat_model.to(device)
80
  image_tokenizer.to(device)
81
 
82
- @spaces.GPU(duration=120)
 
 
 
 
83
  def generate_image(prompt):
84
  POSITIVE_PROMPT = " masterpiece, film grained, best quality."
85
  NEGATIVE_PROMPT = (
@@ -139,6 +133,48 @@ def generate_image(prompt):
139
  return im
140
  return None
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  @spaces.GPU
143
  def vision_language_understanding(image, text):
144
  inputs = processor(
@@ -176,19 +212,8 @@ def chat(history, user_input, user_image):
176
  # Append the user input and response to the history
177
  history = history + [(image2str(user_image) + "<br>" + user_input, response)]
178
  else:
179
- # history = history + [(user_input, "Currently do not support image genration, please provide an valid image.")]
180
- # """
181
- # Use Emu3-Gen for image generation
182
- generated_image = generate_image(user_input)
183
- if generated_image is not None:
184
- # Append the user input and generated image to the history
185
- history = history + [(user_input, image2str(generated_image))]
186
- else:
187
- # If image generation failed, respond with an error message
188
- history = history + [
189
- (user_input, "Sorry, I could not generate an image.")
190
- ]
191
- # """
192
  return history, history, gr.update(value=None)
193
 
194
  def clear_input():
 
44
  EMU_CHAT_HUB = "BAAI/Emu3-Chat"
45
  VQ_HUB = "BAAI/Emu3-VisionTokenizer"
46
 
47
+
48
+ # uncomment to use gen model
49
+ """
50
  # Prepare models and processors
51
  # Emu3-Gen model and processor
52
  gen_model = AutoModelForCausalLM.from_pretrained(
 
57
  trust_remote_code=True,
58
  )
59
 
 
 
 
 
 
 
 
 
 
60
  tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
61
  image_processor = AutoImageProcessor.from_pretrained(
62
  VQ_HUB, trust_remote_code=True
 
64
  image_tokenizer = AutoModel.from_pretrained(
65
  VQ_HUB, device_map="cpu", trust_remote_code=True
66
  ).eval()
 
 
 
67
 
68
  print(device)
69
  gen_model.to(device)
 
70
  image_tokenizer.to(device)
71
 
72
+ processor = Emu3Processor(
73
+ image_processor, image_tokenizer, tokenizer
74
+ )
75
+
76
+ @spaces.GPU(duration=300)
77
  def generate_image(prompt):
78
  POSITIVE_PROMPT = " masterpiece, film grained, best quality."
79
  NEGATIVE_PROMPT = (
 
133
  return im
134
  return None
135
 
136
+ def chat(history, user_input, user_image):
137
+ if user_image is not None:
138
+ history = history + [("", "Sorry, gen model do not accept image input")]
139
+ else:
140
+ # Use Emu3-Gen for image generation
141
+ generated_image = generate_image(user_input)
142
+ if generated_image is not None:
143
+ # Append the user input and generated image to the history
144
+ history = history + [(user_input, image2str(generated_image))]
145
+ else:
146
+ # If image generation failed, respond with an error message
147
+ history = history + [
148
+ (user_input, "Sorry, I could not generate an image.")
149
+ ]
150
+ return history, history, gr.update(value=None)
151
+ """
152
+
153
+ # Emu3-Chat model and processor
154
+ chat_model = AutoModelForCausalLM.from_pretrained(
155
+ EMU_CHAT_HUB,
156
+ device_map="cpu",
157
+ torch_dtype=torch.bfloat16,
158
+ attn_implementation="flash_attention_2",
159
+ trust_remote_code=True,
160
+ )
161
+
162
+ tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
163
+ image_processor = AutoImageProcessor.from_pretrained(
164
+ VQ_HUB, trust_remote_code=True
165
+ )
166
+ image_tokenizer = AutoModel.from_pretrained(
167
+ VQ_HUB, device_map="cpu", trust_remote_code=True
168
+ ).eval()
169
+
170
+ print(device)
171
+ chat_model.to(device)
172
+ image_tokenizer.to(device)
173
+
174
+ processor = Emu3Processor(
175
+ image_processor, image_tokenizer, tokenizer
176
+ )
177
+
178
  @spaces.GPU
179
  def vision_language_understanding(image, text):
180
  inputs = processor(
 
212
  # Append the user input and response to the history
213
  history = history + [(image2str(user_image) + "<br>" + user_input, response)]
214
  else:
215
+ history = history + [(user_input, "Sorry, please specify a valid image for vl understanding.")]
216
+
 
 
 
 
 
 
 
 
 
 
 
217
  return history, history, gr.update(value=None)
218
 
219
  def clear_input():