Spaces:
Running
on
A10G
Running
on
A10G
tastelikefeet
commited on
Commit
•
95878c0
1
Parent(s):
ddfd056
v2
Browse files- app.py +36 -6
- cldm/cldm.py +10 -1
- cldm/model.py +5 -1
- cldm/recognizer.py +3 -0
- ldm/modules/diffusionmodules/openaimodel.py +1 -1
- ldm/modules/diffusionmodules/util.py +2 -1
app.py
CHANGED
@@ -13,12 +13,41 @@ import re
|
|
13 |
from gradio.components import Component
|
14 |
from util import check_channels, resize_image, save_images
|
15 |
import json
|
|
|
|
|
16 |
|
17 |
BBOX_MAX_NUM = 8
|
18 |
img_save_folder = 'SaveImages'
|
19 |
load_model = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
if load_model:
|
21 |
-
inference = pipeline('my-anytext-task', model='damo/cv_anytext_text_generation_editing', model_revision='v1.1.
|
22 |
|
23 |
|
24 |
def count_lines(prompt):
|
@@ -221,7 +250,8 @@ with block:
|
|
221 |
[<a href="https://arxiv.org/abs/2311.03054" style="color:blue; font-size:18px;">arXiv</a>] \
|
222 |
[<a href="https://github.com/tyxsspa/AnyText" style="color:blue; font-size:18px;">Code</a>] \
|
223 |
[<a href="https://modelscope.cn/models/damo/cv_anytext_text_generation_editing/summary" style="color:blue; font-size:18px;">ModelScope</a>]\
|
224 |
-
|
|
|
225 |
with gr.Row(variant='compact'):
|
226 |
with gr.Column():
|
227 |
with gr.Accordion('🕹Instructions(说明)', open=False,):
|
@@ -305,7 +335,7 @@ with block:
|
|
305 |
rect_xywh_list.extend([x, y, w, h])
|
306 |
|
307 |
rect_img = gr.Image(value=create_canvas(), label="Rext Position(方框位置)", elem_id="MD-bbox-rect-t2i", show_label=False, visible=False)
|
308 |
-
draw_img = gr.Image(value=create_canvas(), label="Draw Position(绘制位置)", visible=True, tool='sketch', show_label=False, brush_radius=
|
309 |
|
310 |
def re_draw():
|
311 |
return [gr.Image(value=create_canvas(), tool='sketch'), gr.Slider(value=512), gr.Slider(value=512)]
|
@@ -357,7 +387,7 @@ with block:
|
|
357 |
ori_img = gr.Image(label='Ori(原图)')
|
358 |
|
359 |
def upload_ref(x):
|
360 |
-
return [gr.Image(type="numpy", brush_radius=
|
361 |
gr.Image(value=x)]
|
362 |
|
363 |
def clear_ref(x):
|
@@ -394,8 +424,8 @@ with block:
|
|
394 |
run_edit.click(fn=process, inputs=[gr.State('edit')] + ips, outputs=[result_gallery, result_info])
|
395 |
|
396 |
block.launch(
|
397 |
-
#server_name='0.0.0.0' if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1",
|
398 |
-
#share=False,
|
399 |
root_path=f"/{os.getenv('GRADIO_PROXY_PATH')}" if os.getenv('GRADIO_PROXY_PATH') else ""
|
400 |
)
|
401 |
# block.launch(server_name='0.0.0.0')
|
|
|
13 |
from gradio.components import Component
|
14 |
from util import check_channels, resize_image, save_images
|
15 |
import json
|
16 |
+
import argparse
|
17 |
+
|
18 |
|
19 |
BBOX_MAX_NUM = 8
|
20 |
img_save_folder = 'SaveImages'
|
21 |
load_model = True
|
22 |
+
|
23 |
+
|
24 |
+
def parse_args():
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument(
|
27 |
+
"--use_fp32",
|
28 |
+
action="store_true",
|
29 |
+
default=False,
|
30 |
+
help="Whether or not to use fp32 during inference."
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--no_translator",
|
34 |
+
action="store_true",
|
35 |
+
default=False,
|
36 |
+
help="Whether or not to use the CH->EN translator, which enable input Chinese prompt and cause ~4GB VRAM."
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--font_path",
|
40 |
+
type=str,
|
41 |
+
default='font/Arial_Unicode.ttf',
|
42 |
+
help="path of a font file"
|
43 |
+
)
|
44 |
+
args = parser.parse_args()
|
45 |
+
return args
|
46 |
+
|
47 |
+
|
48 |
+
args = parse_args()
|
49 |
if load_model:
|
50 |
+
inference = pipeline('my-anytext-task', model='damo/cv_anytext_text_generation_editing', model_revision='v1.1.1', use_fp16=not args.use_fp32, use_translator=not args.no_translator, font_path=args.font_path)
|
51 |
|
52 |
|
53 |
def count_lines(prompt):
|
|
|
250 |
[<a href="https://arxiv.org/abs/2311.03054" style="color:blue; font-size:18px;">arXiv</a>] \
|
251 |
[<a href="https://github.com/tyxsspa/AnyText" style="color:blue; font-size:18px;">Code</a>] \
|
252 |
[<a href="https://modelscope.cn/models/damo/cv_anytext_text_generation_editing/summary" style="color:blue; font-size:18px;">ModelScope</a>]\
|
253 |
+
[<a href="https://huggingface.co/spaces/modelscope/AnyText" style="color:blue; font-size:18px;">HuggingFace</a>]\
|
254 |
+
version: 1.1.1 </div>')
|
255 |
with gr.Row(variant='compact'):
|
256 |
with gr.Column():
|
257 |
with gr.Accordion('🕹Instructions(说明)', open=False,):
|
|
|
335 |
rect_xywh_list.extend([x, y, w, h])
|
336 |
|
337 |
rect_img = gr.Image(value=create_canvas(), label="Rext Position(方框位置)", elem_id="MD-bbox-rect-t2i", show_label=False, visible=False)
|
338 |
+
draw_img = gr.Image(value=create_canvas(), label="Draw Position(绘制位置)", visible=True, tool='sketch', show_label=False, brush_radius=100)
|
339 |
|
340 |
def re_draw():
|
341 |
return [gr.Image(value=create_canvas(), tool='sketch'), gr.Slider(value=512), gr.Slider(value=512)]
|
|
|
387 |
ori_img = gr.Image(label='Ori(原图)')
|
388 |
|
389 |
def upload_ref(x):
|
390 |
+
return [gr.Image(type="numpy", brush_radius=100, tool='sketch'),
|
391 |
gr.Image(value=x)]
|
392 |
|
393 |
def clear_ref(x):
|
|
|
424 |
run_edit.click(fn=process, inputs=[gr.State('edit')] + ips, outputs=[result_gallery, result_info])
|
425 |
|
426 |
block.launch(
|
427 |
+
# server_name='0.0.0.0' if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1",
|
428 |
+
# share=False,
|
429 |
root_path=f"/{os.getenv('GRADIO_PROXY_PATH')}" if os.getenv('GRADIO_PROXY_PATH') else ""
|
430 |
)
|
431 |
# block.launch(server_name='0.0.0.0')
|
cldm/cldm.py
CHANGED
@@ -32,6 +32,8 @@ class ControlledUnetModel(UNetModel):
|
|
32 |
hs = []
|
33 |
with torch.no_grad():
|
34 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
|
|
|
|
35 |
emb = self.time_embed(t_emb)
|
36 |
h = x.type(self.dtype)
|
37 |
for module in self.input_blocks:
|
@@ -124,12 +126,12 @@ class ControlNet(nn.Module):
|
|
124 |
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
125 |
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
126 |
f"attention will still not be set.")
|
127 |
-
|
128 |
self.attention_resolutions = attention_resolutions
|
129 |
self.dropout = dropout
|
130 |
self.channel_mult = channel_mult
|
131 |
self.conv_resample = conv_resample
|
132 |
self.use_checkpoint = use_checkpoint
|
|
|
133 |
self.dtype = th.float16 if use_fp16 else th.float32
|
134 |
self.num_heads = num_heads
|
135 |
self.num_head_channels = num_head_channels
|
@@ -313,6 +315,8 @@ class ControlNet(nn.Module):
|
|
313 |
|
314 |
def forward(self, x, hint, text_info, timesteps, context, **kwargs):
|
315 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
|
|
|
|
316 |
emb = self.time_embed(t_emb)
|
317 |
|
318 |
# guided_hint from text_info
|
@@ -344,6 +348,7 @@ class ControlNet(nn.Module):
|
|
344 |
class ControlLDM(LatentDiffusion):
|
345 |
|
346 |
def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
|
|
|
347 |
super().__init__(*args, **kwargs)
|
348 |
self.control_model = instantiate_from_config(control_stage_config)
|
349 |
self.control_key = control_key
|
@@ -356,6 +361,7 @@ class ControlLDM(LatentDiffusion):
|
|
356 |
self.with_step_weight = with_step_weight
|
357 |
self.use_vae_upsample = use_vae_upsample
|
358 |
self.latin_weight = latin_weight
|
|
|
359 |
if embedding_manager_config is not None and embedding_manager_config.params.valid:
|
360 |
self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
|
361 |
for param in self.embedding_manager.embedding_parameters():
|
@@ -369,6 +375,7 @@ class ControlLDM(LatentDiffusion):
|
|
369 |
args.rec_image_shape = "3, 48, 320"
|
370 |
args.rec_batch_num = 6
|
371 |
args.rec_char_dict_path = './ocr_recog/ppocr_keys_v1.txt'
|
|
|
372 |
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
|
373 |
for param in self.text_predictor.parameters():
|
374 |
param.requires_grad = False
|
@@ -433,6 +440,8 @@ class ControlLDM(LatentDiffusion):
|
|
433 |
diffusion_model = self.model.diffusion_model
|
434 |
_cond = torch.cat(cond['c_crossattn'], 1)
|
435 |
_hint = torch.cat(cond['c_concat'], 1)
|
|
|
|
|
436 |
control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
|
437 |
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
438 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
|
|
|
32 |
hs = []
|
33 |
with torch.no_grad():
|
34 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
35 |
+
if self.use_fp16:
|
36 |
+
t_emb = t_emb.half()
|
37 |
emb = self.time_embed(t_emb)
|
38 |
h = x.type(self.dtype)
|
39 |
for module in self.input_blocks:
|
|
|
126 |
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
127 |
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
128 |
f"attention will still not be set.")
|
|
|
129 |
self.attention_resolutions = attention_resolutions
|
130 |
self.dropout = dropout
|
131 |
self.channel_mult = channel_mult
|
132 |
self.conv_resample = conv_resample
|
133 |
self.use_checkpoint = use_checkpoint
|
134 |
+
self.use_fp16 = use_fp16
|
135 |
self.dtype = th.float16 if use_fp16 else th.float32
|
136 |
self.num_heads = num_heads
|
137 |
self.num_head_channels = num_head_channels
|
|
|
315 |
|
316 |
def forward(self, x, hint, text_info, timesteps, context, **kwargs):
|
317 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
318 |
+
if self.use_fp16:
|
319 |
+
t_emb = t_emb.half()
|
320 |
emb = self.time_embed(t_emb)
|
321 |
|
322 |
# guided_hint from text_info
|
|
|
348 |
class ControlLDM(LatentDiffusion):
|
349 |
|
350 |
def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
|
351 |
+
self.use_fp16 = kwargs.pop('use_fp16', False)
|
352 |
super().__init__(*args, **kwargs)
|
353 |
self.control_model = instantiate_from_config(control_stage_config)
|
354 |
self.control_key = control_key
|
|
|
361 |
self.with_step_weight = with_step_weight
|
362 |
self.use_vae_upsample = use_vae_upsample
|
363 |
self.latin_weight = latin_weight
|
364 |
+
|
365 |
if embedding_manager_config is not None and embedding_manager_config.params.valid:
|
366 |
self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
|
367 |
for param in self.embedding_manager.embedding_parameters():
|
|
|
375 |
args.rec_image_shape = "3, 48, 320"
|
376 |
args.rec_batch_num = 6
|
377 |
args.rec_char_dict_path = './ocr_recog/ppocr_keys_v1.txt'
|
378 |
+
args.use_fp16 = self.use_fp16
|
379 |
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
|
380 |
for param in self.text_predictor.parameters():
|
381 |
param.requires_grad = False
|
|
|
440 |
diffusion_model = self.model.diffusion_model
|
441 |
_cond = torch.cat(cond['c_crossattn'], 1)
|
442 |
_hint = torch.cat(cond['c_concat'], 1)
|
443 |
+
if self.use_fp16:
|
444 |
+
x_noisy = x_noisy.half()
|
445 |
control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
|
446 |
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
447 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
|
cldm/model.py
CHANGED
@@ -21,10 +21,14 @@ def load_state_dict(ckpt_path, location='cpu'):
|
|
21 |
return state_dict
|
22 |
|
23 |
|
24 |
-
def create_model(config_path, cond_stage_path=None):
|
25 |
config = OmegaConf.load(config_path)
|
26 |
if cond_stage_path:
|
27 |
config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked
|
|
|
|
|
|
|
|
|
28 |
model = instantiate_from_config(config.model).cpu()
|
29 |
print(f'Loaded model config from [{config_path}]')
|
30 |
return model
|
|
|
21 |
return state_dict
|
22 |
|
23 |
|
24 |
+
def create_model(config_path, cond_stage_path=None, use_fp16=False):
|
25 |
config = OmegaConf.load(config_path)
|
26 |
if cond_stage_path:
|
27 |
config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked
|
28 |
+
if use_fp16:
|
29 |
+
config.model.params.use_fp16 = True
|
30 |
+
config.model.params.control_stage_config.params.use_fp16 = True
|
31 |
+
config.model.params.unet_config.params.use_fp16 = True
|
32 |
model = instantiate_from_config(config.model).cpu()
|
33 |
print(f'Loaded model config from [{config_path}]')
|
34 |
return model
|
cldm/recognizer.py
CHANGED
@@ -132,6 +132,7 @@ class TextRecognizer(object):
|
|
132 |
self.chars = self.get_char_dict(args.rec_char_dict_path)
|
133 |
self.char2id = {x: i for i, x in enumerate(self.chars)}
|
134 |
self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
|
|
|
135 |
|
136 |
# img: CHW
|
137 |
def resize_norm_img(self, img, max_wh_ratio):
|
@@ -188,6 +189,8 @@ class TextRecognizer(object):
|
|
188 |
# max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
|
189 |
for ino in range(beg_img_no, end_img_no):
|
190 |
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
|
|
|
|
|
191 |
norm_img = norm_img.unsqueeze(0)
|
192 |
norm_img_batch.append(norm_img)
|
193 |
norm_img_batch = torch.cat(norm_img_batch, dim=0)
|
|
|
132 |
self.chars = self.get_char_dict(args.rec_char_dict_path)
|
133 |
self.char2id = {x: i for i, x in enumerate(self.chars)}
|
134 |
self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
|
135 |
+
self.use_fp16 = args.use_fp16
|
136 |
|
137 |
# img: CHW
|
138 |
def resize_norm_img(self, img, max_wh_ratio):
|
|
|
189 |
# max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
|
190 |
for ino in range(beg_img_no, end_img_no):
|
191 |
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
|
192 |
+
if self.use_fp16:
|
193 |
+
norm_img = norm_img.half()
|
194 |
norm_img = norm_img.unsqueeze(0)
|
195 |
norm_img_batch.append(norm_img)
|
196 |
norm_img_batch = torch.cat(norm_img_batch, dim=0)
|
ldm/modules/diffusionmodules/openaimodel.py
CHANGED
@@ -510,7 +510,7 @@ class UNetModel(nn.Module):
|
|
510 |
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
511 |
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
512 |
f"attention will still not be set.")
|
513 |
-
|
514 |
self.attention_resolutions = attention_resolutions
|
515 |
self.dropout = dropout
|
516 |
self.channel_mult = channel_mult
|
|
|
510 |
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
511 |
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
512 |
f"attention will still not be set.")
|
513 |
+
self.use_fp16 = use_fp16
|
514 |
self.attention_resolutions = attention_resolutions
|
515 |
self.dropout = dropout
|
516 |
self.channel_mult = channel_mult
|
ldm/modules/diffusionmodules/util.py
CHANGED
@@ -216,7 +216,8 @@ class SiLU(nn.Module):
|
|
216 |
|
217 |
class GroupNorm32(nn.GroupNorm):
|
218 |
def forward(self, x):
|
219 |
-
return super().forward(x.float()).type(x.dtype)
|
|
|
220 |
|
221 |
def conv_nd(dims, *args, **kwargs):
|
222 |
"""
|
|
|
216 |
|
217 |
class GroupNorm32(nn.GroupNorm):
|
218 |
def forward(self, x):
|
219 |
+
# return super().forward(x.float()).type(x.dtype)
|
220 |
+
return super().forward(x).type(x.dtype)
|
221 |
|
222 |
def conv_nd(dims, *args, **kwargs):
|
223 |
"""
|