Spaces:
Runtime error
Runtime error
jiwan-chung
commited on
Commit
•
8a76135
1
Parent(s):
a2217bb
example update & better descriptions
Browse files- app.py +14 -8
- arguments.py +1 -1
- run.py +8 -7
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from pathlib import Path
|
|
|
2 |
|
3 |
import gdown
|
4 |
# from PIL import Image
|
@@ -14,7 +15,10 @@ assert Path('./data/esper_demo').is_dir()
|
|
14 |
|
15 |
# example image from COCO data
|
16 |
image_urls = {
|
17 |
-
'108953': 'https://farm8.staticflickr.com/7160/6484651991_9d1eaa557a_z.jpg'
|
|
|
|
|
|
|
18 |
}
|
19 |
images = {}
|
20 |
for k, url in image_urls.items():
|
@@ -32,12 +36,14 @@ for k, v in images.items():
|
|
32 |
images[k] = image
|
33 |
'''
|
34 |
|
|
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
20,
|
40 |
-
False
|
41 |
-
] for v in images.values()]
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
+
from itertools import chain
|
3 |
|
4 |
import gdown
|
5 |
# from PIL import Image
|
|
|
15 |
|
16 |
# example image from COCO data
|
17 |
image_urls = {
|
18 |
+
# '108953': 'https://farm8.staticflickr.com/7160/6484651991_9d1eaa557a_z.jpg'
|
19 |
+
# '394905': 'http://farm5.staticflickr.com/4151/4955575345_6c6bdfae9d_z.jpg',
|
20 |
+
'330341': 'https://farm4.staticflickr.com/3018/3451080626_4a43435f4b_z.jpg',
|
21 |
+
'396820': 'https://farm1.staticflickr.com/148/351466274_8d7174e11b_z.jpg'
|
22 |
}
|
23 |
images = {}
|
24 |
for k, url in image_urls.items():
|
|
|
36 |
images[k] = image
|
37 |
'''
|
38 |
|
39 |
+
prompts = ['blog:', 'dialogue:', 'This is my favorite poem:']
|
40 |
|
41 |
+
title = 'Demo for ESPER'
|
42 |
+
description = 'backbone: style-finetuned GPT-2-base'
|
43 |
+
prompt_label = f'Prompt (try pretrained styles such as "blog:" or "dialogue:" or unseen prompts such as "{prompts[-1]}")'
|
|
|
|
|
|
|
44 |
|
45 |
+
examples = [[[v, prompt, 20, False ] for prompt in prompts]
|
46 |
+
for v in images.values()]
|
47 |
+
examples = list(chain(*examples))
|
48 |
+
|
49 |
+
launch(examples, title=title, description=description, prompt_label=prompt_label)
|
arguments.py
CHANGED
@@ -20,7 +20,7 @@ def get_args():
|
|
20 |
parser.add_argument(
|
21 |
'--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
|
22 |
parser.add_argument(
|
23 |
-
'--checkpoint', type=str, default='./data/esper_demo/ckpt/
|
24 |
|
25 |
parser.add_argument(
|
26 |
'--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
|
|
|
20 |
parser.add_argument(
|
21 |
'--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
|
22 |
parser.add_argument(
|
23 |
+
'--checkpoint', type=str, default='./data/esper_demo/ckpt/gpt2_style_2', help='checkpoint file path')
|
24 |
|
25 |
parser.add_argument(
|
26 |
'--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
|
run.py
CHANGED
@@ -140,21 +140,22 @@ img, _, text = run(sample_img, 'There lies', 50, 20, sample=False)
|
|
140 |
print('test_run:', text)
|
141 |
'''
|
142 |
|
143 |
-
def launch(examples=None):
|
144 |
args = get_args()
|
145 |
inferer = prepare(args)
|
146 |
runner = Runner(inferer)
|
147 |
|
148 |
iface = gr.Interface(
|
149 |
-
title=
|
|
|
150 |
fn=runner.__call__,
|
151 |
-
inputs=[gr.components.Image(shape=(224, 224)),
|
152 |
-
gr.components.Textbox(label=
|
153 |
-
gr.components.Slider(20,
|
154 |
# gr.components.Slider(10, 100, step=1, label='window_size'),
|
155 |
gr.components.Checkbox(label='do sample')],
|
156 |
-
outputs=[gr.components.Textbox(label='
|
157 |
-
gr.components.Textbox(label='
|
158 |
examples=examples
|
159 |
)
|
160 |
if args.port is not None:
|
|
|
140 |
print('test_run:', text)
|
141 |
'''
|
142 |
|
143 |
+
def launch(examples=None, title='Demo for ESPER', description=None, prompt_label='Prompt'):
|
144 |
args = get_args()
|
145 |
inferer = prepare(args)
|
146 |
runner = Runner(inferer)
|
147 |
|
148 |
iface = gr.Interface(
|
149 |
+
title=title,
|
150 |
+
description=description,
|
151 |
fn=runner.__call__,
|
152 |
+
inputs=[gr.components.Image(shape=(224, 224), label='Image'),
|
153 |
+
gr.components.Textbox(label=prompt_label),
|
154 |
+
gr.components.Slider(20, 40, step=1, label='Length'),
|
155 |
# gr.components.Slider(10, 100, step=1, label='window_size'),
|
156 |
gr.components.Checkbox(label='do sample')],
|
157 |
+
outputs=[gr.components.Textbox(label='Prompt'),
|
158 |
+
gr.components.Textbox(label='Generation')],
|
159 |
examples=examples
|
160 |
)
|
161 |
if args.port is not None:
|