UniPortrait / src /generation.py
Junjie96's picture
add nsfw safety checker
6475329 verified
import json
import os
import time
import gradio as gr
import requests
from src.log import logger
from src.util import download_images
def call_generation(data):
url_task = os.getenv("URL_TASK")
api_key = os.getenv("API_KEY_GENERATION")
model_id = os.getenv("MODEL_ID")
url_query = os.getenv("URL_QUERY")
batch_size = 4
repeat_times = 1
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key}",
"X-DashScope-Async": "enable",
}
data["model"] = model_id
data["parameters"]["n"] = batch_size
all_res_ = []
for i in range(repeat_times):
if data["parameters"]["seed"] != -1:
data["parameters"]["seed"] = data["parameters"]["seed"] * (i+1)
res_ = requests.post(url_task, data=json.dumps(data), headers=headers)
all_res_.append(res_)
all_image_data = []
for res_ in all_res_:
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
task_id = res['output']['task_id']
logger.info(f"task_id: {task_id}: Create request success. Params: {data}")
# Async query
is_running = True
while is_running:
res_ = requests.post(f'{url_query}/{task_id}', headers=headers)
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
if "SUCCEEDED" == res['output']['task_status']:
logger.info(f"task_id: {task_id}: Generation task query success.")
results = res['output']['results']
img_urls = [x['url'] for x in results]
logger.info(f"task_id: {task_id}: {res}")
break
elif "FAILED" != res['output']['task_status']:
logger.debug(f"task_id: {task_id}: query result...")
time.sleep(1)
else:
raise gr.Error(
"Fail to get results from Generation task. Make sure all the ID images have a clear face. If it still doesn't work, you can contact us or open an issue.")
else:
logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
raise gr.Error("Fail to query task result.")
logger.info(f"task_id: {task_id}: download generated images.")
img_data = download_images(img_urls, batch_size)
logger.info(f"task_id: {task_id}: Generate done.")
all_image_data += img_data
else:
logger.error(f'Fail to create Generation task: {res_.content}')
raise gr.Error("Fail to create Generation task.")
if len(all_image_data) != repeat_times * batch_size:
raise gr.Error("Fail to Generation.")
return all_image_data