update
Browse files
app.py
CHANGED
@@ -122,15 +122,26 @@ def preprocess(text, task):
|
|
122 |
|
123 |
return task_style_to_task_prefix[task] + "\n" + text + "\n答案:"
|
124 |
|
125 |
-
def inference_gen(text, task):
|
126 |
text = preprocess(text, task)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
#print(text)
|
128 |
try:
|
129 |
prediction = cl.generate(
|
130 |
model_name='clueai-base',
|
131 |
-
prompt=text
|
|
|
132 |
except Exception as e:
|
133 |
-
logger.error(f"error, e")
|
134 |
return
|
135 |
|
136 |
return prediction.generations[0].text
|
@@ -141,7 +152,9 @@ from io import BytesIO
|
|
141 |
from PIL import Image
|
142 |
def inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale):
|
143 |
try:
|
144 |
-
res = requests.get(f"https://www.clueai.cn/clueai/hf_text2image?text={text}
|
|
|
|
|
145 |
except Exception as e:
|
146 |
logger.error(f"error, {e}")
|
147 |
return
|
@@ -167,14 +180,21 @@ with gr.Blocks(css=css, title="ClueAI") as demo:
|
|
167 |
task = gr.Dropdown(label="任务", show_label=True, choices=task_styles, value="标题生成文章")
|
168 |
btn = gr.Button("生成",elem_id="gen_btn_1").style(full_width=False)
|
169 |
with gr.Accordion("高级操作", open=False):
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
with gr.Row(variant="compact").style( equal_height=True):
|
172 |
output_text = gr.Textbox(
|
173 |
label="输出", show_label=True, max_lines=50,
|
174 |
placeholder="在这里展示结果",
|
175 |
)
|
176 |
gr.Examples(examples_list, [task, text], label="示例")
|
177 |
-
input_params = [text, task]
|
178 |
#text.submit(inference_gen, inputs=input_params, outputs=output_text)
|
179 |
btn.click(inference_gen, inputs=input_params, outputs=output_text)
|
180 |
|
|
|
122 |
|
123 |
return task_style_to_task_prefix[task] + "\n" + text + "\n答案:"
|
124 |
|
125 |
+
def inference_gen(text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty):
|
126 |
text = preprocess(text, task)
|
127 |
+
generate_config = {
|
128 |
+
"do_sample": do_sample,
|
129 |
+
"top_p": top_p,
|
130 |
+
"top_k": top_k,
|
131 |
+
"max_length": max_token,
|
132 |
+
"temperature": temperature,
|
133 |
+
"num_beams": beam_size,
|
134 |
+
"length_penalty": length_penalty
|
135 |
+
}
|
136 |
+
#print(generate_config)
|
137 |
#print(text)
|
138 |
try:
|
139 |
prediction = cl.generate(
|
140 |
model_name='clueai-base',
|
141 |
+
prompt=text,
|
142 |
+
generate_config=generate_config)
|
143 |
except Exception as e:
|
144 |
+
logger.error(f"error, {e}")
|
145 |
return
|
146 |
|
147 |
return prediction.generations[0].text
|
|
|
152 |
from PIL import Image
|
153 |
def inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale):
|
154 |
try:
|
155 |
+
res = requests.get(f"https://www.clueai.cn/clueai/hf_text2image?text={text}&negative_prompt={n_text}\
|
156 |
+
&guidance_scale={guidance_scale}&num_inference_steps={steps}\
|
157 |
+
&style={style}&shape={shape}&clarity={clarity}&shape_scale={shape_scale}")
|
158 |
except Exception as e:
|
159 |
logger.error(f"error, {e}")
|
160 |
return
|
|
|
180 |
task = gr.Dropdown(label="任务", show_label=True, choices=task_styles, value="标题生成文章")
|
181 |
btn = gr.Button("生成",elem_id="gen_btn_1").style(full_width=False)
|
182 |
with gr.Accordion("高级操作", open=False):
|
183 |
+
do_sample = gr.Radio([True, False], label="是否采样", value=False)
|
184 |
+
top_p = gr.Slider(0, 1, value=0, step=0.1, label="越大多样性越高, 按照概率采样")
|
185 |
+
top_k = gr.Slider(1, 100, value=50, step=1, label="越大多样性越高,按照top k采样")
|
186 |
+
max_token = gr.Slider(1, 512, value=64, step=1, label="生成的最大长度")
|
187 |
+
temperature = gr.Slider(0,1, value=1, step=0.1, label="temperature, 越小下一个token预测概率越平滑")
|
188 |
+
beam_size = gr.Slider(1, 4, value=1, step=1, label="beam size, 越大解码窗口越广,")
|
189 |
+
length_penalty = gr.Slider(-1, 1, value=0.6, step=0.1, label="大于0鼓励长句子,小于0鼓励短句子")
|
190 |
+
|
191 |
with gr.Row(variant="compact").style( equal_height=True):
|
192 |
output_text = gr.Textbox(
|
193 |
label="输出", show_label=True, max_lines=50,
|
194 |
placeholder="在这里展示结果",
|
195 |
)
|
196 |
gr.Examples(examples_list, [task, text], label="示例")
|
197 |
+
input_params = [text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty]
|
198 |
#text.submit(inference_gen, inputs=input_params, outputs=output_text)
|
199 |
btn.click(inference_gen, inputs=input_params, outputs=output_text)
|
200 |
|