akhaliq HF staff commited on
Commit
ce3d5f3
1 Parent(s): 9d5ad98

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import gradio as gr
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ AutoImageProcessor,
7
+ AutoModel,
8
+ )
9
+ from transformers.generation.configuration_utils import GenerationConfig
10
+ from transformers.generation import (
11
+ LogitsProcessorList,
12
+ PrefixConstrainedLogitsProcessor,
13
+ UnbatchedClassifierFreeGuidanceLogitsProcessor,
14
+ )
15
+ import torch
16
+ from emu3.mllm.processing_emu3 import Emu3Processor
17
+
18
+ # Model paths
19
+ EMU_GEN_HUB = "BAAI/Emu3-Gen"
20
+ EMU_CHAT_HUB = "BAAI/Emu3-Chat"
21
+ VQ_HUB = "BAAI/Emu3-VisionTokenizer"
22
+
23
+ # Prepare models and processors
24
+ # Emu3-Gen model and processor
25
+ gen_model = AutoModelForCausalLM.from_pretrained(
26
+ EMU_GEN_HUB,
27
+ device_map="cuda:0",
28
+ torch_dtype=torch.bfloat16,
29
+ attn_implementation="flash_attention_2",
30
+ trust_remote_code=True,
31
+ )
32
+
33
+ gen_tokenizer = AutoTokenizer.from_pretrained(EMU_GEN_HUB, trust_remote_code=True)
34
+ gen_image_processor = AutoImageProcessor.from_pretrained(
35
+ VQ_HUB, trust_remote_code=True
36
+ )
37
+ gen_image_tokenizer = AutoModel.from_pretrained(
38
+ VQ_HUB, device_map="cuda:0", trust_remote_code=True
39
+ ).eval()
40
+ gen_processor = Emu3Processor(gen_image_processor, gen_image_tokenizer, gen_tokenizer)
41
+
42
+ # Emu3-Chat model and processor
43
+ chat_model = AutoModelForCausalLM.from_pretrained(
44
+ EMU_CHAT_HUB,
45
+ device_map="cuda:0",
46
+ torch_dtype=torch.bfloat16,
47
+ attn_implementation="flash_attention_2",
48
+ trust_remote_code=True,
49
+ )
50
+
51
+ chat_tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
52
+ chat_image_processor = AutoImageProcessor.from_pretrained(
53
+ VQ_HUB, trust_remote_code=True
54
+ )
55
+ chat_image_tokenizer = AutoModel.from_pretrained(
56
+ VQ_HUB, device_map="cuda:0", trust_remote_code=True
57
+ ).eval()
58
+ chat_processor = Emu3Processor(
59
+ chat_image_processor, chat_image_tokenizer, chat_tokenizer
60
+ )
61
+
62
+ def generate_image(prompt):
63
+ POSITIVE_PROMPT = " masterpiece, film grained, best quality."
64
+ NEGATIVE_PROMPT = (
65
+ "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, "
66
+ "fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, "
67
+ "signature, watermark, username, blurry."
68
+ )
69
+
70
+ classifier_free_guidance = 3.0
71
+ full_prompt = prompt + POSITIVE_PROMPT
72
+
73
+ kwargs = dict(
74
+ mode="G",
75
+ ratio="1:1",
76
+ image_area=gen_model.config.image_area,
77
+ return_tensors="pt",
78
+ )
79
+ pos_inputs = gen_processor(text=full_prompt, **kwargs)
80
+ neg_inputs = gen_processor(text=NEGATIVE_PROMPT, **kwargs)
81
+
82
+ # Prepare hyperparameters
83
+ GENERATION_CONFIG = GenerationConfig(
84
+ use_cache=True,
85
+ eos_token_id=gen_model.config.eos_token_id,
86
+ pad_token_id=gen_model.config.pad_token_id,
87
+ max_new_tokens=40960,
88
+ do_sample=True,
89
+ top_k=2048,
90
+ )
91
+
92
+ h, w = pos_inputs.image_size[0]
93
+ constrained_fn = gen_processor.build_prefix_constrained_fn(h, w)
94
+ logits_processor = LogitsProcessorList(
95
+ [
96
+ UnbatchedClassifierFreeGuidanceLogitsProcessor(
97
+ classifier_free_guidance,
98
+ gen_model,
99
+ unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
100
+ ),
101
+ PrefixConstrainedLogitsProcessor(
102
+ constrained_fn,
103
+ num_beams=1,
104
+ ),
105
+ ]
106
+ )
107
+
108
+ # Generate
109
+ outputs = gen_model.generate(
110
+ pos_inputs.input_ids.to("cuda:0"),
111
+ generation_config=GENERATION_CONFIG,
112
+ logits_processor=logits_processor,
113
+ )
114
+
115
+ mm_list = gen_processor.decode(outputs[0])
116
+ for idx, im in enumerate(mm_list):
117
+ if isinstance(im, Image.Image):
118
+ return im
119
+ return None
120
+
121
+ def vision_language_understanding(image, text):
122
+ inputs = chat_processor(
123
+ text=text,
124
+ image=image,
125
+ mode="U",
126
+ padding_side="left",
127
+ padding="longest",
128
+ return_tensors="pt",
129
+ )
130
+
131
+ # Prepare hyperparameters
132
+ GENERATION_CONFIG = GenerationConfig(
133
+ pad_token_id=chat_tokenizer.pad_token_id,
134
+ bos_token_id=chat_tokenizer.bos_token_id,
135
+ eos_token_id=chat_tokenizer.eos_token_id,
136
+ max_new_tokens=320,
137
+ )
138
+
139
+ # Generate
140
+ outputs = chat_model.generate(
141
+ inputs.input_ids.to("cuda:0"),
142
+ generation_config=GENERATION_CONFIG,
143
+ max_new_tokens=320,
144
+ )
145
+
146
+ outputs = outputs[:, inputs.input_ids.shape[-1] :]
147
+ response = chat_processor.batch_decode(outputs, skip_special_tokens=True)[0]
148
+ return response
149
+
150
+ def chat(history, user_input, user_image):
151
+ if user_image is not None:
152
+ # Use Emu3-Chat for vision-language understanding
153
+ response = vision_language_understanding(user_image, user_input)
154
+ # Append the user input and response to the history
155
+ history = history + [(user_input, response)]
156
+ else:
157
+ # Use Emu3-Gen for image generation
158
+ generated_image = generate_image(user_input)
159
+ if generated_image is not None:
160
+ # Append the user input and generated image to the history
161
+ history = history + [(user_input, generated_image)]
162
+ else:
163
+ # If image generation failed, respond with an error message
164
+ history = history + [
165
+ (user_input, "Sorry, I could not generate an image.")
166
+ ]
167
+ return history, history, gr.update(value=None)
168
+
169
+ def clear_input():
170
+ return gr.update(value="")
171
+
172
+ with gr.Blocks() as demo:
173
+ gr.Markdown("# Emu3 Chatbot Demo")
174
+ gr.Markdown(
175
+ "This is a chatbot demo for image generation and vision-language understanding using Emu3 models."
176
+ )
177
+
178
+ chatbot = gr.Chatbot()
179
+ state = gr.State([])
180
+ with gr.Row():
181
+ with gr.Column(scale=0.85):
182
+ user_input = gr.Textbox(
183
+ show_label=False, placeholder="Type your message here...", lines=2
184
+ ).style(container=False)
185
+ with gr.Column(scale=0.15, min_width=0):
186
+ submit_btn = gr.Button("Send")
187
+ user_image = gr.Image(
188
+ source="upload", type="pil", label="Upload an image (optional)"
189
+ )
190
+
191
+ submit_btn.click(
192
+ chat,
193
+ inputs=[state, user_input, user_image],
194
+ outputs=[chatbot, state, user_image],
195
+ ).then(fn=clear_input, inputs=[], outputs=user_input)
196
+ user_input.submit(
197
+ chat,
198
+ inputs=[state, user_input, user_image],
199
+ outputs=[chatbot, state, user_image],
200
+ ).then(fn=clear_input, inputs=[], outputs=user_input)
201
+
202
+ demo.launch()