Spaces:
BAAI
/
Running on L40S

ryanzhangfan commited on
Commit
058e220
1 Parent(s): 542fa16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -60
app.py CHANGED
@@ -17,24 +17,22 @@ from transformers.generation import (
17
  )
18
  import torch
19
  from emu3.mllm.processing_emu3 import Emu3Processor
20
- import spaces
21
 
22
  import io
23
  import base64
24
 
 
 
 
 
 
 
25
  def image2str(image):
26
  buf = io.BytesIO()
27
  image.save(buf, format="PNG")
28
  i_str = base64.b64encode(buf.getvalue()).decode()
29
  return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>'
30
 
31
- # Install flash attention, skipping CUDA build if necessary
32
- subprocess.run(
33
- "pip install flash-attn --no-build-isolation",
34
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
35
- shell=True,
36
- )
37
-
38
  print(gr.__version__)
39
 
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -46,7 +44,6 @@ VQ_HUB = "BAAI/Emu3-VisionTokenizer"
46
 
47
 
48
  # uncomment to use gen model
49
- """
50
  # Prepare models and processors
51
  # Emu3-Gen model and processor
52
  gen_model = AutoModelForCausalLM.from_pretrained(
@@ -55,7 +52,15 @@ gen_model = AutoModelForCausalLM.from_pretrained(
55
  torch_dtype=torch.bfloat16,
56
  attn_implementation="flash_attention_2",
57
  trust_remote_code=True,
58
- )
 
 
 
 
 
 
 
 
59
 
60
  tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
61
  image_processor = AutoImageProcessor.from_pretrained(
@@ -66,14 +71,12 @@ image_tokenizer = AutoModel.from_pretrained(
66
  ).eval()
67
 
68
  print(device)
69
- gen_model.to(device)
70
  image_tokenizer.to(device)
71
 
72
  processor = Emu3Processor(
73
  image_processor, image_tokenizer, tokenizer
74
  )
75
 
76
- @spaces.GPU(duration=300)
77
  def generate_image(prompt):
78
  POSITIVE_PROMPT = " masterpiece, film grained, best quality."
79
  NEGATIVE_PROMPT = (
@@ -104,6 +107,9 @@ def generate_image(prompt):
104
  top_k=2048,
105
  )
106
 
 
 
 
107
  h, w = pos_inputs.image_size[0]
108
  constrained_fn = processor.build_prefix_constrained_fn(h, w)
109
  logits_processor = LogitsProcessorList(
@@ -128,54 +134,17 @@ def generate_image(prompt):
128
  )
129
 
130
  mm_list = processor.decode(outputs[0])
 
131
  for idx, im in enumerate(mm_list):
132
  if isinstance(im, Image.Image):
133
- return im
134
- return None
135
 
136
- def chat(history, user_input, user_image):
137
- if user_image is not None:
138
- history = history + [(image2str(user_image) + "<br>" + user_input, "Sorry, gen model do not accept image input")]
139
- else:
140
- # Use Emu3-Gen for image generation
141
- generated_image = generate_image(user_input)
142
- if generated_image is not None:
143
- # Append the user input and generated image to the history
144
- history = history + [(user_input, image2str(generated_image))]
145
- else:
146
- # If image generation failed, respond with an error message
147
- history = history + [
148
- (user_input, "Sorry, I could not generate an image.")
149
- ]
150
- return history, history, gr.update(value=None)
151
- """
152
 
153
- # Emu3-Chat model and processor
154
- chat_model = AutoModelForCausalLM.from_pretrained(
155
- EMU_CHAT_HUB,
156
- device_map="cpu",
157
- torch_dtype=torch.bfloat16,
158
- attn_implementation="flash_attention_2",
159
- trust_remote_code=True,
160
- )
161
-
162
- tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
163
- image_processor = AutoImageProcessor.from_pretrained(
164
- VQ_HUB, trust_remote_code=True
165
- )
166
- image_tokenizer = AutoModel.from_pretrained(
167
- VQ_HUB, device_map="cpu", trust_remote_code=True
168
- ).eval()
169
-
170
- print(device)
171
- chat_model.to(device)
172
- image_tokenizer.to(device)
173
-
174
- processor = Emu3Processor(
175
- image_processor, image_tokenizer, tokenizer
176
- )
177
-
178
- @spaces.GPU
179
  def vision_language_understanding(image, text):
180
  inputs = processor(
181
  text=text,
@@ -194,6 +163,9 @@ def vision_language_understanding(image, text):
194
  max_new_tokens=320,
195
  )
196
 
 
 
 
197
  # Generate
198
  outputs = chat_model.generate(
199
  inputs.input_ids.to(device),
@@ -203,8 +175,13 @@ def vision_language_understanding(image, text):
203
 
204
  outputs = outputs[:, inputs.input_ids.shape[-1] :]
205
  response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
 
 
 
206
  return response
207
 
 
208
  def chat(history, user_input, user_image):
209
  if user_image is not None:
210
  # Use Emu3-Chat for vision-language understanding
@@ -212,21 +189,32 @@ def chat(history, user_input, user_image):
212
  # Append the user input and response to the history
213
  history = history + [(image2str(user_image) + "<br>" + user_input, response)]
214
  else:
215
- history = history + [(user_input, "Sorry, please specify a valid image for vl understanding.")]
 
 
 
 
 
 
 
 
 
216
 
217
  return history, history, gr.update(value=None)
218
 
219
- # uncomment to here to disable chat
220
- # """
221
-
222
  def clear_input():
223
  return gr.update(value="")
224
 
 
225
  with gr.Blocks() as demo:
226
  gr.Markdown("# Emu3 Chatbot Demo")
227
  gr.Markdown(
228
  "This is a chatbot demo for image generation and vision-language understanding using Emu3 models."
229
  )
 
 
 
230
 
231
  chatbot = gr.Chatbot()
232
  state = gr.State([])
 
17
  )
18
  import torch
19
  from emu3.mllm.processing_emu3 import Emu3Processor
 
20
 
21
  import io
22
  import base64
23
 
24
+ subprocess.run(
25
+ "pip3 install flash-attn --no-build-isolation",
26
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
27
+ shell=True,
28
+ )
29
+
30
  def image2str(image):
31
  buf = io.BytesIO()
32
  image.save(buf, format="PNG")
33
  i_str = base64.b64encode(buf.getvalue()).decode()
34
  return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>'
35
 
 
 
 
 
 
 
 
36
  print(gr.__version__)
37
 
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
44
 
45
 
46
  # uncomment to use gen model
 
47
  # Prepare models and processors
48
  # Emu3-Gen model and processor
49
  gen_model = AutoModelForCausalLM.from_pretrained(
 
52
  torch_dtype=torch.bfloat16,
53
  attn_implementation="flash_attention_2",
54
  trust_remote_code=True,
55
+ ).eval()
56
+
57
+ chat_model = AutoModelForCausalLM.from_pretrained(
58
+ EMU_CHAT_HUB,
59
+ device_map="cpu",
60
+ torch_dtype=torch.bfloat16,
61
+ attn_implementation="flash_attention_2",
62
+ trust_remote_code=True,
63
+ ).eval()
64
 
65
  tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
66
  image_processor = AutoImageProcessor.from_pretrained(
 
71
  ).eval()
72
 
73
  print(device)
 
74
  image_tokenizer.to(device)
75
 
76
  processor = Emu3Processor(
77
  image_processor, image_tokenizer, tokenizer
78
  )
79
 
 
80
  def generate_image(prompt):
81
  POSITIVE_PROMPT = " masterpiece, film grained, best quality."
82
  NEGATIVE_PROMPT = (
 
107
  top_k=2048,
108
  )
109
 
110
+ torch.cuda.empty_cache()
111
+ gen_model.to(device)
112
+
113
  h, w = pos_inputs.image_size[0]
114
  constrained_fn = processor.build_prefix_constrained_fn(h, w)
115
  logits_processor = LogitsProcessorList(
 
134
  )
135
 
136
  mm_list = processor.decode(outputs[0])
137
+ result = None
138
  for idx, im in enumerate(mm_list):
139
  if isinstance(im, Image.Image):
140
+ result = im
141
+ break
142
 
143
+ gen_model.cpu()
144
+ torch.cuda.empty_cache()
145
+
146
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  def vision_language_understanding(image, text):
149
  inputs = processor(
150
  text=text,
 
163
  max_new_tokens=320,
164
  )
165
 
166
+ torch.cuda.empty_cache()
167
+ chat_model.to(device)
168
+
169
  # Generate
170
  outputs = chat_model.generate(
171
  inputs.input_ids.to(device),
 
175
 
176
  outputs = outputs[:, inputs.input_ids.shape[-1] :]
177
  response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
178
+
179
+ chat_model.cpu()
180
+ torch.cuda.empty_cache()
181
+
182
  return response
183
 
184
+
185
  def chat(history, user_input, user_image):
186
  if user_image is not None:
187
  # Use Emu3-Chat for vision-language understanding
 
189
  # Append the user input and response to the history
190
  history = history + [(image2str(user_image) + "<br>" + user_input, response)]
191
  else:
192
+ # Use Emu3-Gen for image generation
193
+ generated_image = generate_image(user_input)
194
+ if generated_image is not None:
195
+ # Append the user input and generated image to the history
196
+ history = history + [(user_input, image2str(generated_image))]
197
+ else:
198
+ # If image generation failed, respond with an error message
199
+ history = history + [
200
+ (user_input, "Sorry, I could not generate an image.")
201
+ ]
202
 
203
  return history, history, gr.update(value=None)
204
 
205
+
 
 
206
  def clear_input():
207
  return gr.update(value="")
208
 
209
+
210
  with gr.Blocks() as demo:
211
  gr.Markdown("# Emu3 Chatbot Demo")
212
  gr.Markdown(
213
  "This is a chatbot demo for image generation and vision-language understanding using Emu3 models."
214
  )
215
+ gr.Markdown(
216
+ "Please pass only text input for image generation and both image and text for vision-language understanding"
217
+ )
218
 
219
  chatbot = gr.Chatbot()
220
  state = gr.State([])