|
from __future__ import annotations |
|
import gradio as gr |
|
import logging |
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') |
|
import subprocess |
|
def runcmd(command): |
|
ret = subprocess.run(command,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE,encoding="utf-8",timeout=60) |
|
if ret.returncode == 0: |
|
print("success:",ret) |
|
else: |
|
print("error:",ret) |
|
runcmd("pip3 install --upgrade clueai") |
|
|
|
import clueai |
|
cl = clueai.Client("", check_api_key=False) |
|
|
|
''' |
|
#luck_t2i_btn_1, #luck_s2i_btn_1, #luck_i2i_btn_1, #luck_ici_btn_1{ |
|
color: #fff; |
|
--tw-gradient-from: #BED336; |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); |
|
--tw-gradient-to: #BED336; |
|
border-color: #BED336; |
|
} |
|
|
|
#luck_easy_btn_1, #luck_iti_btn_1, #luck_tsi_btn_1, #luck_isi_btn_1{ |
|
color: #fff; |
|
--tw-gradient-from: #BED336; |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); |
|
--tw-gradient-to: #BED336; |
|
border-color: #BED336; |
|
} |
|
''' |
|
css=''' |
|
.container { max-width: 800px; margin: auto; } |
|
#gen_btn_1{ |
|
color: #fff; |
|
--tw-gradient-from: #f44336; |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); |
|
--tw-gradient-to: #ff9800; |
|
border-color: #ff9800; |
|
} |
|
#t2i_btn_1, #s2i_btn_1, #i2i_btn_1, #ici_btn_1, #easy_btn_1, #iti_btn_1, #tsi_btn_1, #isi_btn_1{ |
|
color: #fff; |
|
--tw-gradient-from: #f44336; |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); |
|
--tw-gradient-to: #ff9800; |
|
border-color: #ff9800; |
|
} |
|
|
|
|
|
#import_t2i_btn_1, #import_s2i_btn_1, #import_i2i_btn_1, #import_ici_btn_1{ |
|
color: #fff; |
|
--tw-gradient-from: #BED336; |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); |
|
--tw-gradient-to: #BED336; |
|
border-color: #BED336; |
|
} |
|
|
|
#import_easy_btn_1, #import_iti_btn_1, #import_tsi_btn_1, #import_isi_btn_1{ |
|
color: #fff; |
|
--tw-gradient-from: #BED336; |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); |
|
--tw-gradient-to: #BED336; |
|
border-color: #BED336; |
|
} |
|
|
|
#record_btn{ |
|
|
|
} |
|
#record_btn > div > button > span { |
|
width: 2.375rem; |
|
height: 2.375rem; |
|
} |
|
#record_btn > div > button > span > span { |
|
width: 2.375rem; |
|
height: 2.375rem; |
|
} |
|
audio { |
|
margin-bottom: 10px; |
|
} |
|
div#record_btn > .mt-6{ |
|
margin-top: 0!important; |
|
} |
|
div#record_btn > .mt-6 button { |
|
font-size: 1em; |
|
width: 100%; |
|
padding: 20px; |
|
height: 60px; |
|
} |
|
|
|
div#txt2img_tab { |
|
color: #BED336; |
|
} |
|
|
|
''' |
|
|
|
default_generate_config = { |
|
"do_sample": False, |
|
"top_p": 0, |
|
"top_k": 50, |
|
"max_length": 64, |
|
"temperature": 1, |
|
"num_beams": 1, |
|
"length_penalty": 0.6 |
|
} |
|
|
|
task_styles = [] |
|
examples_list = [] |
|
task_style_to_task_prefix = {} |
|
import csv |
|
examples_set = set() |
|
def read_examples(input_file): |
|
header = True |
|
with open(input_file) as finput: |
|
csv_input = csv.reader(finput) |
|
for line in csv_input: |
|
if header: |
|
header = False |
|
continue |
|
task_style, task_prefix, example = line |
|
task_styles.append(task_style) |
|
task_style_to_task_prefix[task_style] = task_prefix |
|
examples_list.append([task_style, example]) |
|
examples_set.add((task_style, example)) |
|
read_examples("./examples.csv") |
|
|
|
def preprocess(text, task): |
|
if task == "问答": |
|
text = text.replace("?", ":").replace("?", ":") |
|
text = text + ":" |
|
|
|
return task_style_to_task_prefix[task] + "\n" + text + "\n答案:" |
|
|
|
def inference_gen(text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty): |
|
default_example = (task, text) in examples_set |
|
text = preprocess(text, task) |
|
generate_config = { |
|
"do_sample": do_sample, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"max_length": max_token, |
|
"temperature": temperature, |
|
"num_beams": beam_size, |
|
"length_penalty": length_penalty |
|
} |
|
|
|
|
|
default_example = default_example and generate_config == default_generate_config |
|
try_num = 3 |
|
while try_num: |
|
try: |
|
if default_example: |
|
prediction = cl.generate( |
|
model_name='clueai-base', |
|
prompt=text) |
|
else: |
|
prediction = cl.generate( |
|
model_name='clueai-base', |
|
prompt=text, |
|
generate_config=generate_config) |
|
except Exception as e: |
|
logger.error(f"error, {e}") |
|
return |
|
if prediction.generations[0].text != "含有违规词,不予展示": |
|
break |
|
try_num -= 1 |
|
|
|
return prediction.generations[0].text |
|
|
|
t2i_default_img_path_list = [] |
|
import base64, requests |
|
from io import BytesIO |
|
from PIL import Image |
|
def luck_inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale): |
|
return inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale, luck=True) |
|
|
|
def inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale, luck=False): |
|
try: |
|
res = requests.get(f"https://www.clueai.cn/clueai/hf_text2image?text={text}&negative_prompt={n_text}\ |
|
&guidance_scale={guidance_scale}&num_inference_steps={steps}\ |
|
&style={style}&shape={shape}&clarity={clarity}&shape_scale={shape_scale}&luck={luck}") |
|
except Exception as e: |
|
logger.error(f"error, {e}") |
|
return |
|
json_dict = res.json() |
|
file_path_list = [] |
|
for i, image in enumerate(json_dict["images"]): |
|
image = image.encode('utf-8') |
|
binary_data = base64.b64decode(image) |
|
img_data = BytesIO(binary_data) |
|
img = Image.open(img_data) |
|
file_path_list.append(img) |
|
|
|
return file_path_list |
|
image_styles = ['无', '细节大师', '对称美', '虚拟引擎', '空间感', '机械风格', '形状艺术', '治愈', '电影构图', '电影构图(治愈)', '荒芜感', '漫画', '逃离艺术', '斯皮尔伯格', '幻想', '杰作', '壁画', '朦胧', '黑白(3d)', '梵高', '毕加索', '莫奈', '丰子恺', '现代', '欧美'] |
|
with gr.Blocks(css=css, title="ClueAI") as demo: |
|
|
|
gr.Markdown('<h1><center><font color=red style="font-size:50px;">ClueAI全能师</font></center></h1>') |
|
with gr.TabItem("文本生成", id='_tab'): |
|
with gr.Row(variant="compact").style( equal_height=True): |
|
text = gr.Textbox("标题:俄天然气管道泄漏爆炸", |
|
label="编辑内容", show_label=False, max_lines=20, |
|
placeholder="在这里输入...", |
|
) |
|
task = gr.Dropdown(label="任务", show_label=True, choices=task_styles, value="标题生成文章") |
|
btn = gr.Button("生成",elem_id="gen_btn_1").style(full_width=False) |
|
with gr.Accordion("高级操作", open=False): |
|
do_sample = gr.Radio([True, False], label="是否采样", value=False) |
|
top_p = gr.Slider(0, 1, value=0, step=0.1, label="越大多样性越高, 按照概率采样") |
|
top_k = gr.Slider(1, 100, value=50, step=1, label="越大多样性越高,按照top k采样") |
|
max_token = gr.Slider(1, 512, value=64, step=1, label="生成的最大长度") |
|
temperature = gr.Slider(0,1, value=1, step=0.1, label="temperature, 越小下一个token预测概率越平滑") |
|
beam_size = gr.Slider(1, 4, value=1, step=1, label="beam size, 越大解码窗口越广,") |
|
length_penalty = gr.Slider(-1, 1, value=0.6, step=0.1, label="大于0鼓励长句子,小于0鼓励短句子") |
|
|
|
with gr.Row(variant="compact").style( equal_height=True): |
|
output_text = gr.Textbox( |
|
label="输出", show_label=True, max_lines=50, |
|
placeholder="在这里展示结果", |
|
) |
|
gr.Examples(examples_list, [task, text], label="示例") |
|
input_params = [text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty] |
|
|
|
btn.click(inference_gen, inputs=input_params, outputs=output_text) |
|
|
|
with gr.TabItem("图像生成", id='txt2img_tab'): |
|
with gr.Row(variant="compact").style( equal_height=True): |
|
text = gr.Textbox("美丽的风景", |
|
label="编辑内容", show_label=False, max_lines=2, |
|
placeholder="在这里输入你的描述...", |
|
) |
|
btn = gr.Button("生成图像",elem_id="t2i_btn_1").style(full_width=False) |
|
|
|
with gr.Row().style( equal_height=True): |
|
generate_prompt_btn = gr.Button("手气不错", elem_id="luck_t2i_btn_1") |
|
|
|
style = gr.Dropdown(label="风格", show_label=True, choices=image_styles, value="无") |
|
with gr.Accordion("高级操作", open=False): |
|
n_text = gr.Textbox("", |
|
label="不想要生成的元素", show_label=True, max_lines=2, |
|
placeholder="在这里输入你不需要包含的内容...", |
|
) |
|
guidance_scale = gr.Slider(1, 20, value=7.5, step=0.5, label="和你的描述匹配程度,越大越匹配") |
|
shape = gr.Radio(["1x1", "16x9", "手机壁纸"], label="尺寸", value="1x1") |
|
shape_scale = gr.Radio([1, 2, 3], label="对图放大倍数", value=1) |
|
steps = gr.Slider(10, 150, value=50, step=1, label="越大质量越好,生成时间越长") |
|
clarity = gr.Radio(["标清", "高清"], label="清晰度", value="标清") |
|
|
|
gr.Examples(["秋日的晚霞", "星空", "室内装修", "婚礼鲜花"], text, label="示例") |
|
|
|
t2i_gallery = gr.Gallery( |
|
t2i_default_img_path_list, |
|
label="生成图像", |
|
show_label=False).style( |
|
grid=[2], height="auto" |
|
) |
|
|
|
input_params = [text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale] |
|
generate_prompt_btn.click(luck_inference_image, inputs=input_params, outputs=[t2i_gallery]) |
|
text.submit(inference_image, inputs=input_params, outputs=t2i_gallery) |
|
btn.click(inference_image, inputs=input_params, outputs=t2i_gallery) |
|
|
|
gr.Markdown(""" |
|
<center><a href="https://clustrmaps.com/site/1bsr7" title="Visit tracker"><img src="//www.clustrmaps.com/map_v2.png?d=sFWwaZBlUeql7focpvpWJDpp9DHpvZfdw1kSavIAWqM&cl=ffffff" /></a></center> |
|
""") |
|
|
|
demo.launch() |
|
|