Spaces:
Sleeping
Sleeping
ishworrsubedii
commited on
Commit
•
36cd99b
1
Parent(s):
7277638
Updated the latest changes
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/dockerhub.yaml +39 -0
- .gitignore +61 -0
- Dockerfile +22 -0
- fooocus_api_version.py +1 -0
- fooocusapi/api.py +41 -0
- fooocusapi/args.py +20 -0
- fooocusapi/base_args.py +27 -0
- fooocusapi/configs/default.py +92 -0
- fooocusapi/models/common/base.py +189 -0
- fooocusapi/models/common/image_meta.py +118 -0
- fooocusapi/models/common/requests.py +132 -0
- fooocusapi/models/common/response.py +90 -0
- fooocusapi/models/common/task.py +60 -0
- fooocusapi/models/requests_v1.py +274 -0
- fooocusapi/models/requests_v2.py +50 -0
- fooocusapi/parameters.py +94 -0
- fooocusapi/routes/__init__.py +0 -0
- fooocusapi/routes/generate_v1.py +186 -0
- fooocusapi/routes/generate_v2.py +199 -0
- fooocusapi/routes/query.py +135 -0
- fooocusapi/sql_client.py +269 -0
- fooocusapi/task_queue.py +323 -0
- fooocusapi/utils/api_utils.py +291 -0
- fooocusapi/utils/call_worker.py +97 -0
- fooocusapi/utils/file_utils.py +143 -0
- fooocusapi/utils/img_utils.py +198 -0
- fooocusapi/utils/logger.py +132 -0
- fooocusapi/utils/lora_manager.py +71 -0
- fooocusapi/utils/model_loader.py +46 -0
- fooocusapi/utils/tools.py +159 -0
- fooocusapi/worker.py +1044 -0
- main.py +213 -0
- mannequin_to_model.py +90 -0
- predict.py +316 -0
- repositories/Fooocus/__init__.py +4 -0
- repositories/Fooocus/args_manager.py +55 -0
- repositories/Fooocus/extras/BLIP/configs/bert_config.json +21 -0
- repositories/Fooocus/extras/BLIP/configs/caption_coco.yaml +33 -0
- repositories/Fooocus/extras/BLIP/configs/med_config.json +21 -0
- repositories/Fooocus/extras/BLIP/configs/nlvr.yaml +21 -0
- repositories/Fooocus/extras/BLIP/configs/nocaps.yaml +15 -0
- repositories/Fooocus/extras/BLIP/configs/pretrain.yaml +27 -0
- repositories/Fooocus/extras/BLIP/configs/retrieval_coco.yaml +34 -0
- repositories/Fooocus/extras/BLIP/configs/retrieval_flickr.yaml +34 -0
- repositories/Fooocus/extras/BLIP/configs/retrieval_msrvtt.yaml +12 -0
- repositories/Fooocus/extras/BLIP/configs/vqa.yaml +25 -0
- repositories/Fooocus/extras/BLIP/models/bert_tokenizer/config.json +23 -0
- repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer.json +0 -0
- repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer_config.json +3 -0
- repositories/Fooocus/extras/BLIP/models/bert_tokenizer/vocab.txt +0 -0
.github/workflows/dockerhub.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Publish MannequinToModel Docker image
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
push_to_registry:
|
9 |
+
name: Push Docker image to Docker Hub
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
permissions:
|
12 |
+
packages: write
|
13 |
+
contents: read
|
14 |
+
attestations: write
|
15 |
+
steps:
|
16 |
+
- name: Check out the repo
|
17 |
+
uses: actions/checkout@v4
|
18 |
+
|
19 |
+
- name: Log in to Docker Hub
|
20 |
+
uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a
|
21 |
+
with:
|
22 |
+
username: ${{ secrets.DOCKER_USERNAME }}
|
23 |
+
password: ${{ secrets.DOCKER_PASSWORD }}
|
24 |
+
|
25 |
+
- name: Extract metadata (tags, labels) for Docker
|
26 |
+
id: meta
|
27 |
+
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
|
28 |
+
with:
|
29 |
+
images: techconsp/mannequin_to_model
|
30 |
+
|
31 |
+
- name: Build and push Docker image
|
32 |
+
id: push
|
33 |
+
uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
|
34 |
+
with:
|
35 |
+
context: .
|
36 |
+
file: ./Dockerfile
|
37 |
+
push: true
|
38 |
+
tags: ${{ steps.meta.outputs.tags }}
|
39 |
+
labels: ${{ steps.meta.outputs.labels }}
|
.gitignore
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
|
27 |
+
# Jupyter Notebook
|
28 |
+
.ipynb_checkpoints/
|
29 |
+
*.ipynb_checkpoints/
|
30 |
+
# Exclude if necessary for security reasons:
|
31 |
+
#*.ipynb
|
32 |
+
|
33 |
+
# Virtual environments
|
34 |
+
venv/
|
35 |
+
env/
|
36 |
+
ENV/
|
37 |
+
.venv/
|
38 |
+
.ENV/
|
39 |
+
|
40 |
+
# IDEs
|
41 |
+
.idea/
|
42 |
+
.vscode/
|
43 |
+
*.sublime-project
|
44 |
+
*.sublime-workspace
|
45 |
+
|
46 |
+
htmlcov/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache/
|
50 |
+
|
51 |
+
|
52 |
+
.env
|
53 |
+
|
54 |
+
yolov8m-seg.pt
|
55 |
+
.yoloface
|
56 |
+
logs
|
57 |
+
experiments
|
58 |
+
artifacts
|
59 |
+
examples
|
60 |
+
resources
|
61 |
+
*.safetensors
|
Dockerfile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /mannequin_to_model
|
4 |
+
|
5 |
+
COPY . /mannequin_to_model
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install -y \
|
8 |
+
libgl1-mesa-glx \
|
9 |
+
ffmpeg \
|
10 |
+
libsm6 \
|
11 |
+
libxext6 \
|
12 |
+
build-essential \
|
13 |
+
git \
|
14 |
+
&& apt-get clean
|
15 |
+
|
16 |
+
RUN pip install --no-cache-dir --upgrade pip==23.0
|
17 |
+
|
18 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
19 |
+
|
20 |
+
EXPOSE 8000
|
21 |
+
|
22 |
+
CMD ["python", "main.py"]
|
fooocus_api_version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
version = '0.4.1.1'
|
fooocusapi/api.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Entry for startup fastapi server
|
3 |
+
"""
|
4 |
+
from fastapi import FastAPI
|
5 |
+
from fastapi.staticfiles import StaticFiles
|
6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
7 |
+
|
8 |
+
import uvicorn
|
9 |
+
|
10 |
+
from fooocusapi.utils import file_utils
|
11 |
+
from fooocusapi.routes.generate_v1 import secure_router as generate_v1
|
12 |
+
from fooocusapi.routes.generate_v2 import secure_router as generate_v2
|
13 |
+
from fooocusapi.routes.query import secure_router as query
|
14 |
+
from mannequin_to_model import secure_router as mannequin_to_model
|
15 |
+
|
16 |
+
app = FastAPI()
|
17 |
+
|
18 |
+
app.add_middleware(
|
19 |
+
CORSMiddleware,
|
20 |
+
allow_origins=["*"], # Allow access from all sources
|
21 |
+
allow_credentials=True,
|
22 |
+
allow_methods=["*"], # Allow all HTTP methods
|
23 |
+
allow_headers=["*"], # Allow all request headers
|
24 |
+
)
|
25 |
+
|
26 |
+
app.mount("/files", StaticFiles(directory=file_utils.output_dir), name="files")
|
27 |
+
|
28 |
+
app.include_router(query)
|
29 |
+
app.include_router(generate_v1)
|
30 |
+
app.include_router(generate_v2)
|
31 |
+
app.include_router(mannequin_to_model)
|
32 |
+
|
33 |
+
|
34 |
+
def start_app(args):
|
35 |
+
"""Start the FastAPI application"""
|
36 |
+
file_utils.STATIC_SERVER_BASE = args.base_url + "/files/"
|
37 |
+
uvicorn.run(
|
38 |
+
app="fooocusapi.api:app",
|
39 |
+
host="0.0.0.0",
|
40 |
+
port=8000,
|
41 |
+
log_level=args.log_level)
|
fooocusapi/args.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Do not modify the import order
|
3 |
+
"""
|
4 |
+
from fooocusapi.base_args import add_base_args
|
5 |
+
import ldm_patched.modules.args_parser as args_parser
|
6 |
+
|
7 |
+
# Add Fooocus-API args to parser
|
8 |
+
add_base_args(args_parser.parser, False)
|
9 |
+
|
10 |
+
# Apply Fooocus args
|
11 |
+
from args_manager import args_parser
|
12 |
+
|
13 |
+
# Override the port default value
|
14 |
+
args_parser.parser.set_defaults(
|
15 |
+
port=8888
|
16 |
+
)
|
17 |
+
|
18 |
+
# Execute args parse again
|
19 |
+
args_parser.args = args_parser.parser.parse_args()
|
20 |
+
args = args_parser.args
|
fooocusapi/base_args.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
base_args.py
|
3 |
+
"""
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
|
7 |
+
def add_base_args(parser: ArgumentParser, before_prepared: bool):
|
8 |
+
"""
|
9 |
+
Add base args for fooocusapi
|
10 |
+
Args:
|
11 |
+
parser: ArgumentParser
|
12 |
+
before_prepared: before prepare environment
|
13 |
+
Returns:
|
14 |
+
"""
|
15 |
+
if before_prepared:
|
16 |
+
parser.add_argument("--port", type=int, default=8888, help="Set the listen port, default: 8888")
|
17 |
+
|
18 |
+
parser.add_argument("--host", type=str, default='127.0.0.1', help="Set the listen host, default: 127.0.0.1")
|
19 |
+
parser.add_argument("--base-url", type=str, default=None, help="Set base url for outside visit, default is http://host:port")
|
20 |
+
parser.add_argument("--log-level", type=str, default='info', help="Log info for Uvicorn, default: info")
|
21 |
+
parser.add_argument("--skip-pip", default=False, action="store_true", help="Skip automatic pip install when setup")
|
22 |
+
parser.add_argument("--preload-pipeline", default=False, action="store_true", help="Preload pipeline before start http server")
|
23 |
+
parser.add_argument("--queue-size", type=int, default=100, help="Working queue size, default: 100, generation requests exceeding working queue size will return failure")
|
24 |
+
parser.add_argument("--queue-history", type=int, default=0, help="Finished jobs reserve size, tasks exceeding the limit will be deleted, including output image files, default: 0, means no limit")
|
25 |
+
parser.add_argument('--webhook-url', type=str, default=None, help='The URL to send a POST request when a job is finished')
|
26 |
+
parser.add_argument('--persistent', default=False, action="store_true", help="Store history to db")
|
27 |
+
parser.add_argument("--apikey", type=str, default=None, help="API key for authenticating requests")
|
fooocusapi/configs/default.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Static variables for Fooocus API
|
3 |
+
"""
|
4 |
+
img_generate_responses = {
|
5 |
+
"200": {
|
6 |
+
"description": "PNG bytes if request's 'Accept' header is 'image/png', otherwise JSON",
|
7 |
+
"content": {
|
8 |
+
"application/json": {
|
9 |
+
"example": [{
|
10 |
+
"base64": "...very long string...",
|
11 |
+
"seed": "1050625087",
|
12 |
+
"finish_reason": "SUCCESS",
|
13 |
+
}]
|
14 |
+
},
|
15 |
+
"application/json async": {
|
16 |
+
"example": {
|
17 |
+
"job_id": 1,
|
18 |
+
"job_type": "Text to Image"
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"image/png": {
|
22 |
+
"example": "PNG bytes, what did you expect?"
|
23 |
+
},
|
24 |
+
},
|
25 |
+
}
|
26 |
+
}
|
27 |
+
|
28 |
+
default_inpaint_engine_version = "v2.6"
|
29 |
+
|
30 |
+
default_styles = ["Fooocus V2", "Fooocus Enhance", "Fooocus Sharp"]
|
31 |
+
default_base_model_name = "juggernautXL_v8Rundiffusion.safetensors"
|
32 |
+
default_refiner_model_name = "None"
|
33 |
+
default_refiner_switch = 0.5
|
34 |
+
default_loras = [[True, "sd_xl_offset_example-lora_1.0.safetensors", 0.1]]
|
35 |
+
default_cfg_scale = 7.0
|
36 |
+
default_prompt_negative = ""
|
37 |
+
default_aspect_ratio = "1152*896"
|
38 |
+
default_sampler = "dpmpp_2m_sde_gpu"
|
39 |
+
default_scheduler = "karras"
|
40 |
+
|
41 |
+
available_aspect_ratios = [
|
42 |
+
"704*1408",
|
43 |
+
"704*1344",
|
44 |
+
"768*1344",
|
45 |
+
"768*1280",
|
46 |
+
"832*1216",
|
47 |
+
"832*1152",
|
48 |
+
"896*1152",
|
49 |
+
"896*1088",
|
50 |
+
"960*1088",
|
51 |
+
"960*1024",
|
52 |
+
"1024*1024",
|
53 |
+
"1024*960",
|
54 |
+
"1088*960",
|
55 |
+
"1088*896",
|
56 |
+
"1152*896",
|
57 |
+
"1152*832",
|
58 |
+
"1216*832",
|
59 |
+
"1280*768",
|
60 |
+
"1344*768",
|
61 |
+
"1344*704",
|
62 |
+
"1408*704",
|
63 |
+
"1472*704",
|
64 |
+
"1536*640",
|
65 |
+
"1600*640",
|
66 |
+
"1664*576",
|
67 |
+
"1728*576",
|
68 |
+
]
|
69 |
+
|
70 |
+
uov_methods = [
|
71 |
+
"Disabled",
|
72 |
+
"Vary (Subtle)",
|
73 |
+
"Vary (Strong)",
|
74 |
+
"Upscale (1.5x)",
|
75 |
+
"Upscale (2x)",
|
76 |
+
"Upscale (Fast 2x)",
|
77 |
+
"Upscale (Custom)",
|
78 |
+
]
|
79 |
+
|
80 |
+
outpaint_expansions = ["Left", "Right", "Top", "Bottom"]
|
81 |
+
|
82 |
+
|
83 |
+
def get_aspect_ratio_value(label: str) -> str:
|
84 |
+
"""
|
85 |
+
Get aspect ratio
|
86 |
+
Args:
|
87 |
+
label: str, aspect ratio
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
|
91 |
+
"""
|
92 |
+
return label.split(" ")[0].replace("×", "*")
|
fooocusapi/models/common/base.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Common models"""
|
2 |
+
from typing import List, Tuple
|
3 |
+
from enum import Enum
|
4 |
+
from fastapi import UploadFile
|
5 |
+
from fastapi.exceptions import RequestValidationError
|
6 |
+
from pydantic import (
|
7 |
+
ValidationError,
|
8 |
+
ConfigDict,
|
9 |
+
BaseModel,
|
10 |
+
TypeAdapter,
|
11 |
+
Field
|
12 |
+
)
|
13 |
+
from pydantic_core import InitErrorDetails
|
14 |
+
|
15 |
+
from fooocusapi.configs.default import default_loras
|
16 |
+
|
17 |
+
|
18 |
+
class PerformanceSelection(str, Enum):
|
19 |
+
"""Performance selection"""
|
20 |
+
speed = 'Speed'
|
21 |
+
quality = 'Quality'
|
22 |
+
extreme_speed = 'Extreme Speed'
|
23 |
+
lightning = 'Lightning'
|
24 |
+
hyper_sd = 'Hyper-SD'
|
25 |
+
|
26 |
+
|
27 |
+
class Lora(BaseModel):
|
28 |
+
"""Common params lora model"""
|
29 |
+
enabled: bool
|
30 |
+
model_name: str
|
31 |
+
weight: float = Field(default=0.5, ge=-2, le=2)
|
32 |
+
|
33 |
+
model_config = ConfigDict(
|
34 |
+
protected_namespaces=('protect_me_', 'also_protect_')
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
LoraList = TypeAdapter(List[Lora])
|
39 |
+
default_loras_model = []
|
40 |
+
for lora in default_loras:
|
41 |
+
if lora[0] != 'None':
|
42 |
+
default_loras_model.append(
|
43 |
+
Lora(
|
44 |
+
enabled=lora[0],
|
45 |
+
model_name=lora[1],
|
46 |
+
weight=lora[2])
|
47 |
+
)
|
48 |
+
default_loras_json = LoraList.dump_json(default_loras_model)
|
49 |
+
|
50 |
+
|
51 |
+
class UpscaleOrVaryMethod(str, Enum):
|
52 |
+
"""Upscale or Vary method"""
|
53 |
+
subtle_variation = 'Vary (Subtle)'
|
54 |
+
strong_variation = 'Vary (Strong)'
|
55 |
+
upscale_15 = 'Upscale (1.5x)'
|
56 |
+
upscale_2 = 'Upscale (2x)'
|
57 |
+
upscale_fast = 'Upscale (Fast 2x)'
|
58 |
+
upscale_custom = 'Upscale (Custom)'
|
59 |
+
|
60 |
+
|
61 |
+
class OutpaintExpansion(str, Enum):
|
62 |
+
"""Outpaint expansion"""
|
63 |
+
left = 'Left'
|
64 |
+
right = 'Right'
|
65 |
+
top = 'Top'
|
66 |
+
bottom = 'Bottom'
|
67 |
+
|
68 |
+
|
69 |
+
class ControlNetType(str, Enum):
|
70 |
+
"""ControlNet Type"""
|
71 |
+
cn_ip = "ImagePrompt"
|
72 |
+
cn_ip_face = "FaceSwap"
|
73 |
+
cn_canny = "PyraCanny"
|
74 |
+
cn_cpds = "CPDS"
|
75 |
+
|
76 |
+
|
77 |
+
class ImagePrompt(BaseModel):
|
78 |
+
"""Common params object ImagePrompt"""
|
79 |
+
cn_img: UploadFile | None = Field(default=None)
|
80 |
+
cn_stop: float | None = Field(default=None, ge=0, le=1)
|
81 |
+
cn_weight: float | None = Field(default=None, ge=0, le=2, description="None for default value")
|
82 |
+
cn_type: ControlNetType = Field(default=ControlNetType.cn_ip)
|
83 |
+
|
84 |
+
|
85 |
+
class DescribeImageType(str, Enum):
|
86 |
+
"""Image type for image to prompt"""
|
87 |
+
photo = 'Photo'
|
88 |
+
anime = 'Anime'
|
89 |
+
|
90 |
+
|
91 |
+
class ImageMetaScheme(str, Enum):
|
92 |
+
"""Scheme for save image meta
|
93 |
+
Attributes:
|
94 |
+
Fooocus: json format
|
95 |
+
A111: string
|
96 |
+
"""
|
97 |
+
Fooocus = 'fooocus'
|
98 |
+
A111 = 'a111'
|
99 |
+
|
100 |
+
|
101 |
+
def style_selection_parser(style_selections: str | List[str]) -> List[str]:
|
102 |
+
"""
|
103 |
+
Parse style selections, Convert to list
|
104 |
+
Args:
|
105 |
+
style_selections: str, comma separated Fooocus style selections
|
106 |
+
e.g. Fooocus V2, Fooocus Enhance, Fooocus Sharp
|
107 |
+
Returns:
|
108 |
+
List[str]
|
109 |
+
"""
|
110 |
+
style_selection_arr: List[str] = []
|
111 |
+
if style_selections is None or len(style_selections) == 0:
|
112 |
+
return []
|
113 |
+
for part in style_selections:
|
114 |
+
if len(part) > 0:
|
115 |
+
for s in part.split(','):
|
116 |
+
style = s.strip()
|
117 |
+
style_selection_arr.append(style)
|
118 |
+
return style_selection_arr
|
119 |
+
|
120 |
+
|
121 |
+
def lora_parser(loras: str) -> List[Lora]:
|
122 |
+
"""
|
123 |
+
Parse lora config, Convert to list
|
124 |
+
Args:
|
125 |
+
loras: a json string for loras
|
126 |
+
Returns:
|
127 |
+
List[Lora]
|
128 |
+
"""
|
129 |
+
loras_model: List[Lora] = []
|
130 |
+
if loras is None or len(loras) == 0:
|
131 |
+
return loras_model
|
132 |
+
try:
|
133 |
+
loras_model = LoraList.validate_json(loras)
|
134 |
+
return loras_model
|
135 |
+
except ValidationError as ve:
|
136 |
+
errs = ve.errors()
|
137 |
+
raise RequestValidationError from errs
|
138 |
+
|
139 |
+
|
140 |
+
def outpaint_selections_parser(outpaint_selections: str | list[str]) -> List[OutpaintExpansion]:
|
141 |
+
"""
|
142 |
+
Parse outpaint selections, Convert to list
|
143 |
+
Args:
|
144 |
+
outpaint_selections: str, comma separated Left, Right, Top, Bottom
|
145 |
+
e.g. Left, Right, Top, Bottom
|
146 |
+
Returns:
|
147 |
+
List[OutpaintExpansion]
|
148 |
+
"""
|
149 |
+
outpaint_selections_arr: List[OutpaintExpansion] = []
|
150 |
+
if outpaint_selections is None or len(outpaint_selections) == 0:
|
151 |
+
return []
|
152 |
+
for part in outpaint_selections:
|
153 |
+
if len(part) > 0:
|
154 |
+
for s in part.split(','):
|
155 |
+
try:
|
156 |
+
expansion = OutpaintExpansion(s)
|
157 |
+
outpaint_selections_arr.append(expansion)
|
158 |
+
except ValueError:
|
159 |
+
errs = InitErrorDetails(
|
160 |
+
type='enum',
|
161 |
+
loc=tuple('outpaint_selections'),
|
162 |
+
input=outpaint_selections,
|
163 |
+
ctx={
|
164 |
+
'expected': "str, comma separated Left, Right, Top, Bottom"
|
165 |
+
})
|
166 |
+
raise RequestValidationError from errs
|
167 |
+
return outpaint_selections_arr
|
168 |
+
|
169 |
+
|
170 |
+
def image_prompt_parser(image_prompts_config: List[Tuple]) -> List[ImagePrompt]:
|
171 |
+
"""
|
172 |
+
Image prompt parser, Convert to List[ImagePrompt]
|
173 |
+
Args:
|
174 |
+
image_prompts_config: List[Tuple]
|
175 |
+
e.g. ('image1.jpg', 0.5, 1.0, 'normal'), ('image2.jpg', 0.5, 1.0, 'normal')
|
176 |
+
returns:
|
177 |
+
List[ImagePrompt]
|
178 |
+
"""
|
179 |
+
image_prompts: List[ImagePrompt] = []
|
180 |
+
if image_prompts_config is None or len(image_prompts_config) == 0:
|
181 |
+
return []
|
182 |
+
for config in image_prompts_config:
|
183 |
+
cn_img, cn_stop, cn_weight, cn_type = config
|
184 |
+
image_prompts.append(ImagePrompt(
|
185 |
+
cn_img=cn_img,
|
186 |
+
cn_stop=cn_stop,
|
187 |
+
cn_weight=cn_weight,
|
188 |
+
cn_type=cn_type))
|
189 |
+
return image_prompts
|
fooocusapi/models/common/image_meta.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Image meta schema
|
3 |
+
"""
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from fooocus_version import version
|
7 |
+
from pydantic import BaseModel
|
8 |
+
|
9 |
+
|
10 |
+
class ImageMeta(BaseModel):
|
11 |
+
"""
|
12 |
+
Image meta data model
|
13 |
+
"""
|
14 |
+
|
15 |
+
metadata_scheme: str = "fooocus"
|
16 |
+
|
17 |
+
base_model: str
|
18 |
+
base_model_hash: str
|
19 |
+
|
20 |
+
prompt: str
|
21 |
+
full_prompt: List[str]
|
22 |
+
prompt_expansion: str
|
23 |
+
|
24 |
+
negative_prompt: str
|
25 |
+
full_negative_prompt: List[str]
|
26 |
+
|
27 |
+
performance: str
|
28 |
+
|
29 |
+
style: str
|
30 |
+
|
31 |
+
refiner_model: str = "None"
|
32 |
+
refiner_switch: float = 0.5
|
33 |
+
|
34 |
+
loras: List[list]
|
35 |
+
|
36 |
+
resolution: str
|
37 |
+
|
38 |
+
sampler: str = "dpmpp_2m_sde_gpu"
|
39 |
+
scheduler: str = "karras"
|
40 |
+
seed: str
|
41 |
+
adm_guidance: str
|
42 |
+
guidance_scale: float
|
43 |
+
sharpness: float
|
44 |
+
steps: int
|
45 |
+
vae_name: str
|
46 |
+
|
47 |
+
version: str = version
|
48 |
+
|
49 |
+
def __repr__(self):
|
50 |
+
return ""
|
51 |
+
|
52 |
+
|
53 |
+
def loras_parser(loras: list) -> list:
|
54 |
+
"""
|
55 |
+
Parse lora list
|
56 |
+
"""
|
57 |
+
return [
|
58 |
+
[
|
59 |
+
lora[0].rsplit('.', maxsplit=1)[:1][0],
|
60 |
+
lora[1],
|
61 |
+
"hash_not_calculated",
|
62 |
+
] for lora in loras if lora[0] != 'None' and lora[0] is not None]
|
63 |
+
|
64 |
+
|
65 |
+
def image_parse(
|
66 |
+
async_tak: object,
|
67 |
+
task: dict
|
68 |
+
) -> dict | str:
|
69 |
+
"""
|
70 |
+
Parse image meta data
|
71 |
+
Generate meta data for image from task and async task object
|
72 |
+
Args:
|
73 |
+
async_tak: async task obj
|
74 |
+
task: task obj
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
dict: image meta data
|
78 |
+
"""
|
79 |
+
req_param = async_tak.req_param
|
80 |
+
meta = ImageMeta(
|
81 |
+
metadata_scheme=req_param.meta_scheme,
|
82 |
+
base_model=req_param.base_model_name.rsplit('.', maxsplit=1)[:1][0],
|
83 |
+
base_model_hash='',
|
84 |
+
prompt=req_param.prompt,
|
85 |
+
full_prompt=task['positive'],
|
86 |
+
prompt_expansion=task['expansion'],
|
87 |
+
negative_prompt=req_param.negative_prompt,
|
88 |
+
full_negative_prompt=task['negative'],
|
89 |
+
performance=req_param.performance_selection,
|
90 |
+
style=str(req_param.style_selections),
|
91 |
+
refiner_model=req_param.refiner_model_name,
|
92 |
+
refiner_switch=req_param.refiner_switch,
|
93 |
+
loras=loras_parser(req_param.loras),
|
94 |
+
resolution=str(tuple([int(n) for n in req_param.aspect_ratios_selection.split('*')])),
|
95 |
+
sampler=req_param.advanced_params.sampler_name,
|
96 |
+
scheduler=req_param.advanced_params.scheduler_name,
|
97 |
+
seed=str(task['task_seed']),
|
98 |
+
adm_guidance=str((
|
99 |
+
req_param.advanced_params.adm_scaler_positive,
|
100 |
+
req_param.advanced_params.adm_scaler_negative,
|
101 |
+
req_param.advanced_params.adm_scaler_end)),
|
102 |
+
guidance_scale=req_param.guidance_scale,
|
103 |
+
sharpness=req_param.sharpness,
|
104 |
+
steps=-1,
|
105 |
+
vae_name=req_param.advanced_params.vae_name,
|
106 |
+
version=version
|
107 |
+
)
|
108 |
+
if meta.metadata_scheme not in ["fooocus", "a111"]:
|
109 |
+
meta.metadata_scheme = "fooocus"
|
110 |
+
if meta.metadata_scheme == "fooocus":
|
111 |
+
meta_dict = meta.model_dump()
|
112 |
+
for i, lora in enumerate(meta.loras):
|
113 |
+
attr_name = f"lora_combined_{i+1}"
|
114 |
+
lr = [str(x) for x in lora]
|
115 |
+
meta_dict[attr_name] = f"{lr[0]} : {lr[1]}"
|
116 |
+
else:
|
117 |
+
meta_dict = meta.model_dump()
|
118 |
+
return meta_dict
|
fooocusapi/models/common/requests.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Common model for requests"""
|
2 |
+
from typing import List
|
3 |
+
from pydantic import (
|
4 |
+
BaseModel,
|
5 |
+
Field,
|
6 |
+
ValidationError
|
7 |
+
)
|
8 |
+
|
9 |
+
from modules.config import (
|
10 |
+
default_sampler,
|
11 |
+
default_scheduler,
|
12 |
+
default_prompt,
|
13 |
+
default_prompt_negative,
|
14 |
+
default_aspect_ratio,
|
15 |
+
default_base_model_name,
|
16 |
+
default_refiner_model_name,
|
17 |
+
default_refiner_switch,
|
18 |
+
default_cfg_scale,
|
19 |
+
default_styles,
|
20 |
+
default_overwrite_step,
|
21 |
+
default_inpaint_engine_version,
|
22 |
+
default_overwrite_switch,
|
23 |
+
default_cfg_tsnr,
|
24 |
+
default_sample_sharpness,
|
25 |
+
default_vae,
|
26 |
+
default_clip_skip
|
27 |
+
)
|
28 |
+
|
29 |
+
from modules.flags import clip_skip_max
|
30 |
+
|
31 |
+
from fooocusapi.models.common.base import (
|
32 |
+
PerformanceSelection,
|
33 |
+
Lora,
|
34 |
+
default_loras_model
|
35 |
+
)
|
36 |
+
|
37 |
+
default_aspect_ratio = default_aspect_ratio.split(" ")[0].replace("×", "*")
|
38 |
+
|
39 |
+
|
40 |
+
class QueryJobRequest(BaseModel):
|
41 |
+
"""Query job request"""
|
42 |
+
job_id: str = Field(description="Job ID to query")
|
43 |
+
require_step_preview: bool = Field(
|
44 |
+
default=False,
|
45 |
+
description="Set to true will return preview image of generation steps at current time")
|
46 |
+
|
47 |
+
|
48 |
+
class AdvancedParams(BaseModel):
|
49 |
+
"""Common params object AdvancedParams"""
|
50 |
+
disable_preview: bool = Field(False, description="Disable preview during generation")
|
51 |
+
disable_intermediate_results: bool = Field(False, description="Disable intermediate results")
|
52 |
+
disable_seed_increment: bool = Field(False, description="Disable Seed Increment")
|
53 |
+
adm_scaler_positive: float = Field(1.5, description="Positive ADM Guidance Scaler", ge=0.1, le=3.0)
|
54 |
+
adm_scaler_negative: float = Field(0.8, description="Negative ADM Guidance Scaler", ge=0.1, le=3.0)
|
55 |
+
adm_scaler_end: float = Field(0.3, description="ADM Guidance End At Step", ge=0.0, le=1.0)
|
56 |
+
adaptive_cfg: float = Field(default_cfg_tsnr, description="CFG Mimicking from TSNR", ge=1.0, le=30.0)
|
57 |
+
clip_skip: int = Field(default_clip_skip, description="Clip Skip", ge=1, le=clip_skip_max)
|
58 |
+
sampler_name: str = Field(default_sampler, description="Sampler")
|
59 |
+
scheduler_name: str = Field(default_scheduler, description="Scheduler")
|
60 |
+
overwrite_step: int = Field(default_overwrite_step, description="Forced Overwrite of Sampling Step", ge=-1, le=200)
|
61 |
+
overwrite_switch: float = Field(default_overwrite_switch, description="Forced Overwrite of Refiner Switch Step", ge=-1, le=1)
|
62 |
+
overwrite_width: int = Field(-1, description="Forced Overwrite of Generating Width", ge=-1, le=2048)
|
63 |
+
overwrite_height: int = Field(-1, description="Forced Overwrite of Generating Height", ge=-1, le=2048)
|
64 |
+
overwrite_vary_strength: float = Field(-1, description='Forced Overwrite of Denoising Strength of "Vary"', ge=-1, le=1.0)
|
65 |
+
overwrite_upscale_strength: float = Field(-1, description='Forced Overwrite of Denoising Strength of "Upscale"', ge=-1, le=1.0)
|
66 |
+
mixing_image_prompt_and_vary_upscale: bool = Field(False, description="Mixing Image Prompt and Vary/Upscale")
|
67 |
+
mixing_image_prompt_and_inpaint: bool = Field(False, description="Mixing Image Prompt and Inpaint")
|
68 |
+
debugging_cn_preprocessor: bool = Field(False, description="Debug Preprocessors")
|
69 |
+
skipping_cn_preprocessor: bool = Field(False, description="Skip Preprocessors")
|
70 |
+
canny_low_threshold: int = Field(64, description="Canny Low Threshold", ge=1, le=255)
|
71 |
+
canny_high_threshold: int = Field(128, description="Canny High Threshold", ge=1, le=255)
|
72 |
+
refiner_swap_method: str = Field('joint', description="Refiner swap method")
|
73 |
+
controlnet_softness: float = Field(0.25, description="Softness of ControlNet", ge=0.0, le=1.0)
|
74 |
+
freeu_enabled: bool = Field(False, description="FreeU enabled")
|
75 |
+
freeu_b1: float = Field(1.01, description="FreeU B1")
|
76 |
+
freeu_b2: float = Field(1.02, description="FreeU B2")
|
77 |
+
freeu_s1: float = Field(0.99, description="FreeU B3")
|
78 |
+
freeu_s2: float = Field(0.95, description="FreeU B4")
|
79 |
+
debugging_inpaint_preprocessor: bool = Field(False, description="Debug Inpaint Preprocessing")
|
80 |
+
inpaint_disable_initial_latent: bool = Field(False, description="Disable initial latent in inpaint")
|
81 |
+
inpaint_engine: str = Field(default_inpaint_engine_version, description="Inpaint Engine")
|
82 |
+
inpaint_strength: float = Field(1.0, description="Inpaint Denoising Strength", ge=0.0, le=1.0)
|
83 |
+
inpaint_respective_field: float = Field(1.0, description="Inpaint Respective Field", ge=0.0, le=1.0)
|
84 |
+
inpaint_mask_upload_checkbox: bool = Field(False, description="Upload Mask")
|
85 |
+
invert_mask_checkbox: bool = Field(False, description="Invert Mask")
|
86 |
+
inpaint_erode_or_dilate: int = Field(0, description="Mask Erode or Dilate", ge=-64, le=64)
|
87 |
+
black_out_nsfw: bool = Field(False, description="Block out NSFW")
|
88 |
+
vae_name: str = Field(default_vae, description="VAE name")
|
89 |
+
|
90 |
+
|
91 |
+
class CommonRequest(BaseModel):
|
92 |
+
"""All generate request based on this model"""
|
93 |
+
prompt: str = default_prompt
|
94 |
+
negative_prompt: str = default_prompt_negative
|
95 |
+
style_selections: List[str] = default_styles
|
96 |
+
performance_selection: PerformanceSelection = PerformanceSelection.speed
|
97 |
+
aspect_ratios_selection: str = default_aspect_ratio
|
98 |
+
image_number: int = Field(default=1, description="Image number", ge=1, le=32)
|
99 |
+
image_seed: int = Field(default=-1, description="Seed to generate image, -1 for random")
|
100 |
+
sharpness: float = Field(default=default_sample_sharpness, ge=0.0, le=30.0)
|
101 |
+
guidance_scale: float = Field(default=default_cfg_scale, ge=1.0, le=30.0)
|
102 |
+
base_model_name: str = default_base_model_name
|
103 |
+
refiner_model_name: str = default_refiner_model_name
|
104 |
+
refiner_switch: float = Field(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0)
|
105 |
+
loras: List[Lora] = Field(default=default_loras_model)
|
106 |
+
advanced_params: AdvancedParams = AdvancedParams()
|
107 |
+
save_meta: bool = Field(default=True, description="Save meta data")
|
108 |
+
meta_scheme: str = Field(default='fooocus', description="Meta data scheme, one of [fooocus, a111]")
|
109 |
+
save_extension: str = Field(default='png', description="Save extension, one of [png, jpg, webp]")
|
110 |
+
save_name: str = Field(default='', description="Image name for output image, default is job id + seq")
|
111 |
+
read_wildcards_in_order: bool = Field(default=False, description="Read wildcards in order")
|
112 |
+
require_base64: bool = Field(default=False, description="Return base64 data of generated image")
|
113 |
+
async_process: bool = Field(default=False, description="Set to true will run async and return job info for retrieve generation result later")
|
114 |
+
webhook_url: str | None = Field(default='', description="Optional URL for a webhook callback. If provided, the system will send a POST request to this URL upon task completion or failure."
|
115 |
+
" This allows for asynchronous notification of task status.")
|
116 |
+
|
117 |
+
|
118 |
+
def advanced_params_parser(advanced_params: str | None) -> AdvancedParams:
|
119 |
+
"""
|
120 |
+
Parse advanced params, Convert to AdvancedParams
|
121 |
+
Args:
|
122 |
+
advanced_params: str, json format
|
123 |
+
Returns:
|
124 |
+
AdvancedParams object, if validate error return default value
|
125 |
+
"""
|
126 |
+
if advanced_params is not None and len(advanced_params) > 0:
|
127 |
+
try:
|
128 |
+
advanced_params_obj = AdvancedParams.__pydantic_validator__.validate_json(advanced_params)
|
129 |
+
return AdvancedParams(**advanced_params_obj)
|
130 |
+
except ValidationError:
|
131 |
+
return AdvancedParams()
|
132 |
+
return AdvancedParams()
|
fooocusapi/models/common/response.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Fooocus API models for response"""
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from pydantic import (
|
5 |
+
BaseModel,
|
6 |
+
ConfigDict,
|
7 |
+
Field
|
8 |
+
)
|
9 |
+
|
10 |
+
from fooocusapi.models.common.task import (
|
11 |
+
GeneratedImageResult,
|
12 |
+
AsyncJobStage
|
13 |
+
)
|
14 |
+
from fooocusapi.task_queue import TaskType
|
15 |
+
|
16 |
+
|
17 |
+
class DescribeImageResponse(BaseModel):
|
18 |
+
"""
|
19 |
+
describe image response
|
20 |
+
"""
|
21 |
+
describe: str
|
22 |
+
|
23 |
+
|
24 |
+
class AsyncJobResponse(BaseModel):
|
25 |
+
"""
|
26 |
+
Async job response
|
27 |
+
Attributes:
|
28 |
+
job_id: Job ID
|
29 |
+
job_type: Job type
|
30 |
+
job_stage: Job stage
|
31 |
+
job_progress: Job progress, 0-100
|
32 |
+
job_status: Job status
|
33 |
+
job_step_preview: Job step preview
|
34 |
+
job_result: Job result
|
35 |
+
"""
|
36 |
+
job_id: str = Field(description="Job ID")
|
37 |
+
job_type: TaskType = Field(description="Job type")
|
38 |
+
job_stage: AsyncJobStage = Field(description="Job running stage")
|
39 |
+
job_progress: int = Field(description="Job running progress, 100 is for finished.")
|
40 |
+
job_status: str | None = Field(None, description="Job running status in text")
|
41 |
+
job_step_preview: str | None = Field(None, description="Preview image of generation steps at current time, as base64 image")
|
42 |
+
job_result: List[GeneratedImageResult] | None = Field(None, description="Job generation result")
|
43 |
+
|
44 |
+
|
45 |
+
class JobQueueInfo(BaseModel):
|
46 |
+
"""
|
47 |
+
job queue info
|
48 |
+
Attributes:
|
49 |
+
running_size: int, The current running and waiting job count
|
50 |
+
finished_size: int, The current finished job count
|
51 |
+
last_job_id: str, Last submit generation job id
|
52 |
+
"""
|
53 |
+
running_size: int = Field(description="The current running and waiting job count")
|
54 |
+
finished_size: int = Field(description="Finished job count (after auto clean)")
|
55 |
+
last_job_id: str | None = Field(description="Last submit generation job id")
|
56 |
+
|
57 |
+
|
58 |
+
# TODO May need more detail fields, will add later when someone need
|
59 |
+
class JobHistoryInfo(BaseModel):
|
60 |
+
"""
|
61 |
+
job history info
|
62 |
+
"""
|
63 |
+
job_id: str
|
64 |
+
is_finished: bool = False
|
65 |
+
|
66 |
+
|
67 |
+
# Response model for the historical tasks
|
68 |
+
class JobHistoryResponse(BaseModel):
|
69 |
+
"""
|
70 |
+
job history response
|
71 |
+
"""
|
72 |
+
queue: List[JobHistoryInfo] = []
|
73 |
+
history: List[JobHistoryInfo] = []
|
74 |
+
|
75 |
+
|
76 |
+
class AllModelNamesResponse(BaseModel):
|
77 |
+
"""
|
78 |
+
all model list response
|
79 |
+
"""
|
80 |
+
model_filenames: List[str] = Field(description="All available model filenames")
|
81 |
+
lora_filenames: List[str] = Field(description="All available lora filenames")
|
82 |
+
|
83 |
+
model_config = ConfigDict(
|
84 |
+
protected_namespaces=('protect_me_', 'also_protect_')
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
class StopResponse(BaseModel):
|
89 |
+
"""stop task response"""
|
90 |
+
msg: str
|
fooocusapi/models/common/task.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Task and job related models
|
3 |
+
"""
|
4 |
+
from enum import Enum
|
5 |
+
from pydantic import (
|
6 |
+
BaseModel,
|
7 |
+
Field
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
class TaskType(str, Enum):
|
12 |
+
"""
|
13 |
+
Task type object
|
14 |
+
"""
|
15 |
+
text_2_img = 'Text to Image'
|
16 |
+
img_uov = 'Image Upscale or Variation'
|
17 |
+
img_inpaint_outpaint = 'Image Inpaint or Outpaint'
|
18 |
+
img_prompt = 'Image Prompt'
|
19 |
+
not_found = 'Not Found'
|
20 |
+
|
21 |
+
|
22 |
+
class GenerationFinishReason(str, Enum):
|
23 |
+
"""
|
24 |
+
Generation finish reason
|
25 |
+
"""
|
26 |
+
success = 'SUCCESS'
|
27 |
+
queue_is_full = 'QUEUE_IS_FULL'
|
28 |
+
user_cancel = 'USER_CANCEL'
|
29 |
+
error = 'ERROR'
|
30 |
+
|
31 |
+
|
32 |
+
class ImageGenerationResult:
|
33 |
+
"""
|
34 |
+
Image generation result
|
35 |
+
"""
|
36 |
+
def __init__(self, im: str | None, seed: str, finish_reason: GenerationFinishReason):
|
37 |
+
self.im = im
|
38 |
+
self.seed = seed
|
39 |
+
self.finish_reason = finish_reason
|
40 |
+
|
41 |
+
|
42 |
+
class AsyncJobStage(str, Enum):
|
43 |
+
"""
|
44 |
+
Async job stage
|
45 |
+
"""
|
46 |
+
waiting = 'WAITING'
|
47 |
+
running = 'RUNNING'
|
48 |
+
success = 'SUCCESS'
|
49 |
+
error = 'ERROR'
|
50 |
+
|
51 |
+
|
52 |
+
class GeneratedImageResult(BaseModel):
|
53 |
+
"""
|
54 |
+
Generated images result
|
55 |
+
"""
|
56 |
+
base64: str | None = Field(
|
57 |
+
description="Image encoded in base64, or null if finishReason is not 'SUCCESS', only return when request require base64")
|
58 |
+
url: str | None = Field(description="Image file static serve url, or null if finishReason is not 'SUCCESS'")
|
59 |
+
seed: str = Field(description="The seed associated with this image")
|
60 |
+
finish_reason: GenerationFinishReason
|
fooocusapi/models/requests_v1.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
requests models for v1 endpoints
|
3 |
+
"""
|
4 |
+
from typing import List
|
5 |
+
from fastapi.params import File
|
6 |
+
from fastapi import (
|
7 |
+
UploadFile,
|
8 |
+
Form
|
9 |
+
)
|
10 |
+
from fooocusapi.models.common.requests import (
|
11 |
+
CommonRequest,
|
12 |
+
advanced_params_parser
|
13 |
+
)
|
14 |
+
from fooocusapi.models.common.base import (
|
15 |
+
ImagePrompt,
|
16 |
+
ControlNetType,
|
17 |
+
OutpaintExpansion,
|
18 |
+
UpscaleOrVaryMethod,
|
19 |
+
PerformanceSelection
|
20 |
+
)
|
21 |
+
|
22 |
+
from fooocusapi.models.common.base import (
|
23 |
+
style_selection_parser,
|
24 |
+
lora_parser,
|
25 |
+
outpaint_selections_parser,
|
26 |
+
image_prompt_parser,
|
27 |
+
default_loras_json
|
28 |
+
)
|
29 |
+
|
30 |
+
from fooocusapi.configs.default import (
|
31 |
+
default_prompt_negative,
|
32 |
+
default_aspect_ratio,
|
33 |
+
default_base_model_name,
|
34 |
+
default_refiner_model_name,
|
35 |
+
default_refiner_switch,
|
36 |
+
default_cfg_scale,
|
37 |
+
default_styles,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
class ImgUpscaleOrVaryRequest(CommonRequest):
|
42 |
+
"""
|
43 |
+
Request for image upscale or variation
|
44 |
+
Attributes:
|
45 |
+
input_image: Input image
|
46 |
+
uov_method: Upscale or variation method
|
47 |
+
upscale_value: upscale value
|
48 |
+
Functions:
|
49 |
+
as_form: Convert request to form data
|
50 |
+
"""
|
51 |
+
input_image: UploadFile
|
52 |
+
uov_method: UpscaleOrVaryMethod
|
53 |
+
upscale_value: float | None
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def as_form(
|
57 |
+
cls,
|
58 |
+
input_image: UploadFile = Form(description="Init image for upscale or outpaint"),
|
59 |
+
uov_method: UpscaleOrVaryMethod = Form(),
|
60 |
+
upscale_value: float | None = Form(None, description="Upscale custom value, None for default value", ge=1.0, le=5.0),
|
61 |
+
prompt: str = Form(''),
|
62 |
+
negative_prompt: str = Form(default_prompt_negative),
|
63 |
+
style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
|
64 |
+
performance_selection: PerformanceSelection = Form(PerformanceSelection.speed, description="Performance Selection, one of 'Speed','Quality','Extreme Speed'"),
|
65 |
+
aspect_ratios_selection: str = Form(default_aspect_ratio, description="Aspect Ratios Selection, default 1152*896"),
|
66 |
+
image_number: int = Form(default=1, description="Image number", ge=1, le=32),
|
67 |
+
image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
|
68 |
+
sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
|
69 |
+
guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
|
70 |
+
base_model_name: str = Form(default_base_model_name, description="checkpoint file name"),
|
71 |
+
refiner_model_name: str = Form(default_refiner_model_name, description="refiner file name"),
|
72 |
+
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
|
73 |
+
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
|
74 |
+
advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
|
75 |
+
save_meta: bool = Form(default=False, description="Save metadata to image"),
|
76 |
+
meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
|
77 |
+
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
|
78 |
+
save_name: str = Form(default="", description="Save name, empty for auto generate"),
|
79 |
+
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
|
80 |
+
read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
|
81 |
+
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
|
82 |
+
webhook_url: str = Form(default="", description="Webhook url for generation result"),
|
83 |
+
):
|
84 |
+
style_selection_arr = style_selection_parser(style_selections)
|
85 |
+
loras_model = lora_parser(loras)
|
86 |
+
advanced_params_obj = advanced_params_parser(advanced_params)
|
87 |
+
|
88 |
+
return cls(
|
89 |
+
input_image=input_image, uov_method=uov_method, upscale_value=upscale_value,
|
90 |
+
prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
|
91 |
+
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
|
92 |
+
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
|
93 |
+
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
|
94 |
+
loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
|
95 |
+
save_extension=save_extension, save_name=save_name, require_base64=require_base64,
|
96 |
+
read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
|
97 |
+
|
98 |
+
|
99 |
+
class ImgInpaintOrOutpaintRequest(CommonRequest):
|
100 |
+
"""
|
101 |
+
Image Inpaint or Outpaint Request
|
102 |
+
"""
|
103 |
+
input_image: UploadFile | None
|
104 |
+
input_mask: UploadFile | None
|
105 |
+
inpaint_additional_prompt: str | None
|
106 |
+
outpaint_selections: List[OutpaintExpansion]
|
107 |
+
outpaint_distance_left: int
|
108 |
+
outpaint_distance_right: int
|
109 |
+
outpaint_distance_top: int
|
110 |
+
outpaint_distance_bottom: int
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def as_form(
|
114 |
+
cls,
|
115 |
+
input_image: UploadFile = Form(description="Init image for inpaint or outpaint"),
|
116 |
+
input_mask: UploadFile = Form(File(None), description="Inpaint or outpaint mask"),
|
117 |
+
inpaint_additional_prompt: str | None = Form("", description="Describe what you want to inpaint"),
|
118 |
+
outpaint_selections: List[str] = Form([], description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
|
119 |
+
outpaint_distance_left: int = Form(default=0, description="Set outpaint left distance, -1 for default"),
|
120 |
+
outpaint_distance_right: int = Form(default=0, description="Set outpaint right distance, -1 for default"),
|
121 |
+
outpaint_distance_top: int = Form(default=0, description="Set outpaint top distance, -1 for default"),
|
122 |
+
outpaint_distance_bottom: int = Form(default=0, description="Set outpaint bottom distance, -1 for default"),
|
123 |
+
prompt: str = Form(''),
|
124 |
+
negative_prompt: str = Form(default_prompt_negative),
|
125 |
+
style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
|
126 |
+
performance_selection: PerformanceSelection = Form(PerformanceSelection.speed, description="Performance Selection, one of 'Speed','Quality','Extreme Speed'"),
|
127 |
+
aspect_ratios_selection: str = Form(default_aspect_ratio, description="Aspect Ratios Selection, default 1152*896"),
|
128 |
+
image_number: int = Form(default=1, description="Image number", ge=1, le=32),
|
129 |
+
image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
|
130 |
+
sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
|
131 |
+
guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
|
132 |
+
base_model_name: str = Form(default_base_model_name),
|
133 |
+
refiner_model_name: str = Form(default_refiner_model_name),
|
134 |
+
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
|
135 |
+
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
|
136 |
+
advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
|
137 |
+
save_meta: bool = Form(default=False, description="Save metadata to image"),
|
138 |
+
meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
|
139 |
+
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
|
140 |
+
save_name: str = Form(default="", description="Save name, empty for auto generate"),
|
141 |
+
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
|
142 |
+
read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
|
143 |
+
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
|
144 |
+
webhook_url: str = Form(default="", description="Webhook url for generation result"),
|
145 |
+
):
|
146 |
+
if isinstance(input_mask, File):
|
147 |
+
input_mask = None
|
148 |
+
|
149 |
+
outpaint_selections_arr = outpaint_selections_parser(outpaint_selections)
|
150 |
+
style_selection_arr = style_selection_parser(style_selections)
|
151 |
+
loras_model = lora_parser(loras)
|
152 |
+
advanced_params_obj = advanced_params_parser(advanced_params)
|
153 |
+
|
154 |
+
return cls(
|
155 |
+
input_image=input_image, input_mask=input_mask, inpaint_additional_prompt=inpaint_additional_prompt,
|
156 |
+
outpaint_selections=outpaint_selections_arr, outpaint_distance_left=outpaint_distance_left,
|
157 |
+
outpaint_distance_right=outpaint_distance_right, outpaint_distance_top=outpaint_distance_top,
|
158 |
+
outpaint_distance_bottom=outpaint_distance_bottom, prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
|
159 |
+
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
|
160 |
+
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
|
161 |
+
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
|
162 |
+
loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
|
163 |
+
save_extension=save_extension, save_name=save_name, require_base64=require_base64,
|
164 |
+
read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
|
165 |
+
|
166 |
+
|
167 |
+
class ImgPromptRequest(ImgInpaintOrOutpaintRequest):
|
168 |
+
"""
|
169 |
+
Image Prompt Request
|
170 |
+
"""
|
171 |
+
image_prompts: List[ImagePrompt]
|
172 |
+
|
173 |
+
@classmethod
|
174 |
+
def as_form(
|
175 |
+
cls,
|
176 |
+
input_image: UploadFile = Form(File(None), description="Init image for inpaint or outpaint"),
|
177 |
+
input_mask: UploadFile = Form(File(None), description="Inpaint or outpaint mask"),
|
178 |
+
inpaint_additional_prompt: str | None = Form(None, description="Describe what you want to inpaint"),
|
179 |
+
outpaint_selections: List[str] = Form([], description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
|
180 |
+
outpaint_distance_left: int = Form(default=0, description="Set outpaint left distance, 0 for default"),
|
181 |
+
outpaint_distance_right: int = Form(default=0, description="Set outpaint right distance, 0 for default"),
|
182 |
+
outpaint_distance_top: int = Form(default=0, description="Set outpaint top distance, 0 for default"),
|
183 |
+
outpaint_distance_bottom: int = Form(default=0, description="Set outpaint bottom distance, 0 for default"),
|
184 |
+
cn_img1: UploadFile = Form(File(None), description="Input image for image prompt"),
|
185 |
+
cn_stop1: float | None = Form(
|
186 |
+
default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
|
187 |
+
cn_weight1: float | None = Form(
|
188 |
+
default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
|
189 |
+
cn_type1: ControlNetType = Form(
|
190 |
+
default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
|
191 |
+
cn_img2: UploadFile = Form(
|
192 |
+
File(None), description="Input image for image prompt"),
|
193 |
+
cn_stop2: float | None = Form(
|
194 |
+
default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
|
195 |
+
cn_weight2: float | None = Form(
|
196 |
+
default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
|
197 |
+
cn_type2: ControlNetType = Form(
|
198 |
+
default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
|
199 |
+
cn_img3: UploadFile = Form(
|
200 |
+
File(None), description="Input image for image prompt"),
|
201 |
+
cn_stop3: float | None = Form(
|
202 |
+
default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
|
203 |
+
cn_weight3: float | None = Form(
|
204 |
+
default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
|
205 |
+
cn_type3: ControlNetType = Form(
|
206 |
+
default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
|
207 |
+
cn_img4: UploadFile = Form(
|
208 |
+
File(None), description="Input image for image prompt"),
|
209 |
+
cn_stop4: float | None = Form(
|
210 |
+
default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
|
211 |
+
cn_weight4: float | None = Form(
|
212 |
+
default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
|
213 |
+
cn_type4: ControlNetType = Form(
|
214 |
+
default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
|
215 |
+
prompt: str = Form(''),
|
216 |
+
negative_prompt: str = Form(default_prompt_negative),
|
217 |
+
style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
|
218 |
+
performance_selection: PerformanceSelection = Form(
|
219 |
+
PerformanceSelection.speed),
|
220 |
+
aspect_ratios_selection: str = Form(default_aspect_ratio),
|
221 |
+
image_number: int = Form(
|
222 |
+
default=1, description="Image number", ge=1, le=32),
|
223 |
+
image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
|
224 |
+
sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
|
225 |
+
guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
|
226 |
+
base_model_name: str = Form(default_base_model_name),
|
227 |
+
refiner_model_name: str = Form(default_refiner_model_name),
|
228 |
+
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
|
229 |
+
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
|
230 |
+
advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
|
231 |
+
save_meta: bool = Form(default=False, description="Save metadata to image"),
|
232 |
+
meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
|
233 |
+
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
|
234 |
+
save_name: str = Form(default="", description="Save name, empty for auto generate"),
|
235 |
+
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
|
236 |
+
read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
|
237 |
+
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
|
238 |
+
webhook_url: str = Form(default="", description="Webhook url for generation result"),
|
239 |
+
):
|
240 |
+
if isinstance(input_image, File):
|
241 |
+
input_image = None
|
242 |
+
if isinstance(input_mask, File):
|
243 |
+
input_mask = None
|
244 |
+
if isinstance(cn_img1, File):
|
245 |
+
cn_img1 = None
|
246 |
+
if isinstance(cn_img2, File):
|
247 |
+
cn_img2 = None
|
248 |
+
if isinstance(cn_img3, File):
|
249 |
+
cn_img3 = None
|
250 |
+
if isinstance(cn_img4, File):
|
251 |
+
cn_img4 = None
|
252 |
+
|
253 |
+
outpaint_selections_arr = outpaint_selections_parser(outpaint_selections)
|
254 |
+
|
255 |
+
image_prompt_config = [
|
256 |
+
(cn_img1, cn_stop1, cn_weight1, cn_type1),
|
257 |
+
(cn_img2, cn_stop2, cn_weight2, cn_type2),
|
258 |
+
(cn_img3, cn_stop3, cn_weight3, cn_type3),
|
259 |
+
(cn_img4, cn_stop4, cn_weight4, cn_type4)]
|
260 |
+
image_prompts = image_prompt_parser(image_prompt_config)
|
261 |
+
style_selection_arr = style_selection_parser(style_selections)
|
262 |
+
loras_model = lora_parser(loras)
|
263 |
+
advanced_params_obj = advanced_params_parser(advanced_params)
|
264 |
+
|
265 |
+
return cls(
|
266 |
+
input_image=input_image, input_mask=input_mask, inpaint_additional_prompt=inpaint_additional_prompt, outpaint_selections=outpaint_selections_arr,
|
267 |
+
outpaint_distance_left=outpaint_distance_left, outpaint_distance_right=outpaint_distance_right, outpaint_distance_top=outpaint_distance_top, outpaint_distance_bottom=outpaint_distance_bottom,
|
268 |
+
image_prompts=image_prompts, prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
|
269 |
+
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
|
270 |
+
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
|
271 |
+
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
|
272 |
+
loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
|
273 |
+
save_extension=save_extension, save_name=save_name, require_base64=require_base64,
|
274 |
+
read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
|
fooocusapi/models/requests_v2.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""V2 API models"""
|
2 |
+
from typing import List
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from fooocusapi.models.common.requests import CommonRequest
|
5 |
+
from fooocusapi.models.common.base import (
|
6 |
+
ControlNetType,
|
7 |
+
OutpaintExpansion,
|
8 |
+
ImagePrompt,
|
9 |
+
UpscaleOrVaryMethod
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class ImagePromptJson(BaseModel):
|
14 |
+
"""Image prompt for V2 API"""
|
15 |
+
cn_img: str | None = Field(None, description="Input image for image prompt as base64")
|
16 |
+
cn_stop: float | None = Field(0, ge=0, le=1, description="Stop at for image prompt, 0 for default value")
|
17 |
+
cn_weight: float | None = Field(0, ge=0, le=2, description="Weight for image prompt, 0 for default value")
|
18 |
+
cn_type: ControlNetType = Field(default=ControlNetType.cn_ip, description="ControlNet type for image prompt")
|
19 |
+
|
20 |
+
|
21 |
+
class ImgInpaintOrOutpaintRequestJson(CommonRequest):
|
22 |
+
"""image inpaint or outpaint request"""
|
23 |
+
input_image: str = Field('', description="Init image for inpaint or outpaint as base64")
|
24 |
+
input_mask: str | None = Field('', description="Inpaint or outpaint mask as base64")
|
25 |
+
inpaint_additional_prompt: str | None = Field('', description="Describe what you want to inpaint")
|
26 |
+
outpaint_selections: List[OutpaintExpansion] = []
|
27 |
+
outpaint_distance_left: int | None = Field(-1, description="Set outpaint left distance")
|
28 |
+
outpaint_distance_right: int | None = Field(-1, description="Set outpaint right distance")
|
29 |
+
outpaint_distance_top: int | None = Field(-1, description="Set outpaint top distance")
|
30 |
+
outpaint_distance_bottom: int | None = Field(-1, description="Set outpaint bottom distance")
|
31 |
+
image_prompts: List[ImagePromptJson | ImagePrompt] = []
|
32 |
+
|
33 |
+
|
34 |
+
class ImgPromptRequestJson(ImgInpaintOrOutpaintRequestJson):
|
35 |
+
"""img prompt request json"""
|
36 |
+
input_image: str | None = Field(None, description="Init image for inpaint or outpaint as base64")
|
37 |
+
image_prompts: List[ImagePromptJson | ImagePrompt]
|
38 |
+
|
39 |
+
|
40 |
+
class Text2ImgRequestWithPrompt(CommonRequest):
|
41 |
+
"""text to image request with prompt"""
|
42 |
+
image_prompts: List[ImagePromptJson] = []
|
43 |
+
|
44 |
+
|
45 |
+
class ImgUpscaleOrVaryRequestJson(CommonRequest):
|
46 |
+
"""img upscale or vary request json"""
|
47 |
+
uov_method: UpscaleOrVaryMethod = UpscaleOrVaryMethod.upscale_2
|
48 |
+
upscale_value: float | None = Field(1.0, ge=1.0, le=5.0, description="Upscale custom value, 1.0 for default value")
|
49 |
+
input_image: str = Field(description="Init image for upscale or outpaint as base64")
|
50 |
+
image_prompts: List[ImagePromptJson | ImagePrompt] = []
|
fooocusapi/parameters.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple
|
2 |
+
import numpy as np
|
3 |
+
import copy
|
4 |
+
|
5 |
+
from fooocusapi.models.common.requests import AdvancedParams
|
6 |
+
|
7 |
+
|
8 |
+
class ImageGenerationParams:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
prompt: str,
|
12 |
+
negative_prompt: str,
|
13 |
+
style_selections: List[str],
|
14 |
+
performance_selection: str,
|
15 |
+
aspect_ratios_selection: str,
|
16 |
+
image_number: int,
|
17 |
+
image_seed: int | None,
|
18 |
+
sharpness: float,
|
19 |
+
guidance_scale: float,
|
20 |
+
base_model_name: str,
|
21 |
+
refiner_model_name: str,
|
22 |
+
refiner_switch: float,
|
23 |
+
loras: List[Tuple[str, float]],
|
24 |
+
uov_input_image: np.ndarray | None,
|
25 |
+
uov_method: str,
|
26 |
+
upscale_value: float | None,
|
27 |
+
outpaint_selections: List[str],
|
28 |
+
outpaint_distance_left: int,
|
29 |
+
outpaint_distance_right: int,
|
30 |
+
outpaint_distance_top: int,
|
31 |
+
outpaint_distance_bottom: int,
|
32 |
+
inpaint_input_image: Dict[str, np.ndarray] | None,
|
33 |
+
inpaint_additional_prompt: str | None,
|
34 |
+
image_prompts: List[Tuple[np.ndarray, float, float, str]],
|
35 |
+
advanced_params: List[any] | None,
|
36 |
+
save_extension: str,
|
37 |
+
save_meta: bool,
|
38 |
+
meta_scheme: str,
|
39 |
+
save_name: str,
|
40 |
+
require_base64: bool,
|
41 |
+
):
|
42 |
+
self.prompt = prompt
|
43 |
+
self.negative_prompt = negative_prompt
|
44 |
+
self.style_selections = style_selections
|
45 |
+
self.performance_selection = performance_selection
|
46 |
+
self.aspect_ratios_selection = aspect_ratios_selection
|
47 |
+
self.image_number = image_number
|
48 |
+
self.image_seed = image_seed
|
49 |
+
self.sharpness = sharpness
|
50 |
+
self.guidance_scale = guidance_scale
|
51 |
+
self.base_model_name = base_model_name
|
52 |
+
self.refiner_model_name = refiner_model_name
|
53 |
+
self.refiner_switch = refiner_switch
|
54 |
+
self.loras = loras
|
55 |
+
self.uov_input_image = uov_input_image
|
56 |
+
self.uov_method = uov_method
|
57 |
+
self.upscale_value = upscale_value
|
58 |
+
self.outpaint_selections = outpaint_selections
|
59 |
+
self.outpaint_distance_left = outpaint_distance_left
|
60 |
+
self.outpaint_distance_right = outpaint_distance_right
|
61 |
+
self.outpaint_distance_top = outpaint_distance_top
|
62 |
+
self.outpaint_distance_bottom = outpaint_distance_bottom
|
63 |
+
self.inpaint_input_image = inpaint_input_image
|
64 |
+
self.inpaint_additional_prompt = inpaint_additional_prompt
|
65 |
+
self.image_prompts = image_prompts
|
66 |
+
self.save_extension = save_extension
|
67 |
+
self.save_meta = save_meta
|
68 |
+
self.meta_scheme = meta_scheme
|
69 |
+
self.save_name = save_name
|
70 |
+
self.require_base64 = require_base64
|
71 |
+
self.advanced_params = advanced_params
|
72 |
+
|
73 |
+
if self.advanced_params is None:
|
74 |
+
self.advanced_params = AdvancedParams()
|
75 |
+
|
76 |
+
# Auto set mixing_image_prompt_and_inpaint to True
|
77 |
+
if len(self.image_prompts) > 0 and self.inpaint_input_image is not None:
|
78 |
+
print("Mixing Image Prompts and Inpaint Enabled")
|
79 |
+
self.advanced_params.mixing_image_prompt_and_inpaint = True
|
80 |
+
if len(self.image_prompts) > 0 and self.uov_input_image is not None:
|
81 |
+
print("Mixing Image Prompts and Vary Upscale Enabled")
|
82 |
+
self.advanced_params.mixing_image_prompt_and_vary_upscale = True
|
83 |
+
|
84 |
+
def to_dict(self):
|
85 |
+
"""
|
86 |
+
Convert the ImageGenerationParams object to a dictionary.
|
87 |
+
Args:
|
88 |
+
self:
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
self to dict
|
92 |
+
"""
|
93 |
+
obj_dict = copy.deepcopy(self)
|
94 |
+
return obj_dict.__dict__
|
fooocusapi/routes/__init__.py
ADDED
File without changes
|
fooocusapi/routes/generate_v1.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Generate API V1 routes
|
2 |
+
|
3 |
+
"""
|
4 |
+
from typing import List, Optional
|
5 |
+
from fastapi import APIRouter, Depends, Header, Query, UploadFile
|
6 |
+
from fastapi.params import File
|
7 |
+
|
8 |
+
from modules.util import HWC3
|
9 |
+
|
10 |
+
from fooocusapi.models.common.base import DescribeImageType
|
11 |
+
from fooocusapi.utils.api_utils import api_key_auth
|
12 |
+
|
13 |
+
from fooocusapi.models.common.requests import CommonRequest as Text2ImgRequest
|
14 |
+
from fooocusapi.models.requests_v1 import (
|
15 |
+
ImgUpscaleOrVaryRequest,
|
16 |
+
ImgPromptRequest,
|
17 |
+
ImgInpaintOrOutpaintRequest
|
18 |
+
)
|
19 |
+
from fooocusapi.models.common.response import (
|
20 |
+
AsyncJobResponse,
|
21 |
+
GeneratedImageResult,
|
22 |
+
DescribeImageResponse,
|
23 |
+
StopResponse
|
24 |
+
)
|
25 |
+
from fooocusapi.utils.call_worker import call_worker
|
26 |
+
from fooocusapi.utils.img_utils import read_input_image
|
27 |
+
from fooocusapi.configs.default import img_generate_responses
|
28 |
+
from fooocusapi.worker import process_stop
|
29 |
+
|
30 |
+
|
31 |
+
secure_router = APIRouter(
|
32 |
+
dependencies=[Depends(api_key_auth)]
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def stop_worker():
|
37 |
+
"""Interrupt worker process"""
|
38 |
+
process_stop()
|
39 |
+
|
40 |
+
|
41 |
+
@secure_router.post(
|
42 |
+
path="/v1/generation/text-to-image",
|
43 |
+
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
44 |
+
responses=img_generate_responses,
|
45 |
+
tags=["GenerateV1"])
|
46 |
+
def text2img_generation(
|
47 |
+
req: Text2ImgRequest,
|
48 |
+
accept: str = Header(None),
|
49 |
+
accept_query: str | None = Query(
|
50 |
+
None, alias='accept',
|
51 |
+
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
52 |
+
"""\nText to Image Generation\n
|
53 |
+
A text to image generation endpoint
|
54 |
+
Arguments:
|
55 |
+
req {Text2ImgRequest} -- Text to image generation request
|
56 |
+
accept {str} -- Accept header
|
57 |
+
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
58 |
+
returns:
|
59 |
+
Response -- img_generate_responses
|
60 |
+
"""
|
61 |
+
if accept_query is not None and len(accept_query) > 0:
|
62 |
+
accept = accept_query
|
63 |
+
|
64 |
+
return call_worker(req, accept)
|
65 |
+
|
66 |
+
|
67 |
+
@secure_router.post(
|
68 |
+
path="/v1/generation/image-upscale-vary",
|
69 |
+
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
70 |
+
responses=img_generate_responses,
|
71 |
+
tags=["GenerateV1"])
|
72 |
+
def img_upscale_or_vary(
|
73 |
+
input_image: UploadFile,
|
74 |
+
req: ImgUpscaleOrVaryRequest = Depends(ImgUpscaleOrVaryRequest.as_form),
|
75 |
+
accept: str = Header(None),
|
76 |
+
accept_query: str | None = Query(
|
77 |
+
None, alias='accept',
|
78 |
+
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
79 |
+
"""\nImage upscale or vary\n
|
80 |
+
Image upscale or vary
|
81 |
+
Arguments:
|
82 |
+
input_image {UploadFile} -- Input image file
|
83 |
+
req {ImgUpscaleOrVaryRequest} -- Request body
|
84 |
+
accept {str} -- Accept header
|
85 |
+
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
86 |
+
Returns:
|
87 |
+
Response -- img_generate_responses
|
88 |
+
"""
|
89 |
+
if accept_query is not None and len(accept_query) > 0:
|
90 |
+
accept = accept_query
|
91 |
+
|
92 |
+
return call_worker(req, accept)
|
93 |
+
|
94 |
+
|
95 |
+
@secure_router.post(
|
96 |
+
path="/v1/generation/image-inpaint-outpaint",
|
97 |
+
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
98 |
+
responses=img_generate_responses,
|
99 |
+
tags=["GenerateV1"])
|
100 |
+
def img_inpaint_or_outpaint(
|
101 |
+
input_image: UploadFile,
|
102 |
+
req: ImgInpaintOrOutpaintRequest = Depends(ImgInpaintOrOutpaintRequest.as_form),
|
103 |
+
accept: str = Header(None),
|
104 |
+
accept_query: str | None = Query(
|
105 |
+
None, alias='accept',
|
106 |
+
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
107 |
+
"""\nInpaint or outpaint\n
|
108 |
+
Inpaint or outpaint
|
109 |
+
Arguments:
|
110 |
+
input_image {UploadFile} -- Input image file
|
111 |
+
req {ImgInpaintOrOutpaintRequest} -- Request body
|
112 |
+
accept {str} -- Accept header
|
113 |
+
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
114 |
+
"""
|
115 |
+
if accept_query is not None and len(accept_query) > 0:
|
116 |
+
accept = accept_query
|
117 |
+
|
118 |
+
return call_worker(req, accept)
|
119 |
+
|
120 |
+
|
121 |
+
@secure_router.post(
|
122 |
+
path="/v1/generation/image-prompt",
|
123 |
+
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
124 |
+
responses=img_generate_responses,
|
125 |
+
tags=["GenerateV1"])
|
126 |
+
def img_prompt(
|
127 |
+
cn_img1: Optional[UploadFile] = File(None),
|
128 |
+
req: ImgPromptRequest = Depends(ImgPromptRequest.as_form),
|
129 |
+
accept: str = Header(None),
|
130 |
+
accept_query: str | None = Query(
|
131 |
+
None, alias='accept',
|
132 |
+
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
133 |
+
"""\nImage Prompt\n
|
134 |
+
Image Prompt
|
135 |
+
A prompt-based image generation.
|
136 |
+
Arguments:
|
137 |
+
cn_img1 {UploadFile} -- Input image file
|
138 |
+
req {ImgPromptRequest} -- Request body
|
139 |
+
accept {str} -- Accept header
|
140 |
+
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
141 |
+
Returns:
|
142 |
+
Response -- img_generate_responses
|
143 |
+
"""
|
144 |
+
if accept_query is not None and len(accept_query) > 0:
|
145 |
+
accept = accept_query
|
146 |
+
|
147 |
+
return call_worker(req, accept)
|
148 |
+
|
149 |
+
|
150 |
+
@secure_router.post(
|
151 |
+
path="/v1/tools/describe-image",
|
152 |
+
response_model=DescribeImageResponse,
|
153 |
+
tags=["GenerateV1"])
|
154 |
+
def describe_image(
|
155 |
+
image: UploadFile,
|
156 |
+
image_type: DescribeImageType = Query(
|
157 |
+
DescribeImageType.photo,
|
158 |
+
description="Image type, 'Photo' or 'Anime'")):
|
159 |
+
"""\nDescribe image\n
|
160 |
+
Describe image, Get tags from an image
|
161 |
+
Arguments:
|
162 |
+
image {UploadFile} -- Image to get tags
|
163 |
+
image_type {DescribeImageType} -- Image type, 'Photo' or 'Anime'
|
164 |
+
Returns:
|
165 |
+
DescribeImageResponse -- Describe image response, a string
|
166 |
+
"""
|
167 |
+
if image_type == DescribeImageType.photo:
|
168 |
+
from extras.interrogate import default_interrogator as default_interrogator_photo
|
169 |
+
interrogator = default_interrogator_photo
|
170 |
+
else:
|
171 |
+
from extras.wd14tagger import default_interrogator as default_interrogator_anime
|
172 |
+
interrogator = default_interrogator_anime
|
173 |
+
img = HWC3(read_input_image(image))
|
174 |
+
result = interrogator(img)
|
175 |
+
return DescribeImageResponse(describe=result)
|
176 |
+
|
177 |
+
|
178 |
+
@secure_router.post(
|
179 |
+
path="/v1/generation/stop",
|
180 |
+
response_model=StopResponse,
|
181 |
+
description="Job stopping",
|
182 |
+
tags=["Default"])
|
183 |
+
def stop():
|
184 |
+
"""Interrupt worker"""
|
185 |
+
stop_worker()
|
186 |
+
return StopResponse(msg="success")
|
fooocusapi/routes/generate_v2.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Generate API V2 routes
|
2 |
+
|
3 |
+
"""
|
4 |
+
from typing import List
|
5 |
+
from fastapi import APIRouter, Depends, Header, Query
|
6 |
+
|
7 |
+
from fooocusapi.utils.api_utils import api_key_auth
|
8 |
+
from fooocusapi.models.requests_v1 import ImagePrompt
|
9 |
+
from fooocusapi.models.requests_v2 import (
|
10 |
+
ImgInpaintOrOutpaintRequestJson,
|
11 |
+
ImgPromptRequestJson,
|
12 |
+
Text2ImgRequestWithPrompt,
|
13 |
+
ImgUpscaleOrVaryRequestJson
|
14 |
+
)
|
15 |
+
from fooocusapi.models.common.response import (
|
16 |
+
AsyncJobResponse,
|
17 |
+
GeneratedImageResult
|
18 |
+
)
|
19 |
+
from fooocusapi.utils.call_worker import call_worker
|
20 |
+
from fooocusapi.utils.img_utils import base64_to_stream
|
21 |
+
from fooocusapi.configs.default import img_generate_responses
|
22 |
+
|
23 |
+
|
24 |
+
secure_router = APIRouter(
|
25 |
+
dependencies=[Depends(api_key_auth)]
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
@secure_router.post(
|
30 |
+
path="/v2/generation/text-to-image-with-ip",
|
31 |
+
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
32 |
+
responses=img_generate_responses,
|
33 |
+
tags=["GenerateV2"])
|
34 |
+
def text_to_img_with_ip(
|
35 |
+
req: Text2ImgRequestWithPrompt,
|
36 |
+
accept: str = Header(None),
|
37 |
+
accept_query: str | None = Query(
|
38 |
+
default=None, alias='accept',
|
39 |
+
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
40 |
+
"""\nText to image with prompt\n
|
41 |
+
Text to image with prompt
|
42 |
+
Arguments:
|
43 |
+
req {Text2ImgRequestWithPrompt} -- Text to image generation request
|
44 |
+
accept {str} -- Accept header
|
45 |
+
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
46 |
+
Returns:
|
47 |
+
Response -- img_generate_responses
|
48 |
+
"""
|
49 |
+
if accept_query is not None and len(accept_query) > 0:
|
50 |
+
accept = accept_query
|
51 |
+
|
52 |
+
default_image_prompt = ImagePrompt(cn_img=None)
|
53 |
+
image_prompts_files: List[ImagePrompt] = []
|
54 |
+
for image_prompt in req.image_prompts:
|
55 |
+
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
|
56 |
+
image = ImagePrompt(
|
57 |
+
cn_img=image_prompt.cn_img,
|
58 |
+
cn_stop=image_prompt.cn_stop,
|
59 |
+
cn_weight=image_prompt.cn_weight,
|
60 |
+
cn_type=image_prompt.cn_type)
|
61 |
+
image_prompts_files.append(image)
|
62 |
+
|
63 |
+
while len(image_prompts_files) <= 4:
|
64 |
+
image_prompts_files.append(default_image_prompt)
|
65 |
+
|
66 |
+
req.image_prompts = image_prompts_files
|
67 |
+
|
68 |
+
return call_worker(req, accept)
|
69 |
+
|
70 |
+
|
71 |
+
@secure_router.post(
|
72 |
+
path="/v2/generation/image-upscale-vary",
|
73 |
+
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
74 |
+
responses=img_generate_responses,
|
75 |
+
tags=["GenerateV2"])
|
76 |
+
def img_upscale_or_vary(
|
77 |
+
req: ImgUpscaleOrVaryRequestJson,
|
78 |
+
accept: str = Header(None),
|
79 |
+
accept_query: str | None = Query(
|
80 |
+
None, alias='accept', description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
81 |
+
"""\nImage upscale or vary\n
|
82 |
+
Image upscale or vary
|
83 |
+
Arguments:
|
84 |
+
req {ImgUpscaleOrVaryRequestJson} -- Image upscale or vary request
|
85 |
+
accept {str} -- Accept header
|
86 |
+
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
87 |
+
Returns:
|
88 |
+
Response -- img_generate_responses
|
89 |
+
"""
|
90 |
+
if accept_query is not None and len(accept_query) > 0:
|
91 |
+
accept = accept_query
|
92 |
+
|
93 |
+
req.input_image = base64_to_stream(req.input_image)
|
94 |
+
|
95 |
+
default_image_prompt = ImagePrompt(cn_img=None)
|
96 |
+
image_prompts_files: List[ImagePrompt] = []
|
97 |
+
for image_prompt in req.image_prompts:
|
98 |
+
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
|
99 |
+
image = ImagePrompt(
|
100 |
+
cn_img=image_prompt.cn_img,
|
101 |
+
cn_stop=image_prompt.cn_stop,
|
102 |
+
cn_weight=image_prompt.cn_weight,
|
103 |
+
cn_type=image_prompt.cn_type)
|
104 |
+
image_prompts_files.append(image)
|
105 |
+
while len(image_prompts_files) <= 4:
|
106 |
+
image_prompts_files.append(default_image_prompt)
|
107 |
+
req.image_prompts = image_prompts_files
|
108 |
+
|
109 |
+
return call_worker(req, accept)
|
110 |
+
|
111 |
+
|
112 |
+
@secure_router.post(
|
113 |
+
path="/v2/generation/image-inpaint-outpaint",
|
114 |
+
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
115 |
+
responses=img_generate_responses,
|
116 |
+
tags=["GenerateV2"])
|
117 |
+
def img_inpaint_or_outpaint(
|
118 |
+
req: ImgInpaintOrOutpaintRequestJson,
|
119 |
+
accept: str = Header(None),
|
120 |
+
accept_query: str | None = Query(
|
121 |
+
None, alias='accept',
|
122 |
+
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
123 |
+
"""\nInpaint or outpaint\n
|
124 |
+
Inpaint or outpaint
|
125 |
+
Arguments:
|
126 |
+
req {ImgInpaintOrOutpaintRequestJson} -- Request body
|
127 |
+
accept {str} -- Accept header
|
128 |
+
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
129 |
+
Returns:
|
130 |
+
Response -- img_generate_responses
|
131 |
+
"""
|
132 |
+
if accept_query is not None and len(accept_query) > 0:
|
133 |
+
accept = accept_query
|
134 |
+
|
135 |
+
req.input_image = base64_to_stream(req.input_image)
|
136 |
+
if req.input_mask is not None:
|
137 |
+
req.input_mask = base64_to_stream(req.input_mask)
|
138 |
+
default_image_prompt = ImagePrompt(cn_img=None)
|
139 |
+
image_prompts_files: List[ImagePrompt] = []
|
140 |
+
for image_prompt in req.image_prompts:
|
141 |
+
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
|
142 |
+
image = ImagePrompt(
|
143 |
+
cn_img=image_prompt.cn_img,
|
144 |
+
cn_stop=image_prompt.cn_stop,
|
145 |
+
cn_weight=image_prompt.cn_weight,
|
146 |
+
cn_type=image_prompt.cn_type)
|
147 |
+
image_prompts_files.append(image)
|
148 |
+
while len(image_prompts_files) <= 4:
|
149 |
+
image_prompts_files.append(default_image_prompt)
|
150 |
+
req.image_prompts = image_prompts_files
|
151 |
+
|
152 |
+
return call_worker(req, accept)
|
153 |
+
|
154 |
+
|
155 |
+
@secure_router.post(
|
156 |
+
path="/v2/generation/image-prompt",
|
157 |
+
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
158 |
+
responses=img_generate_responses,
|
159 |
+
tags=["GenerateV2"])
|
160 |
+
def img_prompt(
|
161 |
+
req: ImgPromptRequestJson,
|
162 |
+
accept: str = Header(None),
|
163 |
+
accept_query: str | None = Query(
|
164 |
+
None, alias='accept',
|
165 |
+
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
166 |
+
"""\nImage prompt\n
|
167 |
+
Image prompt generation
|
168 |
+
Arguments:
|
169 |
+
req {ImgPromptRequest} -- Request body
|
170 |
+
accept {str} -- Accept header
|
171 |
+
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
172 |
+
Returns:
|
173 |
+
Response -- img_generate_responses
|
174 |
+
"""
|
175 |
+
if accept_query is not None and len(accept_query) > 0:
|
176 |
+
accept = accept_query
|
177 |
+
|
178 |
+
if req.input_image is not None:
|
179 |
+
req.input_image = base64_to_stream(req.input_image)
|
180 |
+
if req.input_mask is not None:
|
181 |
+
req.input_mask = base64_to_stream(req.input_mask)
|
182 |
+
|
183 |
+
default_image_prompt = ImagePrompt(cn_img=None)
|
184 |
+
image_prompts_files: List[ImagePrompt] = []
|
185 |
+
for image_prompt in req.image_prompts:
|
186 |
+
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
|
187 |
+
image = ImagePrompt(
|
188 |
+
cn_img=image_prompt.cn_img,
|
189 |
+
cn_stop=image_prompt.cn_stop,
|
190 |
+
cn_weight=image_prompt.cn_weight,
|
191 |
+
cn_type=image_prompt.cn_type)
|
192 |
+
image_prompts_files.append(image)
|
193 |
+
|
194 |
+
while len(image_prompts_files) <= 4:
|
195 |
+
image_prompts_files.append(default_image_prompt)
|
196 |
+
|
197 |
+
req.image_prompts = image_prompts_files
|
198 |
+
|
199 |
+
return call_worker(req, accept)
|
fooocusapi/routes/query.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Query API"""
|
2 |
+
from typing import List
|
3 |
+
from fastapi import Depends, Response, APIRouter
|
4 |
+
|
5 |
+
from fooocusapi.args import args
|
6 |
+
|
7 |
+
from fooocusapi.models.common.requests import QueryJobRequest
|
8 |
+
from fooocusapi.models.common.response import (
|
9 |
+
AsyncJobResponse,
|
10 |
+
JobHistoryInfo,
|
11 |
+
JobQueueInfo,
|
12 |
+
JobHistoryResponse,
|
13 |
+
AllModelNamesResponse
|
14 |
+
)
|
15 |
+
from fooocusapi.models.common.task import AsyncJobStage
|
16 |
+
|
17 |
+
from fooocusapi.utils.api_utils import generate_async_output, api_key_auth
|
18 |
+
from fooocusapi.task_queue import TaskType
|
19 |
+
from fooocusapi.worker import worker_queue
|
20 |
+
|
21 |
+
secure_router = APIRouter(dependencies=[Depends(api_key_auth)])
|
22 |
+
|
23 |
+
|
24 |
+
@secure_router.get(path="/", tags=['Query'])
|
25 |
+
def home():
|
26 |
+
"""Home page"""
|
27 |
+
return Response(
|
28 |
+
content='Swagger-UI to: <a href="/docs">/docs</a>',
|
29 |
+
media_type="text/html"
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
@secure_router.get(
|
34 |
+
path="/ping",
|
35 |
+
description="Returns a simple 'pong'",
|
36 |
+
tags=['Query'])
|
37 |
+
async def ping():
|
38 |
+
"""\nPing\n
|
39 |
+
Ping page, just to check if the fastapi is up.
|
40 |
+
Instant return correct, does not mean the service is available.
|
41 |
+
Returns:
|
42 |
+
A simple string pong
|
43 |
+
"""
|
44 |
+
return 'pong'
|
45 |
+
|
46 |
+
|
47 |
+
@secure_router.get(
|
48 |
+
path="/v1/generation/query-job",
|
49 |
+
response_model=AsyncJobResponse,
|
50 |
+
description="Query async generation job",
|
51 |
+
tags=['Query'])
|
52 |
+
def query_job(req: QueryJobRequest = Depends()):
|
53 |
+
"""query job info by id"""
|
54 |
+
queue_task = worker_queue.get_task(req.job_id, True)
|
55 |
+
if queue_task is None:
|
56 |
+
result = AsyncJobResponse(
|
57 |
+
job_id="",
|
58 |
+
job_type=TaskType.not_found,
|
59 |
+
job_stage=AsyncJobStage.error,
|
60 |
+
job_progress=0,
|
61 |
+
job_status="Job not found")
|
62 |
+
content = result.model_dump_json()
|
63 |
+
return Response(content=content, media_type='application/json', status_code=404)
|
64 |
+
return generate_async_output(queue_task, req.require_step_preview)
|
65 |
+
|
66 |
+
|
67 |
+
@secure_router.get(
|
68 |
+
path="/v1/generation/job-queue",
|
69 |
+
response_model=JobQueueInfo,
|
70 |
+
description="Query job queue info",
|
71 |
+
tags=['Query'])
|
72 |
+
def job_queue():
|
73 |
+
"""Query job queue info"""
|
74 |
+
queue = JobQueueInfo(
|
75 |
+
running_size=len(worker_queue.queue),
|
76 |
+
finished_size=len(worker_queue.history),
|
77 |
+
last_job_id=worker_queue.last_job_id
|
78 |
+
)
|
79 |
+
return queue
|
80 |
+
|
81 |
+
|
82 |
+
@secure_router.get(
|
83 |
+
path="/v1/generation/job-history",
|
84 |
+
response_model=JobHistoryResponse | dict,
|
85 |
+
description="Query historical job data",
|
86 |
+
tags=["Query"])
|
87 |
+
def get_history(job_id: str = None, page: int = 0, page_size: int = 20):
|
88 |
+
"""Fetch and return the historical tasks"""
|
89 |
+
queue = [
|
90 |
+
JobHistoryInfo(
|
91 |
+
job_id=item.job_id,
|
92 |
+
is_finished=item.is_finished
|
93 |
+
) for item in worker_queue.queue
|
94 |
+
]
|
95 |
+
if not args.persistent:
|
96 |
+
history = [
|
97 |
+
JobHistoryInfo(
|
98 |
+
job_id=item.job_id,
|
99 |
+
is_finished=item.is_finished
|
100 |
+
) for item in worker_queue.history
|
101 |
+
]
|
102 |
+
return JobHistoryResponse(history=history, queue=queue)
|
103 |
+
|
104 |
+
from fooocusapi.sql_client import query_history
|
105 |
+
history = query_history(task_id=job_id, page=page, page_size=page_size)
|
106 |
+
return {
|
107 |
+
"history": history,
|
108 |
+
"queue": queue
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
@secure_router.get(
|
113 |
+
path="/v1/engines/all-models",
|
114 |
+
response_model=AllModelNamesResponse,
|
115 |
+
description="Get all filenames of base model and lora",
|
116 |
+
tags=["Query"])
|
117 |
+
def all_models():
|
118 |
+
"""Refresh and return all models"""
|
119 |
+
from modules import config
|
120 |
+
config.update_files()
|
121 |
+
models = AllModelNamesResponse(
|
122 |
+
model_filenames=config.model_filenames,
|
123 |
+
lora_filenames=config.lora_filenames)
|
124 |
+
return models
|
125 |
+
|
126 |
+
|
127 |
+
@secure_router.get(
|
128 |
+
path="/v1/engines/styles",
|
129 |
+
response_model=List[str],
|
130 |
+
description="Get all legal Fooocus styles",
|
131 |
+
tags=['Query'])
|
132 |
+
def all_styles():
|
133 |
+
"""Return all available styles"""
|
134 |
+
from modules.sdxl_styles import legal_style_names
|
135 |
+
return legal_style_names
|
fooocusapi/sql_client.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SQLite client for Fooocus API
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import platform
|
7 |
+
from datetime import datetime
|
8 |
+
from typing import Optional
|
9 |
+
import copy
|
10 |
+
|
11 |
+
from sqlalchemy import Integer, Float, VARCHAR, Boolean, JSON, Text, create_engine
|
12 |
+
from sqlalchemy.orm import declarative_base, Session, Mapped, mapped_column
|
13 |
+
|
14 |
+
|
15 |
+
Base = declarative_base()
|
16 |
+
|
17 |
+
|
18 |
+
if platform.system().lower() == "windows":
|
19 |
+
default_sqlite_db_path = os.path.join(
|
20 |
+
os.path.dirname(__file__), "../database.db"
|
21 |
+
).replace("\\", "/")
|
22 |
+
else:
|
23 |
+
default_sqlite_db_path = os.path.join(os.path.dirname(__file__), "../database.db")
|
24 |
+
|
25 |
+
connection_uri = os.environ.get(
|
26 |
+
"FOOOCUS_DB_CONF", f"sqlite:///{default_sqlite_db_path}"
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class GenerateRecord(Base):
|
31 |
+
"""
|
32 |
+
GenerateRecord
|
33 |
+
|
34 |
+
__tablename__ = 'generate_record'
|
35 |
+
"""
|
36 |
+
|
37 |
+
__tablename__ = "generate_record"
|
38 |
+
|
39 |
+
task_id: Mapped[str] = mapped_column(VARCHAR(255), nullable=False, primary_key=True)
|
40 |
+
task_type: Mapped[str] = mapped_column(Text, nullable=False)
|
41 |
+
result_url: Mapped[str] = mapped_column(Text, nullable=True)
|
42 |
+
finish_reason: Mapped[str] = mapped_column(Text, nullable=True)
|
43 |
+
date_time: Mapped[int] = mapped_column(Integer, nullable=False)
|
44 |
+
|
45 |
+
prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
46 |
+
negative_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
47 |
+
style_selections: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
48 |
+
performance_selection: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
49 |
+
aspect_ratios_selection: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
50 |
+
base_model_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
51 |
+
refiner_model_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
52 |
+
refiner_switch: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
53 |
+
loras: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
54 |
+
image_number: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
55 |
+
image_seed: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
56 |
+
sharpness: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
57 |
+
guidance_scale: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
58 |
+
advanced_params: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
59 |
+
|
60 |
+
input_image: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
61 |
+
input_mask: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
62 |
+
image_prompts: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
63 |
+
inpaint_additional_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
64 |
+
outpaint_selections: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
65 |
+
outpaint_distance_left: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
66 |
+
outpaint_distance_right: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
67 |
+
outpaint_distance_top: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
68 |
+
outpaint_distance_bottom: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
69 |
+
uov_method: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
70 |
+
upscale_value: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
71 |
+
|
72 |
+
webhook_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
73 |
+
require_base64: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
74 |
+
async_process: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
75 |
+
|
76 |
+
def __repr__(self) -> str:
|
77 |
+
return f"GenerateRecord(task_id={self.task_id!r}, task_type={self.task_type!r}, \
|
78 |
+
result_url={self.result_url!r}, finish_reason={self.finish_reason!r}, date_time={self.date_time!r}, \
|
79 |
+
prompt={self.prompt!r}, negative_prompt={self.negative_prompt!r}, style_selections={self.style_selections!r}, performance_selection={self.performance_selection!r}, \
|
80 |
+
aspect_ratios_selection={self.aspect_ratios_selection!r}, base_model_name={self.base_model_name!r}, \
|
81 |
+
refiner_model_name={self.refiner_model_name!r}, refiner_switch={self.refiner_switch!r}, loras={self.loras!r}, \
|
82 |
+
image_number={self.image_number!r}, image_seed={self.image_seed!r}, sharpness={self.sharpness!r}, \
|
83 |
+
guidance_scale={self.guidance_scale!r}, advanced_params={self.advanced_params!r}, input_image={self.input_image!r}, \
|
84 |
+
input_mask={self.input_mask!r}, image_prompts={self.image_prompts!r}, inpaint_additional_prompt={self.inpaint_additional_prompt!r}, \
|
85 |
+
outpaint_selections={self.outpaint_selections!r}, outpaint_distance_left={self.outpaint_distance_left!r}, outpaint_distance_right={self.outpaint_distance_right!r}, \
|
86 |
+
outpaint_distance_top={self.outpaint_distance_top!r}, outpaint_distance_bottom={self.outpaint_distance_bottom!r}, uov_method={self.uov_method!r}, \
|
87 |
+
upscale_value={self.upscale_value!r}, webhook_url={self.webhook_url!r}, require_base64={self.require_base64!r}, \
|
88 |
+
async_process={self.async_process!r})"
|
89 |
+
|
90 |
+
|
91 |
+
engine = create_engine(connection_uri)
|
92 |
+
|
93 |
+
session = Session(engine)
|
94 |
+
Base.metadata.create_all(engine, checkfirst=True)
|
95 |
+
session.close()
|
96 |
+
|
97 |
+
|
98 |
+
def convert_to_dict_list(obj_list: list[object]) -> list[dict]:
|
99 |
+
"""
|
100 |
+
Convert a list of objects to a list of dictionaries.
|
101 |
+
Args:
|
102 |
+
obj_list:
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
dict_list:
|
106 |
+
"""
|
107 |
+
dict_list = []
|
108 |
+
for obj in obj_list:
|
109 |
+
# 将对象属性转化为字典键值对
|
110 |
+
dict_obj = {}
|
111 |
+
for attr, value in vars(obj).items():
|
112 |
+
if (
|
113 |
+
not callable(value)
|
114 |
+
and not attr.startswith("__")
|
115 |
+
and not attr.startswith("_")
|
116 |
+
):
|
117 |
+
dict_obj[attr] = value
|
118 |
+
task_info = {
|
119 |
+
"task_id": obj.task_id,
|
120 |
+
"task_type": obj.task_type,
|
121 |
+
"result_url": obj.result_url,
|
122 |
+
"finish_reason": obj.finish_reason,
|
123 |
+
"date_time": datetime.fromtimestamp(obj.date_time).strftime(
|
124 |
+
"%Y-%m-%d %H:%M:%S"
|
125 |
+
),
|
126 |
+
}
|
127 |
+
del dict_obj["task_id"]
|
128 |
+
del dict_obj["task_type"]
|
129 |
+
del dict_obj["result_url"]
|
130 |
+
del dict_obj["finish_reason"]
|
131 |
+
del dict_obj["date_time"]
|
132 |
+
dict_list.append({"params": dict_obj, "task_info": task_info})
|
133 |
+
return dict_list
|
134 |
+
|
135 |
+
|
136 |
+
class MySQLAlchemy:
|
137 |
+
"""
|
138 |
+
MySQLAlchemy, a toolkit for managing SQLAlchemy connections and sessions.
|
139 |
+
|
140 |
+
:param uri: SQLAlchemy connection URI
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, uri: str):
|
144 |
+
# 'mysql+pymysql://{username}:{password}@{host}:{port}/{database}'
|
145 |
+
self.engine = create_engine(uri)
|
146 |
+
self.session = Session(self.engine)
|
147 |
+
|
148 |
+
def store_history(self, record: dict) -> None:
|
149 |
+
"""
|
150 |
+
Store history to database
|
151 |
+
:param record:
|
152 |
+
:return:
|
153 |
+
"""
|
154 |
+
self.session.add_all([GenerateRecord(**record)])
|
155 |
+
self.session.commit()
|
156 |
+
|
157 |
+
def get_history(
|
158 |
+
self,
|
159 |
+
task_id: str = None,
|
160 |
+
page: int = 0,
|
161 |
+
page_size: int = 20,
|
162 |
+
order_by: str = "date_time",
|
163 |
+
) -> list:
|
164 |
+
"""
|
165 |
+
Get history from database
|
166 |
+
:param task_id:
|
167 |
+
:param page:
|
168 |
+
:param page_size:
|
169 |
+
:param order_by:
|
170 |
+
:return:
|
171 |
+
"""
|
172 |
+
if task_id is not None:
|
173 |
+
res = (
|
174 |
+
self.session.query(GenerateRecord)
|
175 |
+
.filter(GenerateRecord.task_id == task_id)
|
176 |
+
.all()
|
177 |
+
)
|
178 |
+
if len(res) == 0:
|
179 |
+
return []
|
180 |
+
return convert_to_dict_list(res)
|
181 |
+
|
182 |
+
res = (
|
183 |
+
self.session.query(GenerateRecord)
|
184 |
+
.order_by(getattr(GenerateRecord, order_by).desc())
|
185 |
+
.offset(page * page_size)
|
186 |
+
.limit(page_size)
|
187 |
+
.all()
|
188 |
+
)
|
189 |
+
if len(res) == 0:
|
190 |
+
return []
|
191 |
+
return convert_to_dict_list(res)
|
192 |
+
|
193 |
+
|
194 |
+
db = MySQLAlchemy(uri=connection_uri)
|
195 |
+
|
196 |
+
|
197 |
+
def req_to_dict(req: dict) -> dict:
|
198 |
+
"""
|
199 |
+
Convert request to dictionary
|
200 |
+
Args:
|
201 |
+
req:
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
|
205 |
+
"""
|
206 |
+
req["loras"] = [{"model_name": lora[0], "weight": lora[1]} for lora in req["loras"]]
|
207 |
+
# req["advanced_params"] = dict(zip(adv_params_keys, req["advanced_params"]))
|
208 |
+
req["image_prompts"] = [
|
209 |
+
{"cn_img": "", "cn_stop": image[1], "cn_weight": image[2], "cn_type": image[3]}
|
210 |
+
for image in req["image_prompts"]
|
211 |
+
]
|
212 |
+
del req["inpaint_input_image"]
|
213 |
+
del req["uov_input_image"]
|
214 |
+
return req
|
215 |
+
|
216 |
+
|
217 |
+
def add_history(
|
218 |
+
params: dict, task_type: str, task_id: str, result_url: str, finish_reason: str
|
219 |
+
) -> None:
|
220 |
+
"""
|
221 |
+
Store history to database
|
222 |
+
Args:
|
223 |
+
params:
|
224 |
+
task_type:
|
225 |
+
task_id:
|
226 |
+
result_url:
|
227 |
+
finish_reason:
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
|
231 |
+
"""
|
232 |
+
adv = copy.deepcopy(params["advanced_params"])
|
233 |
+
params["advanced_params"] = adv.__dict__
|
234 |
+
params["date_time"] = int(time.time())
|
235 |
+
params["task_type"] = task_type
|
236 |
+
params["task_id"] = task_id
|
237 |
+
params["result_url"] = result_url
|
238 |
+
params["finish_reason"] = finish_reason
|
239 |
+
|
240 |
+
del params["inpaint_input_image"]
|
241 |
+
del params["uov_input_image"]
|
242 |
+
del params["save_extension"]
|
243 |
+
del params["save_meta"]
|
244 |
+
del params["save_name"]
|
245 |
+
del params["meta_scheme"]
|
246 |
+
|
247 |
+
db.store_history(params)
|
248 |
+
|
249 |
+
|
250 |
+
def query_history(
|
251 |
+
task_id: str = None,
|
252 |
+
page: int = 0,
|
253 |
+
page_size: int = 20,
|
254 |
+
order_by: str = "date_time"
|
255 |
+
) -> list:
|
256 |
+
"""
|
257 |
+
Query history from database
|
258 |
+
Args:
|
259 |
+
task_id:
|
260 |
+
page:
|
261 |
+
page_size:
|
262 |
+
order_by:
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
|
266 |
+
"""
|
267 |
+
return db.get_history(
|
268 |
+
task_id=task_id, page=page, page_size=page_size, order_by=order_by
|
269 |
+
)
|
fooocusapi/task_queue.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Task queue management
|
3 |
+
|
4 |
+
This module provides classes and functions for managing the task queue.
|
5 |
+
|
6 |
+
Classes:
|
7 |
+
QueueTask: A class representing a task in the queue.
|
8 |
+
TaskQueue: A class for managing the task queue.
|
9 |
+
"""
|
10 |
+
import uuid
|
11 |
+
import time
|
12 |
+
from typing import List, Tuple
|
13 |
+
import numpy as np
|
14 |
+
import requests
|
15 |
+
|
16 |
+
from fooocusapi.utils.file_utils import delete_output_file, get_file_serve_url
|
17 |
+
from fooocusapi.utils.img_utils import narray_to_base64img
|
18 |
+
from fooocusapi.utils.logger import logger
|
19 |
+
|
20 |
+
from fooocusapi.models.common.task import ImageGenerationResult, GenerationFinishReason
|
21 |
+
from fooocusapi.parameters import ImageGenerationParams
|
22 |
+
from fooocusapi.models.common.task import TaskType
|
23 |
+
|
24 |
+
|
25 |
+
class QueueTask:
|
26 |
+
"""
|
27 |
+
A class representing a task in the queue.
|
28 |
+
|
29 |
+
Attributes:
|
30 |
+
job_id (str): The unique identifier for the task, generated by uuid.
|
31 |
+
task_type (TaskType): The type of task.
|
32 |
+
is_finished (bool): Indicates whether the task has been completed.
|
33 |
+
finish_progress (int): The progress of the task completion.
|
34 |
+
in_queue_mills (int): The time the task was added to the queue, in milliseconds.
|
35 |
+
start_mills (int): The time the task started, in milliseconds.
|
36 |
+
finish_mills (int): The time the task finished, in milliseconds.
|
37 |
+
finish_with_error (bool): Indicates whether the task finished with an error.
|
38 |
+
task_status (str): The status of the task.
|
39 |
+
task_step_preview (str): A list of step previews for the task.
|
40 |
+
task_result (List[ImageGenerationResult]): The result of the task.
|
41 |
+
error_message (str): The error message, if any.
|
42 |
+
webhook_url (str): The webhook URL, if any.
|
43 |
+
"""
|
44 |
+
|
45 |
+
job_id: str
|
46 |
+
task_type: TaskType
|
47 |
+
req_param: ImageGenerationParams
|
48 |
+
is_finished: bool = False
|
49 |
+
finish_progress: int = 0
|
50 |
+
in_queue_mills: int
|
51 |
+
start_mills: int = 0
|
52 |
+
finish_mills: int = 0
|
53 |
+
finish_with_error: bool = False
|
54 |
+
task_status: str | None = None
|
55 |
+
task_step_preview: str | None = None
|
56 |
+
task_result: List[ImageGenerationResult] = None
|
57 |
+
error_message: str | None = None
|
58 |
+
webhook_url: str | None = None # attribute for individual webhook_url
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
job_id: str,
|
63 |
+
task_type: TaskType,
|
64 |
+
req_param: ImageGenerationParams,
|
65 |
+
webhook_url: str | None = None,
|
66 |
+
):
|
67 |
+
self.job_id = job_id
|
68 |
+
self.task_type = task_type
|
69 |
+
self.req_param = req_param
|
70 |
+
self.in_queue_mills = int(round(time.time() * 1000))
|
71 |
+
self.webhook_url = webhook_url
|
72 |
+
|
73 |
+
def set_progress(self, progress: int, status: str | None):
|
74 |
+
"""
|
75 |
+
Set progress and status
|
76 |
+
Arguments:
|
77 |
+
progress {int} -- progress
|
78 |
+
status {str} -- status
|
79 |
+
"""
|
80 |
+
progress = min(progress, 100)
|
81 |
+
self.finish_progress = progress
|
82 |
+
self.task_status = status
|
83 |
+
|
84 |
+
def set_step_preview(self, task_step_preview: str | None):
|
85 |
+
"""set step preview
|
86 |
+
Set step preview
|
87 |
+
Arguments:
|
88 |
+
task_step_preview {str} -- step preview
|
89 |
+
"""
|
90 |
+
self.task_step_preview = task_step_preview
|
91 |
+
|
92 |
+
def set_result(
|
93 |
+
self,
|
94 |
+
task_result: List[ImageGenerationResult],
|
95 |
+
finish_with_error: bool,
|
96 |
+
error_message: str | None = None,
|
97 |
+
):
|
98 |
+
"""set result
|
99 |
+
Set task result
|
100 |
+
Arguments:
|
101 |
+
task_result {List[ImageGenerationResult]} -- task result
|
102 |
+
finish_with_error {bool} -- finish with error
|
103 |
+
error_message {str} -- error message
|
104 |
+
"""
|
105 |
+
if not finish_with_error:
|
106 |
+
self.finish_progress = 100
|
107 |
+
self.task_status = "Finished"
|
108 |
+
self.task_result = task_result
|
109 |
+
self.finish_with_error = finish_with_error
|
110 |
+
self.error_message = error_message
|
111 |
+
|
112 |
+
def __str__(self) -> str:
|
113 |
+
return f"QueueTask(job_id={self.job_id}, task_type={self.task_type},\
|
114 |
+
is_finished={self.is_finished}, finished_progress={self.finish_progress}, \
|
115 |
+
in_queue_mills={self.in_queue_mills}, start_mills={self.start_mills}, \
|
116 |
+
finish_mills={self.finish_mills}, finish_with_error={self.finish_with_error}, \
|
117 |
+
error_message={self.error_message}, task_status={self.task_status}, \
|
118 |
+
task_step_preview={self.task_step_preview}, webhook_url={self.webhook_url})"
|
119 |
+
|
120 |
+
|
121 |
+
class TaskQueue:
|
122 |
+
"""
|
123 |
+
TaskQueue is a queue of tasks that are waiting to be processed.
|
124 |
+
|
125 |
+
Attributes:
|
126 |
+
queue: List[QueueTask]
|
127 |
+
history: List[QueueTask]
|
128 |
+
last_job_id: str
|
129 |
+
webhook_url: str
|
130 |
+
persistent: bool
|
131 |
+
"""
|
132 |
+
|
133 |
+
queue: List[QueueTask] = []
|
134 |
+
history: List[QueueTask] = []
|
135 |
+
last_job_id: str = None
|
136 |
+
webhook_url: str | None = None
|
137 |
+
persistent: bool = False
|
138 |
+
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
queue_size: int,
|
142 |
+
history_size: int,
|
143 |
+
webhook_url: str | None = None,
|
144 |
+
persistent: bool | None = False,
|
145 |
+
):
|
146 |
+
self.queue_size = queue_size
|
147 |
+
self.history_size = history_size
|
148 |
+
self.webhook_url = webhook_url
|
149 |
+
self.persistent = False if persistent is None else persistent
|
150 |
+
|
151 |
+
def add_task(
|
152 |
+
self,
|
153 |
+
task_type: TaskType,
|
154 |
+
req_param: ImageGenerationParams,
|
155 |
+
webhook_url: str | None = None,
|
156 |
+
) -> QueueTask | None:
|
157 |
+
"""
|
158 |
+
Create and add task to queue
|
159 |
+
:param task_type: task type
|
160 |
+
:param req_param: request parameters
|
161 |
+
:param webhook_url: webhook url
|
162 |
+
:returns: The created task's job_id, or None if reach the queue size limit
|
163 |
+
"""
|
164 |
+
if len(self.queue) >= self.queue_size:
|
165 |
+
return None
|
166 |
+
|
167 |
+
if isinstance(req_param, dict):
|
168 |
+
req_param = ImageGenerationParams(**req_param)
|
169 |
+
|
170 |
+
job_id = str(uuid.uuid4())
|
171 |
+
task = QueueTask(
|
172 |
+
job_id=job_id,
|
173 |
+
task_type=task_type,
|
174 |
+
req_param=req_param,
|
175 |
+
webhook_url=webhook_url,
|
176 |
+
)
|
177 |
+
self.queue.append(task)
|
178 |
+
self.last_job_id = job_id
|
179 |
+
return task
|
180 |
+
|
181 |
+
def get_task(self, job_id: str, include_history: bool = False) -> QueueTask | None:
|
182 |
+
"""
|
183 |
+
Get task by job_id
|
184 |
+
:param job_id: job id
|
185 |
+
:param include_history: whether to include history tasks
|
186 |
+
:returns: The task with the given job_id, or None if not found
|
187 |
+
"""
|
188 |
+
for task in self.queue:
|
189 |
+
if task.job_id == job_id:
|
190 |
+
return task
|
191 |
+
|
192 |
+
if include_history:
|
193 |
+
for task in self.history:
|
194 |
+
if task.job_id == job_id:
|
195 |
+
return task
|
196 |
+
|
197 |
+
return None
|
198 |
+
|
199 |
+
def is_task_ready_to_start(self, job_id: str) -> bool:
|
200 |
+
"""
|
201 |
+
Check if the task is ready to start
|
202 |
+
:param job_id: job id
|
203 |
+
:returns: True if the task is ready to start, False otherwise
|
204 |
+
"""
|
205 |
+
task = self.get_task(job_id)
|
206 |
+
if task is None:
|
207 |
+
return False
|
208 |
+
|
209 |
+
return self.queue[0].job_id == job_id
|
210 |
+
|
211 |
+
def is_task_finished(self, job_id: str) -> bool:
|
212 |
+
"""
|
213 |
+
Check if the task is finished
|
214 |
+
:param job_id: job id
|
215 |
+
:returns: True if the task is finished, False otherwise
|
216 |
+
"""
|
217 |
+
task = self.get_task(job_id, True)
|
218 |
+
if task is None:
|
219 |
+
return False
|
220 |
+
|
221 |
+
return task.is_finished
|
222 |
+
|
223 |
+
def start_task(self, job_id: str):
|
224 |
+
"""
|
225 |
+
Start task by job_id
|
226 |
+
:param job_id: job id
|
227 |
+
"""
|
228 |
+
task = self.get_task(job_id)
|
229 |
+
if task is not None:
|
230 |
+
task.start_mills = int(round(time.time() * 1000))
|
231 |
+
|
232 |
+
def finish_task(self, job_id: str):
|
233 |
+
"""
|
234 |
+
Finish task by job_id
|
235 |
+
:param job_id: job id
|
236 |
+
"""
|
237 |
+
task = self.get_task(job_id)
|
238 |
+
if task is not None:
|
239 |
+
task.is_finished = True
|
240 |
+
task.finish_mills = int(round(time.time() * 1000))
|
241 |
+
|
242 |
+
# Use the task's webhook_url if available, else use the default
|
243 |
+
webhook_url = task.webhook_url or self.webhook_url
|
244 |
+
|
245 |
+
data = {"job_id": task.job_id, "job_result": []}
|
246 |
+
|
247 |
+
if isinstance(task.task_result, List):
|
248 |
+
for item in task.task_result:
|
249 |
+
data["job_result"].append(
|
250 |
+
{
|
251 |
+
"url": get_file_serve_url(item.im) if item.im else None,
|
252 |
+
"seed": item.seed if item.seed else "-1",
|
253 |
+
}
|
254 |
+
)
|
255 |
+
|
256 |
+
# Send webhook
|
257 |
+
if task.is_finished and webhook_url:
|
258 |
+
try:
|
259 |
+
res = requests.post(webhook_url, json=data, timeout=15)
|
260 |
+
print(f"Call webhook response status: {res.status_code}")
|
261 |
+
except Exception as e:
|
262 |
+
print("Call webhook error:", e)
|
263 |
+
|
264 |
+
# Move task to history
|
265 |
+
self.queue.remove(task)
|
266 |
+
self.history.append(task)
|
267 |
+
|
268 |
+
# save history to database
|
269 |
+
if self.persistent:
|
270 |
+
from fooocusapi.sql_client import add_history
|
271 |
+
|
272 |
+
add_history(
|
273 |
+
params=task.req_param.to_dict(),
|
274 |
+
task_type=task.task_type.value,
|
275 |
+
task_id=task.job_id,
|
276 |
+
result_url=",".join([job["url"] for job in data["job_result"]]),
|
277 |
+
finish_reason=task.task_result[0].finish_reason.value,
|
278 |
+
)
|
279 |
+
|
280 |
+
# Clean history
|
281 |
+
if len(self.history) > self.history_size != 0:
|
282 |
+
removed_task = self.history.pop(0)
|
283 |
+
if isinstance(removed_task.task_result, List):
|
284 |
+
for item in removed_task.task_result:
|
285 |
+
if (
|
286 |
+
isinstance(item, ImageGenerationResult)
|
287 |
+
and item.finish_reason == GenerationFinishReason.success
|
288 |
+
and item.im is not None
|
289 |
+
):
|
290 |
+
delete_output_file(item.im)
|
291 |
+
logger.std_info(
|
292 |
+
f"[TaskQueue] Clean task history, remove task: {removed_task.job_id}"
|
293 |
+
)
|
294 |
+
|
295 |
+
|
296 |
+
class TaskOutputs:
|
297 |
+
"""
|
298 |
+
TaskOutputs is a container for task outputs
|
299 |
+
"""
|
300 |
+
|
301 |
+
outputs = []
|
302 |
+
|
303 |
+
def __init__(self, task: QueueTask):
|
304 |
+
self.task = task
|
305 |
+
|
306 |
+
def append(self, args: List[any]):
|
307 |
+
"""
|
308 |
+
Append output to task outputs list
|
309 |
+
:param args: output arguments
|
310 |
+
"""
|
311 |
+
self.outputs.append(args)
|
312 |
+
if len(args) >= 2:
|
313 |
+
if (
|
314 |
+
args[0] == "preview"
|
315 |
+
and isinstance(args[1], Tuple)
|
316 |
+
and len(args[1]) >= 2
|
317 |
+
):
|
318 |
+
number = args[1][0]
|
319 |
+
text = args[1][1]
|
320 |
+
self.task.set_progress(number, text)
|
321 |
+
if len(args[1]) >= 3 and isinstance(args[1][2], np.ndarray):
|
322 |
+
base64_preview_img = narray_to_base64img(args[1][2])
|
323 |
+
self.task.set_step_preview(base64_preview_img)
|
fooocusapi/utils/api_utils.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""some utils for api"""
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from fastapi import Response
|
5 |
+
from fastapi.security import APIKeyHeader
|
6 |
+
from fastapi import HTTPException, Security
|
7 |
+
|
8 |
+
from modules import flags
|
9 |
+
from modules import config
|
10 |
+
from modules.sdxl_styles import legal_style_names
|
11 |
+
|
12 |
+
from fooocusapi.args import args
|
13 |
+
from fooocusapi.utils.img_utils import read_input_image
|
14 |
+
from fooocusapi.utils.file_utils import (
|
15 |
+
get_file_serve_url,
|
16 |
+
output_file_to_base64img,
|
17 |
+
output_file_to_bytesimg
|
18 |
+
)
|
19 |
+
from fooocusapi.utils.logger import logger
|
20 |
+
from fooocusapi.models.common.requests import (
|
21 |
+
CommonRequest as Text2ImgRequest
|
22 |
+
)
|
23 |
+
from fooocusapi.models.common.response import (
|
24 |
+
AsyncJobResponse,
|
25 |
+
AsyncJobStage,
|
26 |
+
GeneratedImageResult
|
27 |
+
)
|
28 |
+
from fooocusapi.models.requests_v1 import (
|
29 |
+
ImgInpaintOrOutpaintRequest,
|
30 |
+
ImgPromptRequest,
|
31 |
+
ImgUpscaleOrVaryRequest
|
32 |
+
)
|
33 |
+
from fooocusapi.models.requests_v2 import (
|
34 |
+
Text2ImgRequestWithPrompt,
|
35 |
+
ImgInpaintOrOutpaintRequestJson,
|
36 |
+
ImgUpscaleOrVaryRequestJson,
|
37 |
+
ImgPromptRequestJson
|
38 |
+
)
|
39 |
+
from fooocusapi.models.common.task import (
|
40 |
+
ImageGenerationResult,
|
41 |
+
GenerationFinishReason
|
42 |
+
)
|
43 |
+
from fooocusapi.configs.default import (
|
44 |
+
default_inpaint_engine_version,
|
45 |
+
default_sampler,
|
46 |
+
default_scheduler,
|
47 |
+
default_base_model_name,
|
48 |
+
default_refiner_model_name
|
49 |
+
)
|
50 |
+
|
51 |
+
from fooocusapi.parameters import ImageGenerationParams
|
52 |
+
from fooocusapi.task_queue import QueueTask
|
53 |
+
|
54 |
+
|
55 |
+
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False)
|
56 |
+
|
57 |
+
|
58 |
+
def api_key_auth(apikey: str = Security(api_key_header)):
|
59 |
+
"""
|
60 |
+
Check if the API key is valid, API key is not required if no API key is set
|
61 |
+
Args:
|
62 |
+
apikey: API key
|
63 |
+
returns:
|
64 |
+
None if API key is not set, otherwise raise HTTPException
|
65 |
+
"""
|
66 |
+
if args.apikey is None:
|
67 |
+
return # Skip API key check if no API key is set
|
68 |
+
if apikey != args.apikey:
|
69 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
70 |
+
|
71 |
+
|
72 |
+
def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
|
73 |
+
"""
|
74 |
+
Convert Request to ImageGenerationParams
|
75 |
+
Args:
|
76 |
+
req: Request, Text2ImgRequest and classes inherited from Text2ImgRequest
|
77 |
+
returns:
|
78 |
+
ImageGenerationParams
|
79 |
+
"""
|
80 |
+
config.update_files()
|
81 |
+
if req.base_model_name is not None:
|
82 |
+
if req.base_model_name not in config.model_filenames:
|
83 |
+
logger.std_warn(f"[Warning] Wrong base_model_name input: {req.base_model_name}, using default")
|
84 |
+
req.base_model_name = default_base_model_name
|
85 |
+
|
86 |
+
if req.refiner_model_name is not None and req.refiner_model_name != 'None':
|
87 |
+
if req.refiner_model_name not in config.model_filenames:
|
88 |
+
logger.std_warn(f"[Warning] Wrong refiner_model_name input: {req.refiner_model_name}, using default")
|
89 |
+
req.refiner_model_name = default_refiner_model_name
|
90 |
+
|
91 |
+
for lora in req.loras:
|
92 |
+
if lora.model_name != 'None' and lora.model_name not in config.lora_filenames:
|
93 |
+
logger.std_warn(f"[Warning] Wrong lora model_name input: {lora.model_name}, using 'None'")
|
94 |
+
lora.model_name = 'None'
|
95 |
+
|
96 |
+
prompt = req.prompt
|
97 |
+
negative_prompt = req.negative_prompt
|
98 |
+
style_selections = [
|
99 |
+
s for s in req.style_selections if s in legal_style_names]
|
100 |
+
performance_selection = req.performance_selection.value
|
101 |
+
aspect_ratios_selection = req.aspect_ratios_selection
|
102 |
+
image_number = req.image_number
|
103 |
+
image_seed = None if req.image_seed == -1 else req.image_seed
|
104 |
+
sharpness = req.sharpness
|
105 |
+
guidance_scale = req.guidance_scale
|
106 |
+
base_model_name = req.base_model_name
|
107 |
+
refiner_model_name = req.refiner_model_name
|
108 |
+
refiner_switch = req.refiner_switch
|
109 |
+
loras = [(lora.model_name, lora.weight) for lora in req.loras]
|
110 |
+
uov_input_image = None
|
111 |
+
if not isinstance(req, Text2ImgRequestWithPrompt):
|
112 |
+
if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
|
113 |
+
uov_input_image = read_input_image(req.input_image)
|
114 |
+
uov_method = flags.disabled if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.uov_method.value
|
115 |
+
upscale_value = None if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.upscale_value
|
116 |
+
outpaint_selections = [] if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else [
|
117 |
+
s.value for s in req.outpaint_selections]
|
118 |
+
outpaint_distance_left = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_left
|
119 |
+
outpaint_distance_right = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_right
|
120 |
+
outpaint_distance_top = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_top
|
121 |
+
outpaint_distance_bottom = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_bottom
|
122 |
+
|
123 |
+
if refiner_model_name == '':
|
124 |
+
refiner_model_name = 'None'
|
125 |
+
|
126 |
+
inpaint_input_image = None
|
127 |
+
inpaint_additional_prompt = None
|
128 |
+
if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) and req.input_image is not None:
|
129 |
+
inpaint_additional_prompt = req.inpaint_additional_prompt
|
130 |
+
input_image = read_input_image(req.input_image)
|
131 |
+
input_mask = None
|
132 |
+
if req.input_mask is not None:
|
133 |
+
input_mask = read_input_image(req.input_mask)
|
134 |
+
inpaint_input_image = {
|
135 |
+
'image': input_image,
|
136 |
+
'mask': input_mask
|
137 |
+
}
|
138 |
+
|
139 |
+
image_prompts = []
|
140 |
+
if isinstance(req, (ImgInpaintOrOutpaintRequestJson, ImgPromptRequest, ImgPromptRequestJson, ImgUpscaleOrVaryRequestJson, Text2ImgRequestWithPrompt)):
|
141 |
+
# Auto set mixing_image_prompt_and_inpaint to True
|
142 |
+
if len(req.image_prompts) > 0 and uov_input_image is not None:
|
143 |
+
print("[INFO] Mixing image prompt and vary upscale is set to True")
|
144 |
+
req.advanced_params.mixing_image_prompt_and_vary_upscale = True
|
145 |
+
elif len(req.image_prompts) > 0 and not isinstance(req, Text2ImgRequestWithPrompt) and req.input_image is not None:
|
146 |
+
print("[INFO] Mixing image prompt and inpaint is set to True")
|
147 |
+
req.advanced_params.mixing_image_prompt_and_inpaint = True
|
148 |
+
|
149 |
+
for img_prompt in req.image_prompts:
|
150 |
+
if img_prompt.cn_img is not None:
|
151 |
+
cn_img = read_input_image(img_prompt.cn_img)
|
152 |
+
if img_prompt.cn_stop is None or img_prompt.cn_stop == 0:
|
153 |
+
img_prompt.cn_stop = flags.default_parameters[img_prompt.cn_type.value][0]
|
154 |
+
if img_prompt.cn_weight is None or img_prompt.cn_weight == 0:
|
155 |
+
img_prompt.cn_weight = flags.default_parameters[img_prompt.cn_type.value][1]
|
156 |
+
image_prompts.append(
|
157 |
+
(cn_img, img_prompt.cn_stop, img_prompt.cn_weight, img_prompt.cn_type.value))
|
158 |
+
|
159 |
+
advanced_params = None
|
160 |
+
if req.advanced_params is not None:
|
161 |
+
adp = req.advanced_params
|
162 |
+
|
163 |
+
if adp.refiner_swap_method not in ['joint', 'separate', 'vae']:
|
164 |
+
print(f"[Warning] Wrong refiner_swap_method input: {adp.refiner_swap_method}, using default")
|
165 |
+
adp.refiner_swap_method = 'joint'
|
166 |
+
|
167 |
+
if adp.sampler_name not in flags.sampler_list:
|
168 |
+
print(f"[Warning] Wrong sampler_name input: {adp.sampler_name}, using default")
|
169 |
+
adp.sampler_name = default_sampler
|
170 |
+
|
171 |
+
if adp.scheduler_name not in flags.scheduler_list:
|
172 |
+
print(f"[Warning] Wrong scheduler_name input: {adp.scheduler_name}, using default")
|
173 |
+
adp.scheduler_name = default_scheduler
|
174 |
+
|
175 |
+
if adp.inpaint_engine not in flags.inpaint_engine_versions:
|
176 |
+
print(f"[Warning] Wrong inpaint_engine input: {adp.inpaint_engine}, using default")
|
177 |
+
adp.inpaint_engine = default_inpaint_engine_version
|
178 |
+
|
179 |
+
advanced_params = adp
|
180 |
+
|
181 |
+
return ImageGenerationParams(
|
182 |
+
prompt=prompt,
|
183 |
+
negative_prompt=negative_prompt,
|
184 |
+
style_selections=style_selections,
|
185 |
+
performance_selection=performance_selection,
|
186 |
+
aspect_ratios_selection=aspect_ratios_selection,
|
187 |
+
image_number=image_number,
|
188 |
+
image_seed=image_seed,
|
189 |
+
sharpness=sharpness,
|
190 |
+
guidance_scale=guidance_scale,
|
191 |
+
base_model_name=base_model_name,
|
192 |
+
refiner_model_name=refiner_model_name,
|
193 |
+
refiner_switch=refiner_switch,
|
194 |
+
loras=loras,
|
195 |
+
uov_input_image=uov_input_image,
|
196 |
+
uov_method=uov_method,
|
197 |
+
upscale_value=upscale_value,
|
198 |
+
outpaint_selections=outpaint_selections,
|
199 |
+
outpaint_distance_left=outpaint_distance_left,
|
200 |
+
outpaint_distance_right=outpaint_distance_right,
|
201 |
+
outpaint_distance_top=outpaint_distance_top,
|
202 |
+
outpaint_distance_bottom=outpaint_distance_bottom,
|
203 |
+
inpaint_input_image=inpaint_input_image,
|
204 |
+
inpaint_additional_prompt=inpaint_additional_prompt,
|
205 |
+
image_prompts=image_prompts,
|
206 |
+
advanced_params=advanced_params,
|
207 |
+
save_meta=req.save_meta,
|
208 |
+
meta_scheme=req.meta_scheme,
|
209 |
+
save_name=req.save_name,
|
210 |
+
save_extension=req.save_extension,
|
211 |
+
require_base64=req.require_base64,
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
def generate_async_output(
|
216 |
+
task: QueueTask,
|
217 |
+
require_step_preview: bool = False) -> AsyncJobResponse:
|
218 |
+
"""
|
219 |
+
Generate output for async job
|
220 |
+
Arguments:
|
221 |
+
task: QueueTask
|
222 |
+
require_step_preview: bool
|
223 |
+
Returns:
|
224 |
+
AsyncJobResponse
|
225 |
+
"""
|
226 |
+
job_stage = AsyncJobStage.running
|
227 |
+
job_result = None
|
228 |
+
|
229 |
+
if task.start_mills == 0:
|
230 |
+
job_stage = AsyncJobStage.waiting
|
231 |
+
|
232 |
+
if task.is_finished:
|
233 |
+
if task.finish_with_error:
|
234 |
+
job_stage = AsyncJobStage.error
|
235 |
+
elif task.task_result is not None:
|
236 |
+
job_stage = AsyncJobStage.success
|
237 |
+
job_result = generate_image_result_output(task.task_result, task.req_param.require_base64)
|
238 |
+
|
239 |
+
result = AsyncJobResponse(
|
240 |
+
job_id=task.job_id,
|
241 |
+
job_type=task.task_type,
|
242 |
+
job_stage=job_stage,
|
243 |
+
job_progress=task.finish_progress,
|
244 |
+
job_status=task.task_status,
|
245 |
+
job_step_preview=task.task_step_preview if require_step_preview else None,
|
246 |
+
job_result=job_result)
|
247 |
+
return result
|
248 |
+
|
249 |
+
|
250 |
+
def generate_streaming_output(results: List[ImageGenerationResult]) -> Response:
|
251 |
+
"""
|
252 |
+
Generate streaming output for image generation results.
|
253 |
+
Args:
|
254 |
+
results (List[ImageGenerationResult]): List of image generation results.
|
255 |
+
Returns:
|
256 |
+
Response: Streaming response object, bytes image.
|
257 |
+
"""
|
258 |
+
if len(results) == 0:
|
259 |
+
return Response(status_code=500)
|
260 |
+
result = results[0]
|
261 |
+
if result.finish_reason == GenerationFinishReason.queue_is_full:
|
262 |
+
return Response(status_code=409, content=result.finish_reason.value)
|
263 |
+
if result.finish_reason == GenerationFinishReason.user_cancel:
|
264 |
+
return Response(status_code=400, content=result.finish_reason.value)
|
265 |
+
if result.finish_reason == GenerationFinishReason.error:
|
266 |
+
return Response(status_code=500, content=result.finish_reason.value)
|
267 |
+
|
268 |
+
img_bytes = output_file_to_bytesimg(results[0].im)
|
269 |
+
return Response(img_bytes, media_type='image/png')
|
270 |
+
|
271 |
+
|
272 |
+
def generate_image_result_output(
|
273 |
+
results: List[ImageGenerationResult],
|
274 |
+
require_base64: bool) -> List[GeneratedImageResult]:
|
275 |
+
"""
|
276 |
+
Generate image result output
|
277 |
+
Arguments:
|
278 |
+
results: List[ImageGenerationResult]
|
279 |
+
require_base64: bool
|
280 |
+
Returns:
|
281 |
+
List[GeneratedImageResult]
|
282 |
+
"""
|
283 |
+
results = [
|
284 |
+
GeneratedImageResult(
|
285 |
+
base64=output_file_to_base64img(item.im) if require_base64 else None,
|
286 |
+
url=get_file_serve_url(item.im),
|
287 |
+
seed=str(item.seed),
|
288 |
+
finish_reason=item.finish_reason
|
289 |
+
) for item in results
|
290 |
+
]
|
291 |
+
return results
|
fooocusapi/utils/call_worker.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""function for call generate worker"""
|
2 |
+
from typing import List
|
3 |
+
from fastapi import Response
|
4 |
+
|
5 |
+
from fooocusapi.models.common.requests import (
|
6 |
+
CommonRequest as Text2ImgRequest
|
7 |
+
)
|
8 |
+
from fooocusapi.models.common.response import (
|
9 |
+
AsyncJobResponse,
|
10 |
+
GeneratedImageResult
|
11 |
+
)
|
12 |
+
from fooocusapi.models.common.task import (
|
13 |
+
GenerationFinishReason,
|
14 |
+
ImageGenerationResult,
|
15 |
+
AsyncJobStage,
|
16 |
+
TaskType
|
17 |
+
)
|
18 |
+
from fooocusapi.utils.api_utils import (
|
19 |
+
req_to_params,
|
20 |
+
generate_async_output,
|
21 |
+
generate_streaming_output,
|
22 |
+
generate_image_result_output
|
23 |
+
)
|
24 |
+
from fooocusapi.models.requests_v1 import (
|
25 |
+
ImgUpscaleOrVaryRequest,
|
26 |
+
ImgPromptRequest,
|
27 |
+
ImgInpaintOrOutpaintRequest
|
28 |
+
)
|
29 |
+
from fooocusapi.models.requests_v2 import (
|
30 |
+
ImgInpaintOrOutpaintRequestJson,
|
31 |
+
ImgPromptRequestJson,
|
32 |
+
ImgUpscaleOrVaryRequestJson
|
33 |
+
)
|
34 |
+
from fooocusapi.worker import worker_queue, blocking_get_task_result
|
35 |
+
|
36 |
+
|
37 |
+
def get_task_type(req: Text2ImgRequest) -> TaskType:
|
38 |
+
"""return task type"""
|
39 |
+
if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
|
40 |
+
return TaskType.img_uov
|
41 |
+
if isinstance(req, (ImgPromptRequest, ImgPromptRequestJson)):
|
42 |
+
return TaskType.img_prompt
|
43 |
+
if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)):
|
44 |
+
return TaskType.img_inpaint_outpaint
|
45 |
+
return TaskType.text_2_img
|
46 |
+
|
47 |
+
|
48 |
+
def call_worker(req: Text2ImgRequest, accept: str) -> Response | AsyncJobResponse | List[GeneratedImageResult]:
|
49 |
+
"""call generation worker"""
|
50 |
+
if accept == 'image/png':
|
51 |
+
streaming_output = True
|
52 |
+
# image_number auto set to 1 in streaming mode
|
53 |
+
req.image_number = 1
|
54 |
+
else:
|
55 |
+
streaming_output = False
|
56 |
+
|
57 |
+
task_type = get_task_type(req)
|
58 |
+
params = req_to_params(req)
|
59 |
+
async_task = worker_queue.add_task(task_type, params, req.webhook_url)
|
60 |
+
|
61 |
+
if async_task is None:
|
62 |
+
# add to worker queue failed
|
63 |
+
failure_results = [
|
64 |
+
ImageGenerationResult(
|
65 |
+
im=None,
|
66 |
+
seed='',
|
67 |
+
finish_reason=GenerationFinishReason.queue_is_full
|
68 |
+
)]
|
69 |
+
|
70 |
+
if streaming_output:
|
71 |
+
return generate_streaming_output(failure_results)
|
72 |
+
if req.async_process:
|
73 |
+
return AsyncJobResponse(
|
74 |
+
job_id='',
|
75 |
+
job_type=get_task_type(req),
|
76 |
+
job_stage=AsyncJobStage.error,
|
77 |
+
job_progress=0,
|
78 |
+
job_status=None,
|
79 |
+
job_step_preview=None,
|
80 |
+
job_result=[GeneratedImageResult(
|
81 |
+
base64=None,
|
82 |
+
url=None,
|
83 |
+
seed='',
|
84 |
+
finish_reason=GenerationFinishReason.queue_is_full
|
85 |
+
)])
|
86 |
+
return generate_image_result_output(failure_results, False)
|
87 |
+
|
88 |
+
if req.async_process:
|
89 |
+
# return async response directly
|
90 |
+
return generate_async_output(async_task)
|
91 |
+
|
92 |
+
# blocking get generation result
|
93 |
+
results = blocking_get_task_result(async_task.job_id)
|
94 |
+
|
95 |
+
if streaming_output:
|
96 |
+
return generate_streaming_output(results)
|
97 |
+
return generate_image_result_output(results, req.require_base64)
|
fooocusapi/utils/file_utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" File utils
|
4 |
+
|
5 |
+
Use for managing generated files
|
6 |
+
|
7 |
+
@file: file_utils.py
|
8 |
+
@author: Konie
|
9 |
+
@update: 2024-03-22
|
10 |
+
"""
|
11 |
+
import base64
|
12 |
+
import datetime
|
13 |
+
from io import BytesIO
|
14 |
+
import os
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
import numpy as np
|
18 |
+
from PIL import Image
|
19 |
+
from PIL.PngImagePlugin import PngInfo
|
20 |
+
|
21 |
+
from fooocusapi.utils.logger import logger
|
22 |
+
|
23 |
+
|
24 |
+
output_dir = os.path.abspath(os.path.join(
|
25 |
+
os.path.dirname(__file__), '../..', 'outputs', 'files'))
|
26 |
+
os.makedirs(output_dir, exist_ok=True)
|
27 |
+
|
28 |
+
STATIC_SERVER_BASE = 'http://127.0.0.1:8888/files/'
|
29 |
+
|
30 |
+
|
31 |
+
def save_output_file(
|
32 |
+
img: np.ndarray,
|
33 |
+
image_meta: dict = None,
|
34 |
+
image_name: str = '',
|
35 |
+
extension: str = 'png') -> str:
|
36 |
+
"""
|
37 |
+
Save np image to file
|
38 |
+
Args:
|
39 |
+
img: np.ndarray image to save
|
40 |
+
image_meta: dict of image metadata
|
41 |
+
image_name: str of image name
|
42 |
+
extension: str of image extension
|
43 |
+
Returns:
|
44 |
+
str of file name
|
45 |
+
"""
|
46 |
+
current_time = datetime.datetime.now()
|
47 |
+
date_string = current_time.strftime("%Y-%m-%d")
|
48 |
+
|
49 |
+
filename = os.path.join(date_string, image_name + '.' + extension)
|
50 |
+
file_path = os.path.join(output_dir, filename)
|
51 |
+
|
52 |
+
if extension not in ['png', 'jpg', 'webp']:
|
53 |
+
extension = 'png'
|
54 |
+
image_format = Image.registered_extensions()['.'+extension]
|
55 |
+
|
56 |
+
if image_meta is None:
|
57 |
+
image_meta = {}
|
58 |
+
|
59 |
+
meta = None
|
60 |
+
if extension == 'png'and image_meta != {}:
|
61 |
+
meta = PngInfo()
|
62 |
+
meta.add_text("parameters", json.dumps(image_meta))
|
63 |
+
meta.add_text("fooocus_scheme", image_meta['metadata_scheme'])
|
64 |
+
|
65 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
66 |
+
Image.fromarray(img).save(
|
67 |
+
file_path,
|
68 |
+
format=image_format,
|
69 |
+
pnginfo=meta,
|
70 |
+
optimize=True)
|
71 |
+
return Path(filename).as_posix()
|
72 |
+
|
73 |
+
|
74 |
+
def delete_output_file(filename: str):
|
75 |
+
"""
|
76 |
+
Delete files specified in the output directory
|
77 |
+
Args:
|
78 |
+
filename: str of file name
|
79 |
+
"""
|
80 |
+
file_path = os.path.join(output_dir, filename)
|
81 |
+
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
82 |
+
logger.std_warn(f'[Fooocus API] {filename} not exists or is not a file')
|
83 |
+
try:
|
84 |
+
os.remove(file_path)
|
85 |
+
logger.std_info(f'[Fooocus API] Delete output file: {filename}')
|
86 |
+
except OSError:
|
87 |
+
logger.std_error(f'[Fooocus API] Delete output file failed: {filename}')
|
88 |
+
|
89 |
+
|
90 |
+
def output_file_to_base64img(filename: str | None) -> str | None:
|
91 |
+
"""
|
92 |
+
Convert an image file to a base64 string.
|
93 |
+
Args:
|
94 |
+
filename: str of file name
|
95 |
+
return: str of base64 string
|
96 |
+
"""
|
97 |
+
if filename is None:
|
98 |
+
return None
|
99 |
+
file_path = os.path.join(output_dir, filename)
|
100 |
+
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
101 |
+
return None
|
102 |
+
|
103 |
+
ext = filename.split('.')[-1]
|
104 |
+
if ext.lower() not in ['png', 'jpg', 'webp', 'jpeg']:
|
105 |
+
ext = 'png'
|
106 |
+
img = Image.open(file_path)
|
107 |
+
output_buffer = BytesIO()
|
108 |
+
img.save(output_buffer, format=ext.upper())
|
109 |
+
byte_data = output_buffer.getvalue()
|
110 |
+
base64_str = base64.b64encode(byte_data).decode('utf-8')
|
111 |
+
return f"data:image/{ext};base64," + base64_str
|
112 |
+
|
113 |
+
|
114 |
+
def output_file_to_bytesimg(filename: str | None) -> bytes | None:
|
115 |
+
"""
|
116 |
+
Convert an image file to a bytes string.
|
117 |
+
Args:
|
118 |
+
filename: str of file name
|
119 |
+
return: bytes of image data
|
120 |
+
"""
|
121 |
+
if filename is None:
|
122 |
+
return None
|
123 |
+
file_path = os.path.join(output_dir, filename)
|
124 |
+
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
125 |
+
return None
|
126 |
+
|
127 |
+
img = Image.open(file_path)
|
128 |
+
output_buffer = BytesIO()
|
129 |
+
img.save(output_buffer, format='PNG')
|
130 |
+
byte_data = output_buffer.getvalue()
|
131 |
+
return byte_data
|
132 |
+
|
133 |
+
|
134 |
+
def get_file_serve_url(filename: str | None) -> str | None:
|
135 |
+
"""
|
136 |
+
Get the static serve url of an image file.
|
137 |
+
Args:
|
138 |
+
filename: str of file name
|
139 |
+
return: str of static serve url
|
140 |
+
"""
|
141 |
+
if filename is None:
|
142 |
+
return None
|
143 |
+
return STATIC_SERVER_BASE + filename.replace('\\', '/')
|
fooocusapi/utils/img_utils.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Image process utils. Used to verify, convert and store Images.
|
3 |
+
|
4 |
+
@file: img_utils.py
|
5 |
+
@author: Konie
|
6 |
+
@update: 2024-03-23
|
7 |
+
"""
|
8 |
+
import base64
|
9 |
+
from io import BytesIO
|
10 |
+
from fastapi import UploadFile
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
import requests
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
def upload2base64(image: UploadFile) -> str | None:
|
18 |
+
"""
|
19 |
+
Convert UploadFile obj to base64 string
|
20 |
+
Args:
|
21 |
+
image (UploadFile): UploadFile obj
|
22 |
+
Returns:
|
23 |
+
str: base64 string, None for None
|
24 |
+
"""
|
25 |
+
if image is None:
|
26 |
+
return None
|
27 |
+
image_bytes = image.file.read()
|
28 |
+
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
29 |
+
return image_base64
|
30 |
+
|
31 |
+
|
32 |
+
def narray_to_base64img(narray: np.ndarray) -> str | None:
|
33 |
+
"""
|
34 |
+
Convert numpy array to base64 image string.
|
35 |
+
Args:
|
36 |
+
narray: numpy array
|
37 |
+
Returns:
|
38 |
+
base64 image string
|
39 |
+
"""
|
40 |
+
if narray is None:
|
41 |
+
return None
|
42 |
+
|
43 |
+
img = Image.fromarray(narray)
|
44 |
+
output_buffer = BytesIO()
|
45 |
+
img.save(output_buffer, format='PNG')
|
46 |
+
byte_data = output_buffer.getvalue()
|
47 |
+
base64_str = base64.b64encode(byte_data).decode('utf-8')
|
48 |
+
return base64_str
|
49 |
+
|
50 |
+
|
51 |
+
def narray_to_bytesimg(narray) -> bytes | None:
|
52 |
+
"""
|
53 |
+
Convert numpy array to bytes image.
|
54 |
+
Args:
|
55 |
+
narray: numpy array
|
56 |
+
Returns:
|
57 |
+
bytes image
|
58 |
+
"""
|
59 |
+
if narray is None:
|
60 |
+
return None
|
61 |
+
|
62 |
+
img = Image.fromarray(narray)
|
63 |
+
output_buffer = BytesIO()
|
64 |
+
img.save(output_buffer, format='PNG')
|
65 |
+
byte_data = output_buffer.getvalue()
|
66 |
+
return byte_data
|
67 |
+
|
68 |
+
|
69 |
+
def read_input_image(input_image: UploadFile | str | None) -> np.ndarray | None:
|
70 |
+
"""
|
71 |
+
Read input image from UploadFile or base64 string.
|
72 |
+
Args:
|
73 |
+
input_image: UploadFile, or base64 image string, or None
|
74 |
+
Returns:
|
75 |
+
numpy array of image
|
76 |
+
"""
|
77 |
+
if input_image is None or input_image == '':
|
78 |
+
return None
|
79 |
+
if isinstance(input_image, str):
|
80 |
+
input_image_bytes = base64.b64decode(input_image)
|
81 |
+
else:
|
82 |
+
input_image_bytes = input_image.file.read()
|
83 |
+
pil_image = Image.open(BytesIO(input_image_bytes))
|
84 |
+
image = np.array(pil_image)
|
85 |
+
return image
|
86 |
+
|
87 |
+
|
88 |
+
def base64_to_stream(image: str) -> UploadFile | None:
|
89 |
+
"""
|
90 |
+
Convert base64 image string to UploadFile.
|
91 |
+
Args:
|
92 |
+
image: base64 image string
|
93 |
+
Returns:
|
94 |
+
UploadFile or None
|
95 |
+
"""
|
96 |
+
if image in ['', None, 'None', 'none', 'string', 'null']:
|
97 |
+
return None
|
98 |
+
if image.startswith('http'):
|
99 |
+
return get_check_image(url=image)
|
100 |
+
if image.startswith('data:image'):
|
101 |
+
image = image.split(sep=',', maxsplit=1)[1]
|
102 |
+
image_bytes = base64.b64decode(image)
|
103 |
+
byte_stream = BytesIO()
|
104 |
+
byte_stream.write(image_bytes)
|
105 |
+
byte_stream.seek(0)
|
106 |
+
return UploadFile(file=byte_stream)
|
107 |
+
|
108 |
+
|
109 |
+
def get_check_image(url: str) -> UploadFile | None:
|
110 |
+
"""
|
111 |
+
Get image from url and check if it's valid.
|
112 |
+
Args:
|
113 |
+
url: image url
|
114 |
+
Returns:
|
115 |
+
UploadFile or None
|
116 |
+
"""
|
117 |
+
if url == '':
|
118 |
+
return None
|
119 |
+
headers = {
|
120 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
|
121 |
+
}
|
122 |
+
try:
|
123 |
+
response = requests.get(url, headers=headers, timeout=10)
|
124 |
+
binary_image = response.content
|
125 |
+
except Exception:
|
126 |
+
return None
|
127 |
+
try:
|
128 |
+
buffer = BytesIO(binary_image)
|
129 |
+
Image.open(buffer) # This validates the image
|
130 |
+
except Exception:
|
131 |
+
return None
|
132 |
+
byte_stream = BytesIO()
|
133 |
+
byte_stream.write(binary_image)
|
134 |
+
byte_stream.seek(0)
|
135 |
+
return UploadFile(file=byte_stream)
|
136 |
+
|
137 |
+
|
138 |
+
def bytes_image_to_io(binary_image: bytes) -> BytesIO | None:
|
139 |
+
"""
|
140 |
+
Convert bytes image to BytesIO.
|
141 |
+
Args:
|
142 |
+
binary_image: bytes image
|
143 |
+
Returns:
|
144 |
+
BytesIO or None
|
145 |
+
"""
|
146 |
+
try:
|
147 |
+
buffer = BytesIO(binary_image)
|
148 |
+
Image.open(buffer)
|
149 |
+
except Exception:
|
150 |
+
return None
|
151 |
+
byte_stream = BytesIO()
|
152 |
+
byte_stream.write(binary_image)
|
153 |
+
byte_stream.seek(0)
|
154 |
+
return byte_stream
|
155 |
+
|
156 |
+
|
157 |
+
def bytes_to_base64img(byte_data: bytes) -> str | None:
|
158 |
+
"""
|
159 |
+
Convert bytes image to base64 image string.
|
160 |
+
Args:
|
161 |
+
byte_data: bytes image
|
162 |
+
Returns:
|
163 |
+
base64 image string or None
|
164 |
+
"""
|
165 |
+
if byte_data is None:
|
166 |
+
return None
|
167 |
+
|
168 |
+
base64_str = base64.b64encode(byte_data).decode('utf-8')
|
169 |
+
return base64_str
|
170 |
+
|
171 |
+
|
172 |
+
def base64_to_bytesimg(base64_str: str) -> bytes | None:
|
173 |
+
"""
|
174 |
+
Convert base64 image string to bytes image.
|
175 |
+
Args:
|
176 |
+
base64_str: base64 image string
|
177 |
+
Returns:
|
178 |
+
bytes image or None
|
179 |
+
"""
|
180 |
+
if base64_str == '':
|
181 |
+
return None
|
182 |
+
bytes_image = base64.b64decode(base64_str)
|
183 |
+
return bytes_image
|
184 |
+
|
185 |
+
|
186 |
+
def base64_to_narray(base64_str: str) -> np.ndarray | None:
|
187 |
+
"""
|
188 |
+
Convert base64 image string to numpy array.
|
189 |
+
Args:
|
190 |
+
base64_str: base64 image string
|
191 |
+
Returns:
|
192 |
+
numpy array or None
|
193 |
+
"""
|
194 |
+
if base64_str == '':
|
195 |
+
return None
|
196 |
+
bytes_image = base64.b64decode(base64_str)
|
197 |
+
image = np.frombuffer(bytes_image, np.uint8)
|
198 |
+
return image
|
fooocusapi/utils/logger.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" A simply logger.
|
4 |
+
|
5 |
+
This module is used to log the program.
|
6 |
+
|
7 |
+
@file: logger.py
|
8 |
+
@author: mrhan1993
|
9 |
+
@update: 2024-03-22
|
10 |
+
"""
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
|
15 |
+
try:
|
16 |
+
from colorlog import ColoredFormatter
|
17 |
+
except ImportError:
|
18 |
+
from fooocusapi.utils.tools import run_pip
|
19 |
+
run_pip(
|
20 |
+
command="install colorlog",
|
21 |
+
desc="Install colorlog for logger.",
|
22 |
+
live=True
|
23 |
+
)
|
24 |
+
finally:
|
25 |
+
from colorlog import ColoredFormatter
|
26 |
+
|
27 |
+
|
28 |
+
own_path = os.path.dirname(os.path.abspath(__file__))
|
29 |
+
log_dir = "logs"
|
30 |
+
default_log_path = os.path.join(own_path, '../../', log_dir)
|
31 |
+
|
32 |
+
std_formatter = ColoredFormatter(
|
33 |
+
fmt="%(log_color)s[%(asctime)s] %(levelname)-8s%(reset)s %(blue)s%(message)s",
|
34 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
35 |
+
reset=True,
|
36 |
+
log_colors={
|
37 |
+
'DEBUG': 'cyan',
|
38 |
+
'INFO': 'green',
|
39 |
+
'WARNING': 'yellow',
|
40 |
+
'ERROR': 'red',
|
41 |
+
'CRITICAL': 'red,bg_white',
|
42 |
+
},
|
43 |
+
secondary_log_colors={},
|
44 |
+
style='%'
|
45 |
+
)
|
46 |
+
|
47 |
+
file_formatter = ColoredFormatter(
|
48 |
+
fmt="[%(asctime)s] %(levelname)-8s%(reset)s %(message)s",
|
49 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
50 |
+
reset=True,
|
51 |
+
no_color=True,
|
52 |
+
style='%'
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class ConfigLogger:
|
57 |
+
"""
|
58 |
+
Configure logger.
|
59 |
+
:param log_path: log file path, better absolute path
|
60 |
+
:param std_format: stdout log format
|
61 |
+
:param file_format: file log format
|
62 |
+
"""
|
63 |
+
def __init__(self,
|
64 |
+
log_path: str = default_log_path,
|
65 |
+
std_format: ColoredFormatter = std_formatter,
|
66 |
+
file_format: ColoredFormatter = file_formatter) -> None:
|
67 |
+
self.log_path = log_path
|
68 |
+
self.std_format = std_format
|
69 |
+
self.file_format = file_format
|
70 |
+
|
71 |
+
|
72 |
+
class Logger:
|
73 |
+
"""
|
74 |
+
A simple logger.
|
75 |
+
:param log_name: log name
|
76 |
+
:param config: config logger
|
77 |
+
"""
|
78 |
+
def __init__(self, log_name, config: ConfigLogger = ConfigLogger()):
|
79 |
+
log_path = config.log_path
|
80 |
+
err_log_path = os.path.join(str(log_path), f"{log_name}_error.log")
|
81 |
+
info_log_path = os.path.join(str(log_path), f"{log_name}_info.log")
|
82 |
+
if not os.path.exists(log_path):
|
83 |
+
os.makedirs(log_path, exist_ok=True)
|
84 |
+
|
85 |
+
self._file_logger = logging.getLogger(log_name)
|
86 |
+
self._file_logger.setLevel("INFO")
|
87 |
+
|
88 |
+
self._std_logger = logging.getLogger()
|
89 |
+
self._std_logger.setLevel("INFO")
|
90 |
+
|
91 |
+
# 创建一个ERROR级别的handler,将日志记录到error.log文件中
|
92 |
+
error_handler = logging.FileHandler(err_log_path, encoding='utf-8')
|
93 |
+
error_handler.setLevel(logging.ERROR)
|
94 |
+
|
95 |
+
# 创建一个INFO级别的handler,将日志记录到info.log文件中
|
96 |
+
info_handler = logging.FileHandler(info_log_path, encoding='utf-8')
|
97 |
+
info_handler.setLevel(logging.INFO)
|
98 |
+
|
99 |
+
# 创建一个 stream handler
|
100 |
+
stream_handler = logging.StreamHandler(sys.stdout)
|
101 |
+
|
102 |
+
error_handler.setFormatter(config.file_format)
|
103 |
+
info_handler.setFormatter(config.file_format)
|
104 |
+
stream_handler.setFormatter(config.std_format)
|
105 |
+
|
106 |
+
# 将handler添加到logger中
|
107 |
+
self._file_logger.addHandler(error_handler)
|
108 |
+
self._file_logger.addHandler(info_handler)
|
109 |
+
self._std_logger.addHandler(stream_handler)
|
110 |
+
|
111 |
+
def file_error(self, message):
|
112 |
+
"""file error log"""
|
113 |
+
self._file_logger.error(message)
|
114 |
+
|
115 |
+
def file_info(self, message):
|
116 |
+
"""file info log"""
|
117 |
+
self._file_logger.info(message)
|
118 |
+
|
119 |
+
def std_info(self, message):
|
120 |
+
"""std info log"""
|
121 |
+
self._std_logger.info(message)
|
122 |
+
|
123 |
+
def std_warn(self, message):
|
124 |
+
"""std warn log"""
|
125 |
+
self._std_logger.warning(message)
|
126 |
+
|
127 |
+
def std_error(self, message):
|
128 |
+
"""std error log"""
|
129 |
+
self._std_logger.error(message)
|
130 |
+
|
131 |
+
|
132 |
+
logger = Logger(log_name="fooocus_api")
|
fooocusapi/utils/lora_manager.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
import tarfile
|
5 |
+
|
6 |
+
def _hash_url(url):
|
7 |
+
"""Generates a hash value for a given URL."""
|
8 |
+
return hashlib.md5(url.encode('utf-8')).hexdigest()
|
9 |
+
|
10 |
+
class LoraManager:
|
11 |
+
"""
|
12 |
+
Manager loras from url
|
13 |
+
"""
|
14 |
+
def __init__(self):
|
15 |
+
self.cache_dir = os.path.join(
|
16 |
+
os.path.dirname(os.path.realpath(__file__)),
|
17 |
+
'../../',
|
18 |
+
'repositories/Fooocus/models/loras')
|
19 |
+
|
20 |
+
def _download_lora(self, url):
|
21 |
+
"""
|
22 |
+
Downloads a LoRa from a URL, saves it in the cache, and if it's a .tar file, extracts it and returns the .safetensors file.
|
23 |
+
"""
|
24 |
+
url_hash = _hash_url(url)
|
25 |
+
file_ext = url.split('.')[-1]
|
26 |
+
filepath = os.path.join(self.cache_dir, f"{url_hash}.{file_ext}")
|
27 |
+
|
28 |
+
if not os.path.exists(filepath):
|
29 |
+
print(f"Start download for: {url}")
|
30 |
+
|
31 |
+
try:
|
32 |
+
response = requests.get(url, timeout=10, stream=True)
|
33 |
+
response.raise_for_status()
|
34 |
+
with open(filepath, 'wb') as f:
|
35 |
+
for chunk in response.iter_content(chunk_size=8192):
|
36 |
+
f.write(chunk)
|
37 |
+
|
38 |
+
if file_ext == "tar":
|
39 |
+
print("Extracting the tar file...")
|
40 |
+
with tarfile.open(filepath, 'r:*') as tar:
|
41 |
+
tar.extractall(path=self.cache_dir)
|
42 |
+
print("Extraction completed.")
|
43 |
+
return self._find_safetensors_file(self.cache_dir)
|
44 |
+
|
45 |
+
print(f"Download successfully, saved as {filepath}")
|
46 |
+
except Exception as e:
|
47 |
+
raise Exception(f"Error downloading {url}: {e}") from e
|
48 |
+
|
49 |
+
else:
|
50 |
+
print(f"LoRa already downloaded {url}")
|
51 |
+
|
52 |
+
return filepath
|
53 |
+
|
54 |
+
def _find_safetensors_file(self, directory):
|
55 |
+
"""
|
56 |
+
Finds the first .safetensors file in the specified directory.
|
57 |
+
"""
|
58 |
+
print("Searching for .safetensors file.")
|
59 |
+
for root, dirs, files in os.walk(directory):
|
60 |
+
for file in files:
|
61 |
+
if file.endswith('.safetensors'):
|
62 |
+
return os.path.join(root, file)
|
63 |
+
raise FileNotFoundError("No .safetensors file found in the extracted files.")
|
64 |
+
|
65 |
+
def check(self, urls):
|
66 |
+
"""Manages the specified LoRAs: downloads missing ones and returns their file names."""
|
67 |
+
paths = []
|
68 |
+
for url in urls:
|
69 |
+
path = self._download_lora(url)
|
70 |
+
paths.append(path)
|
71 |
+
return paths
|
fooocusapi/utils/model_loader.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""
|
4 |
+
Download models from url
|
5 |
+
|
6 |
+
@file: model_loader.py
|
7 |
+
@author: Konie
|
8 |
+
@update: 2024-03-22
|
9 |
+
"""
|
10 |
+
from modules.model_loader import load_file_from_url
|
11 |
+
|
12 |
+
|
13 |
+
def download_models():
|
14 |
+
"""
|
15 |
+
Download models from config
|
16 |
+
"""
|
17 |
+
vae_approx_filenames = [
|
18 |
+
('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
|
19 |
+
('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'),
|
20 |
+
('xl-to-v1_interposer-v3.1.safetensors', 'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
|
21 |
+
]
|
22 |
+
|
23 |
+
from modules.config import (
|
24 |
+
paths_checkpoints as modelfile_path,
|
25 |
+
paths_loras as lorafile_path,
|
26 |
+
path_vae_approx as vae_approx_path,
|
27 |
+
path_fooocus_expansion as fooocus_expansion_path,
|
28 |
+
path_embeddings as embeddings_path,
|
29 |
+
checkpoint_downloads,
|
30 |
+
embeddings_downloads,
|
31 |
+
lora_downloads)
|
32 |
+
|
33 |
+
for file_name, url in checkpoint_downloads.items():
|
34 |
+
load_file_from_url(url=url, model_dir=modelfile_path[0], file_name=file_name)
|
35 |
+
for file_name, url in embeddings_downloads.items():
|
36 |
+
load_file_from_url(url=url, model_dir=embeddings_path, file_name=file_name)
|
37 |
+
for file_name, url in lora_downloads.items():
|
38 |
+
load_file_from_url(url=url, model_dir=lorafile_path[0], file_name=file_name)
|
39 |
+
for file_name, url in vae_approx_filenames:
|
40 |
+
load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name)
|
41 |
+
|
42 |
+
load_file_from_url(
|
43 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',
|
44 |
+
model_dir=fooocus_expansion_path,
|
45 |
+
file_name='pytorch_model.bin'
|
46 |
+
)
|
fooocusapi/utils/tools.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" Some tools
|
4 |
+
|
5 |
+
@file: tools.py
|
6 |
+
@author: Konie
|
7 |
+
@update: 2024-03-22
|
8 |
+
"""
|
9 |
+
# pylint: disable=line-too-long
|
10 |
+
# pylint: disable=broad-exception-caught
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import re
|
14 |
+
import subprocess
|
15 |
+
from importlib.util import find_spec
|
16 |
+
from importlib import metadata
|
17 |
+
from packaging import version
|
18 |
+
|
19 |
+
|
20 |
+
PYTHON_EXEC = sys.executable
|
21 |
+
INDEX_URL = os.environ.get('INDEX_URL', "")
|
22 |
+
PATTERN = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
23 |
+
|
24 |
+
|
25 |
+
# This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
|
26 |
+
def run_command(command: str,
|
27 |
+
desc: str = None,
|
28 |
+
error_desc: str = None,
|
29 |
+
custom_env: str = None,
|
30 |
+
live: bool = True) -> str:
|
31 |
+
"""
|
32 |
+
Run a command and return the output
|
33 |
+
Args:
|
34 |
+
command: Command to run
|
35 |
+
desc: Description of the command
|
36 |
+
error_desc: Description of the error
|
37 |
+
custom_env: Custom environment variables
|
38 |
+
live: Whether to print the output
|
39 |
+
Returns:
|
40 |
+
The output of the command
|
41 |
+
"""
|
42 |
+
if desc is not None:
|
43 |
+
print(desc)
|
44 |
+
|
45 |
+
run_kwargs = {
|
46 |
+
"args": command,
|
47 |
+
"shell": True,
|
48 |
+
"env": os.environ if custom_env is None else custom_env,
|
49 |
+
"encoding": 'utf8',
|
50 |
+
"errors": 'ignore'
|
51 |
+
}
|
52 |
+
|
53 |
+
if not live:
|
54 |
+
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
55 |
+
|
56 |
+
result = subprocess.run(check=False, **run_kwargs)
|
57 |
+
|
58 |
+
if result.returncode != 0:
|
59 |
+
error_bits = [
|
60 |
+
f"{error_desc or 'Error running command'}.",
|
61 |
+
f"Command: {command}",
|
62 |
+
f"Error code: {result.returncode}",
|
63 |
+
]
|
64 |
+
if result.stdout:
|
65 |
+
error_bits.append(f"stdout: {result.stdout}")
|
66 |
+
if result.stderr:
|
67 |
+
error_bits.append(f"stderr: {result.stderr}")
|
68 |
+
raise RuntimeError("\n".join(error_bits))
|
69 |
+
|
70 |
+
return result.stdout or ""
|
71 |
+
|
72 |
+
|
73 |
+
# This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
|
74 |
+
def run_pip(command, desc=None, live=True):
|
75 |
+
"""
|
76 |
+
Run a pip command
|
77 |
+
Args:
|
78 |
+
command: Command to run
|
79 |
+
desc: Description of the command
|
80 |
+
live: Whether to print the output
|
81 |
+
Returns:
|
82 |
+
The output of the command
|
83 |
+
"""
|
84 |
+
try:
|
85 |
+
index_url_line = f' --index-url {INDEX_URL}' if INDEX_URL != '' else ''
|
86 |
+
return run_command(
|
87 |
+
command=f'"{PYTHON_EXEC}" -m pip {command} --prefer-binary{index_url_line}',
|
88 |
+
desc=f"Installing {desc}",
|
89 |
+
error_desc=f"Couldn't install {desc}",
|
90 |
+
live=live
|
91 |
+
)
|
92 |
+
except Exception as e:
|
93 |
+
print(f'CMD Failed {command}: {e}')
|
94 |
+
return None
|
95 |
+
|
96 |
+
|
97 |
+
def is_installed(package: str) -> bool:
|
98 |
+
"""
|
99 |
+
Check if a package is installed
|
100 |
+
Args:
|
101 |
+
package: Package name
|
102 |
+
Returns:
|
103 |
+
Whether the package is installed
|
104 |
+
"""
|
105 |
+
try:
|
106 |
+
spec = find_spec(package)
|
107 |
+
except ModuleNotFoundError:
|
108 |
+
return False
|
109 |
+
|
110 |
+
return spec is not None
|
111 |
+
|
112 |
+
|
113 |
+
def check_torch_cuda() -> bool:
|
114 |
+
"""
|
115 |
+
Check if torch and CUDA is available
|
116 |
+
Returns:
|
117 |
+
Whether CUDA is available
|
118 |
+
"""
|
119 |
+
try:
|
120 |
+
import torch
|
121 |
+
return torch.cuda.is_available()
|
122 |
+
except ImportError:
|
123 |
+
return False
|
124 |
+
|
125 |
+
|
126 |
+
def requirements_check(requirements_file: str = 'requirements.txt',
|
127 |
+
pattern: re.Pattern = PATTERN) -> bool:
|
128 |
+
"""
|
129 |
+
Check if the requirements file is satisfied
|
130 |
+
Args:
|
131 |
+
requirements_file: Path to the requirements file
|
132 |
+
pattern: Pattern to match the requirements
|
133 |
+
Returns:
|
134 |
+
Whether the requirements file is satisfied
|
135 |
+
"""
|
136 |
+
with open(requirements_file, "r", encoding="utf8") as file:
|
137 |
+
for line in file:
|
138 |
+
if line.strip() == "":
|
139 |
+
continue
|
140 |
+
|
141 |
+
m = re.match(pattern, line)
|
142 |
+
if m is None:
|
143 |
+
return False
|
144 |
+
|
145 |
+
package = m.group(1).strip()
|
146 |
+
version_required = (m.group(2) or "").strip()
|
147 |
+
|
148 |
+
if version_required == "":
|
149 |
+
continue
|
150 |
+
|
151 |
+
try:
|
152 |
+
version_installed = metadata.version(package)
|
153 |
+
except Exception:
|
154 |
+
return False
|
155 |
+
|
156 |
+
if version.parse(version_required) != version.parse(version_installed):
|
157 |
+
return False
|
158 |
+
|
159 |
+
return True
|
fooocusapi/worker.py
ADDED
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Worker, modify from https://github.com/lllyasviel/Fooocus/blob/main/modules/async_worker.py
|
3 |
+
"""
|
4 |
+
import copy
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
from typing import List
|
9 |
+
import logging
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from fooocusapi.models.common.image_meta import image_parse
|
14 |
+
from modules.patch import PatchSettings, patch_settings, patch_all
|
15 |
+
from modules.flags import Performance
|
16 |
+
|
17 |
+
from fooocusapi.utils.file_utils import save_output_file
|
18 |
+
from fooocusapi.models.common.task import (
|
19 |
+
GenerationFinishReason,
|
20 |
+
ImageGenerationResult
|
21 |
+
)
|
22 |
+
from fooocusapi.utils.logger import logger
|
23 |
+
from fooocusapi.task_queue import (
|
24 |
+
QueueTask,
|
25 |
+
TaskQueue,
|
26 |
+
TaskOutputs
|
27 |
+
)
|
28 |
+
|
29 |
+
patch_all()
|
30 |
+
|
31 |
+
worker_queue: TaskQueue | None = None
|
32 |
+
last_model_name = None
|
33 |
+
|
34 |
+
|
35 |
+
def process_stop():
|
36 |
+
"""Stop process"""
|
37 |
+
import ldm_patched.modules.model_management
|
38 |
+
ldm_patched.modules.model_management.interrupt_current_processing()
|
39 |
+
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
@torch.inference_mode()
|
43 |
+
def task_schedule_loop():
|
44 |
+
"""Task schedule loop"""
|
45 |
+
while True:
|
46 |
+
if len(worker_queue.queue) == 0:
|
47 |
+
time.sleep(0.05)
|
48 |
+
continue
|
49 |
+
|
50 |
+
current_task = worker_queue.queue[0]
|
51 |
+
if current_task.start_mills == 0:
|
52 |
+
process_generate(current_task)
|
53 |
+
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
@torch.inference_mode()
|
57 |
+
def blocking_get_task_result(job_id: str) -> List[ImageGenerationResult]:
|
58 |
+
"""
|
59 |
+
Get task result, when async_task is false
|
60 |
+
:param job_id:
|
61 |
+
:return:
|
62 |
+
"""
|
63 |
+
waiting_sleep_steps: int = 0
|
64 |
+
waiting_start_time = time.perf_counter()
|
65 |
+
while not worker_queue.is_task_finished(job_id):
|
66 |
+
if waiting_sleep_steps == 0:
|
67 |
+
logger.std_info(f"[Task Queue] Waiting for task finished, job_id={job_id}")
|
68 |
+
delay = 0.05
|
69 |
+
time.sleep(delay)
|
70 |
+
waiting_sleep_steps += 1
|
71 |
+
if waiting_sleep_steps % int(10 / delay) == 0:
|
72 |
+
waiting_time = time.perf_counter() - waiting_start_time
|
73 |
+
logger.std_info(f"[Task Queue] Already waiting for {round(waiting_time, 1)} seconds, job_id={job_id}")
|
74 |
+
|
75 |
+
task = worker_queue.get_task(job_id, True)
|
76 |
+
return task.task_result
|
77 |
+
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
@torch.inference_mode()
|
81 |
+
def process_generate(async_task: QueueTask):
|
82 |
+
"""Generate image"""
|
83 |
+
try:
|
84 |
+
import modules.default_pipeline as pipeline
|
85 |
+
except Exception as e:
|
86 |
+
logger.std_error(f'[Task Queue] Import default pipeline error: {e}')
|
87 |
+
if not async_task.is_finished:
|
88 |
+
worker_queue.finish_task(async_task.job_id)
|
89 |
+
async_task.set_result([], True, str(e))
|
90 |
+
logger.std_error(f"[Task Queue] Finish task with error, seq={async_task.job_id}")
|
91 |
+
return []
|
92 |
+
|
93 |
+
import modules.flags as flags
|
94 |
+
import modules.core as core
|
95 |
+
import modules.inpaint_worker as inpaint_worker
|
96 |
+
import modules.config as config
|
97 |
+
import modules.constants as constants
|
98 |
+
import extras.preprocessors as preprocessors
|
99 |
+
import extras.ip_adapter as ip_adapter
|
100 |
+
import extras.face_crop as face_crop
|
101 |
+
import ldm_patched.modules.model_management as model_management
|
102 |
+
from modules.util import (
|
103 |
+
remove_empty_str, HWC3, resize_image,
|
104 |
+
get_image_shape_ceil, set_image_shape_ceil,
|
105 |
+
get_shape_ceil, resample_image, erode_or_dilate,
|
106 |
+
get_enabled_loras, parse_lora_references_from_prompt, apply_wildcards,
|
107 |
+
remove_performance_lora
|
108 |
+
)
|
109 |
+
|
110 |
+
from modules.upscaler import perform_upscale
|
111 |
+
from extras.expansion import safe_str
|
112 |
+
from extras.censor import default_censor
|
113 |
+
from modules.sdxl_styles import (
|
114 |
+
apply_style, get_random_style,
|
115 |
+
fooocus_expansion, apply_arrays, random_style_name
|
116 |
+
)
|
117 |
+
|
118 |
+
pid = os.getpid()
|
119 |
+
|
120 |
+
outputs = TaskOutputs(async_task)
|
121 |
+
results = []
|
122 |
+
|
123 |
+
def refresh_seed(seed_string: int | str | None) -> int:
|
124 |
+
"""
|
125 |
+
Refresh and check seed number.
|
126 |
+
:params seed_string: seed, str or int. None means random
|
127 |
+
:return: seed number
|
128 |
+
"""
|
129 |
+
if seed_string is None or seed_string == -1:
|
130 |
+
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
|
131 |
+
|
132 |
+
try:
|
133 |
+
seed_value = int(seed_string)
|
134 |
+
if constants.MIN_SEED <= seed_value <= constants.MAX_SEED:
|
135 |
+
return seed_value
|
136 |
+
except ValueError:
|
137 |
+
pass
|
138 |
+
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
|
139 |
+
|
140 |
+
def progressbar(_, number, text):
|
141 |
+
"""progress bar"""
|
142 |
+
logger.std_info(f'[Fooocus] {text}')
|
143 |
+
outputs.append(['preview', (number, text, None)])
|
144 |
+
|
145 |
+
def yield_result(_, images, tasks, extension='png',
|
146 |
+
blockout_nsfw=False, censor=True):
|
147 |
+
"""
|
148 |
+
Yield result
|
149 |
+
:param _: async task object
|
150 |
+
:param images: list for generated image
|
151 |
+
:param tasks: the image was generated one by one, when image number is not one, it will be a task list
|
152 |
+
:param extension: extension for saved image
|
153 |
+
:param blockout_nsfw: blockout nsfw image
|
154 |
+
:param censor: censor image
|
155 |
+
:return:
|
156 |
+
"""
|
157 |
+
if not isinstance(images, list):
|
158 |
+
images = [images]
|
159 |
+
|
160 |
+
if censor and (config.default_black_out_nsfw or black_out_nsfw):
|
161 |
+
images = default_censor(images)
|
162 |
+
|
163 |
+
results = []
|
164 |
+
for index, im in enumerate(images):
|
165 |
+
if async_task.req_param.save_name == '':
|
166 |
+
image_name = f"{async_task.job_id}-{str(index)}"
|
167 |
+
else:
|
168 |
+
image_name = f"{async_task.req_param.save_name}-{str(index)}"
|
169 |
+
if len(tasks) == 0:
|
170 |
+
img_seed = -1
|
171 |
+
img_meta = {}
|
172 |
+
else:
|
173 |
+
img_seed = tasks[index]['task_seed']
|
174 |
+
img_meta = image_parse(
|
175 |
+
async_tak=async_task,
|
176 |
+
task=tasks[index])
|
177 |
+
img_filename = save_output_file(
|
178 |
+
img=im,
|
179 |
+
image_name=image_name,
|
180 |
+
image_meta=img_meta,
|
181 |
+
extension=extension)
|
182 |
+
results.append(ImageGenerationResult(
|
183 |
+
im=img_filename,
|
184 |
+
seed=str(img_seed),
|
185 |
+
finish_reason=GenerationFinishReason.success))
|
186 |
+
async_task.set_result(results, False)
|
187 |
+
worker_queue.finish_task(async_task.job_id)
|
188 |
+
logger.std_info(f"[Task Queue] Finish task, job_id={async_task.job_id}")
|
189 |
+
|
190 |
+
outputs.append(['results', images])
|
191 |
+
pipeline.prepare_text_encoder(async_call=True)
|
192 |
+
|
193 |
+
try:
|
194 |
+
logger.std_info(f"[Task Queue] Task queue start task, job_id={async_task.job_id}")
|
195 |
+
# clear memory
|
196 |
+
global last_model_name
|
197 |
+
|
198 |
+
if last_model_name is None:
|
199 |
+
last_model_name = async_task.req_param.base_model_name
|
200 |
+
if last_model_name != async_task.req_param.base_model_name:
|
201 |
+
model_management.cleanup_models() # key1
|
202 |
+
model_management.unload_all_models()
|
203 |
+
model_management.soft_empty_cache() # key2
|
204 |
+
last_model_name = async_task.req_param.base_model_name
|
205 |
+
|
206 |
+
worker_queue.start_task(async_task.job_id)
|
207 |
+
|
208 |
+
execution_start_time = time.perf_counter()
|
209 |
+
|
210 |
+
# Transform parameters
|
211 |
+
params = async_task.req_param
|
212 |
+
prompt = params.prompt
|
213 |
+
negative_prompt = params.negative_prompt
|
214 |
+
style_selections = params.style_selections
|
215 |
+
performance_selection = Performance(params.performance_selection)
|
216 |
+
aspect_ratios_selection = params.aspect_ratios_selection
|
217 |
+
image_number = params.image_number
|
218 |
+
save_metadata_to_images = params.save_meta
|
219 |
+
metadata_scheme = params.meta_scheme
|
220 |
+
save_extension = params.save_extension
|
221 |
+
save_name = params.save_name
|
222 |
+
image_seed = refresh_seed(params.image_seed)
|
223 |
+
read_wildcards_in_order = False
|
224 |
+
sharpness = params.sharpness
|
225 |
+
guidance_scale = params.guidance_scale
|
226 |
+
base_model_name = params.base_model_name
|
227 |
+
refiner_model_name = params.refiner_model_name
|
228 |
+
refiner_switch = params.refiner_switch
|
229 |
+
loras = params.loras
|
230 |
+
input_image_checkbox = params.uov_input_image is not None or params.inpaint_input_image is not None or len(params.image_prompts) > 0
|
231 |
+
current_tab = 'uov' if params.uov_method != flags.disabled else 'ip' if len(params.image_prompts) > 0 else 'inpaint' if params.inpaint_input_image is not None else None
|
232 |
+
uov_method = params.uov_method
|
233 |
+
upscale_value = params.upscale_value
|
234 |
+
uov_input_image = params.uov_input_image
|
235 |
+
outpaint_selections = params.outpaint_selections
|
236 |
+
outpaint_distance_left = params.outpaint_distance_left
|
237 |
+
outpaint_distance_top = params.outpaint_distance_top
|
238 |
+
outpaint_distance_right = params.outpaint_distance_right
|
239 |
+
outpaint_distance_bottom = params.outpaint_distance_bottom
|
240 |
+
inpaint_input_image = params.inpaint_input_image
|
241 |
+
inpaint_additional_prompt = '' if params.inpaint_additional_prompt is None else params.inpaint_additional_prompt
|
242 |
+
inpaint_mask_image_upload = None
|
243 |
+
|
244 |
+
adp = params.advanced_params
|
245 |
+
disable_preview = adp.disable_preview
|
246 |
+
disable_intermediate_results = adp.disable_intermediate_results
|
247 |
+
disable_seed_increment = adp.disable_seed_increment
|
248 |
+
adm_scaler_positive = adp.adm_scaler_positive
|
249 |
+
adm_scaler_negative = adp.adm_scaler_negative
|
250 |
+
adm_scaler_end = adp.adm_scaler_end
|
251 |
+
adaptive_cfg = adp.adaptive_cfg
|
252 |
+
sampler_name = adp.sampler_name
|
253 |
+
scheduler_name = adp.scheduler_name
|
254 |
+
overwrite_step = adp.overwrite_step
|
255 |
+
overwrite_switch = adp.overwrite_switch
|
256 |
+
overwrite_width = adp.overwrite_width
|
257 |
+
overwrite_height = adp.overwrite_height
|
258 |
+
overwrite_vary_strength = adp.overwrite_vary_strength
|
259 |
+
overwrite_upscale_strength = adp.overwrite_upscale_strength
|
260 |
+
mixing_image_prompt_and_vary_upscale = adp.mixing_image_prompt_and_vary_upscale
|
261 |
+
mixing_image_prompt_and_inpaint = adp.mixing_image_prompt_and_inpaint
|
262 |
+
debugging_cn_preprocessor = adp.debugging_cn_preprocessor
|
263 |
+
skipping_cn_preprocessor = adp.skipping_cn_preprocessor
|
264 |
+
canny_low_threshold = adp.canny_low_threshold
|
265 |
+
canny_high_threshold = adp.canny_high_threshold
|
266 |
+
refiner_swap_method = adp.refiner_swap_method
|
267 |
+
controlnet_softness = adp.controlnet_softness
|
268 |
+
freeu_enabled = adp.freeu_enabled
|
269 |
+
freeu_b1 = adp.freeu_b1
|
270 |
+
freeu_b2 = adp.freeu_b2
|
271 |
+
freeu_s1 = adp.freeu_s1
|
272 |
+
freeu_s2 = adp.freeu_s2
|
273 |
+
debugging_inpaint_preprocessor = adp.debugging_inpaint_preprocessor
|
274 |
+
inpaint_disable_initial_latent = adp.inpaint_disable_initial_latent
|
275 |
+
inpaint_engine = adp.inpaint_engine
|
276 |
+
inpaint_strength = adp.inpaint_strength
|
277 |
+
inpaint_respective_field = adp.inpaint_respective_field
|
278 |
+
inpaint_mask_upload_checkbox = adp.inpaint_mask_upload_checkbox
|
279 |
+
invert_mask_checkbox = adp.invert_mask_checkbox
|
280 |
+
inpaint_erode_or_dilate = adp.inpaint_erode_or_dilate
|
281 |
+
black_out_nsfw = adp.black_out_nsfw
|
282 |
+
vae_name = adp.vae_name
|
283 |
+
clip_skip = adp.clip_skip
|
284 |
+
|
285 |
+
cn_tasks = {x: [] for x in flags.ip_list}
|
286 |
+
for img_prompt in params.image_prompts:
|
287 |
+
cn_img, cn_stop, cn_weight, cn_type = img_prompt
|
288 |
+
cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
|
289 |
+
|
290 |
+
if inpaint_input_image is not None and inpaint_input_image['image'] is not None:
|
291 |
+
inpaint_image_size = inpaint_input_image['image'].shape[:2]
|
292 |
+
if inpaint_input_image['mask'] is None:
|
293 |
+
inpaint_input_image['mask'] = np.zeros(inpaint_image_size, dtype=np.uint8)
|
294 |
+
else:
|
295 |
+
inpaint_mask_upload_checkbox = True
|
296 |
+
|
297 |
+
inpaint_input_image['mask'] = HWC3(inpaint_input_image['mask'])
|
298 |
+
inpaint_mask_image_upload = inpaint_input_image['mask']
|
299 |
+
|
300 |
+
# Fooocus async_worker.py code start
|
301 |
+
|
302 |
+
outpaint_selections = [o.lower() for o in outpaint_selections]
|
303 |
+
base_model_additional_loras = []
|
304 |
+
raw_style_selections = copy.deepcopy(style_selections)
|
305 |
+
uov_method = uov_method.lower()
|
306 |
+
|
307 |
+
if fooocus_expansion in style_selections:
|
308 |
+
use_expansion = True
|
309 |
+
style_selections.remove(fooocus_expansion)
|
310 |
+
else:
|
311 |
+
use_expansion = False
|
312 |
+
|
313 |
+
use_style = len(style_selections) > 0
|
314 |
+
|
315 |
+
if base_model_name == refiner_model_name:
|
316 |
+
logger.std_warn('[Fooocus] Refiner disabled because base model and refiner are same.')
|
317 |
+
refiner_model_name = 'None'
|
318 |
+
|
319 |
+
steps = performance_selection.steps()
|
320 |
+
|
321 |
+
performance_loras = []
|
322 |
+
|
323 |
+
if performance_selection == Performance.EXTREME_SPEED:
|
324 |
+
logger.std_warn('[Fooocus] Enter LCM mode.')
|
325 |
+
progressbar(async_task, 1, 'Downloading LCM components ...')
|
326 |
+
performance_loras += [(config.downloading_sdxl_lcm_lora(), 1.0)]
|
327 |
+
|
328 |
+
if refiner_model_name != 'None':
|
329 |
+
logger.std_info('[Fooocus] Refiner disabled in LCM mode.')
|
330 |
+
|
331 |
+
refiner_model_name = 'None'
|
332 |
+
sampler_name = 'lcm'
|
333 |
+
scheduler_name = 'lcm'
|
334 |
+
sharpness = 0.0
|
335 |
+
guidance_scale = 1.0
|
336 |
+
adaptive_cfg = 1.0
|
337 |
+
refiner_switch = 1.0
|
338 |
+
adm_scaler_positive = 1.0
|
339 |
+
adm_scaler_negative = 1.0
|
340 |
+
adm_scaler_end = 0.0
|
341 |
+
|
342 |
+
elif performance_selection == Performance.LIGHTNING:
|
343 |
+
logger.std_info('[Fooocus] Enter Lightning mode.')
|
344 |
+
progressbar(async_task, 1, 'Downloading Lightning components ...')
|
345 |
+
performance_loras += [(config.downloading_sdxl_lightning_lora(), 1.0)]
|
346 |
+
|
347 |
+
if refiner_model_name != 'None':
|
348 |
+
logger.std_info('[Fooocus] Refiner disabled in Lightning mode.')
|
349 |
+
|
350 |
+
refiner_model_name = 'None'
|
351 |
+
sampler_name = 'euler'
|
352 |
+
scheduler_name = 'sgm_uniform'
|
353 |
+
sharpness = 0.0
|
354 |
+
guidance_scale = 1.0
|
355 |
+
adaptive_cfg = 1.0
|
356 |
+
refiner_switch = 1.0
|
357 |
+
adm_scaler_positive = 1.0
|
358 |
+
adm_scaler_negative = 1.0
|
359 |
+
adm_scaler_end = 0.0
|
360 |
+
|
361 |
+
elif performance_selection == Performance.HYPER_SD:
|
362 |
+
print('Enter Hyper-SD mode.')
|
363 |
+
progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
|
364 |
+
performance_loras += [(config.downloading_sdxl_hyper_sd_lora(), 0.8)]
|
365 |
+
|
366 |
+
if refiner_model_name != 'None':
|
367 |
+
logger.std_info('[Fooocus] Refiner disabled in Hyper-SD mode.')
|
368 |
+
|
369 |
+
refiner_model_name = 'None'
|
370 |
+
sampler_name = 'dpmpp_sde_gpu'
|
371 |
+
scheduler_name = 'karras'
|
372 |
+
sharpness = 0.0
|
373 |
+
guidance_scale = 1.0
|
374 |
+
adaptive_cfg = 1.0
|
375 |
+
refiner_switch = 1.0
|
376 |
+
adm_scaler_positive = 1.0
|
377 |
+
adm_scaler_negative = 1.0
|
378 |
+
adm_scaler_end = 0.0
|
379 |
+
|
380 |
+
logger.std_info(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
|
381 |
+
logger.std_info(f'[Parameters] CLIP Skip = {clip_skip}')
|
382 |
+
logger.std_info(f'[Parameters] Sharpness = {sharpness}')
|
383 |
+
logger.std_info(f'[Parameters] ControlNet Softness = {controlnet_softness}')
|
384 |
+
logger.std_info(f'[Parameters] ADM Scale = '
|
385 |
+
f'{adm_scaler_positive} : '
|
386 |
+
f'{adm_scaler_negative} : '
|
387 |
+
f'{adm_scaler_end}')
|
388 |
+
|
389 |
+
patch_settings[pid] = PatchSettings(
|
390 |
+
sharpness,
|
391 |
+
adm_scaler_end,
|
392 |
+
adm_scaler_positive,
|
393 |
+
adm_scaler_negative,
|
394 |
+
controlnet_softness,
|
395 |
+
adaptive_cfg
|
396 |
+
)
|
397 |
+
|
398 |
+
cfg_scale = float(guidance_scale)
|
399 |
+
logger.std_info(f'[Parameters] CFG = {cfg_scale}')
|
400 |
+
|
401 |
+
initial_latent = None
|
402 |
+
denoising_strength = 1.0
|
403 |
+
tiled = False
|
404 |
+
|
405 |
+
width, height = aspect_ratios_selection.replace('×', ' ').replace('*', ' ').split(' ')[:2]
|
406 |
+
width, height = int(width), int(height)
|
407 |
+
|
408 |
+
skip_prompt_processing = False
|
409 |
+
|
410 |
+
inpaint_worker.current_task = None
|
411 |
+
inpaint_parameterized = inpaint_engine != 'None'
|
412 |
+
inpaint_image = None
|
413 |
+
inpaint_mask = None
|
414 |
+
inpaint_head_model_path = None
|
415 |
+
|
416 |
+
use_synthetic_refiner = False
|
417 |
+
|
418 |
+
controlnet_canny_path = None
|
419 |
+
controlnet_cpds_path = None
|
420 |
+
clip_vision_path, ip_negative_path, ip_adapter_path, ip_adapter_face_path = None, None, None, None
|
421 |
+
|
422 |
+
seed = int(image_seed)
|
423 |
+
logger.std_info(f'[Parameters] Seed = {seed}')
|
424 |
+
|
425 |
+
goals = []
|
426 |
+
tasks = []
|
427 |
+
|
428 |
+
if input_image_checkbox:
|
429 |
+
if (current_tab == 'uov' or (
|
430 |
+
current_tab == 'ip' and mixing_image_prompt_and_vary_upscale)) \
|
431 |
+
and uov_method != flags.disabled and uov_input_image is not None:
|
432 |
+
uov_input_image = HWC3(uov_input_image)
|
433 |
+
if 'vary' in uov_method:
|
434 |
+
goals.append('vary')
|
435 |
+
elif 'upscale' in uov_method:
|
436 |
+
goals.append('upscale')
|
437 |
+
if 'fast' in uov_method:
|
438 |
+
skip_prompt_processing = True
|
439 |
+
else:
|
440 |
+
steps = performance_selection.steps_uov()
|
441 |
+
|
442 |
+
progressbar(async_task, 1, 'Downloading upscale models ...')
|
443 |
+
config.downloading_upscale_model()
|
444 |
+
if (current_tab == 'inpaint' or (
|
445 |
+
current_tab == 'ip' and mixing_image_prompt_and_inpaint)) \
|
446 |
+
and isinstance(inpaint_input_image, dict):
|
447 |
+
inpaint_image = inpaint_input_image['image']
|
448 |
+
inpaint_mask = inpaint_input_image['mask'][:, :, 0]
|
449 |
+
|
450 |
+
if inpaint_mask_upload_checkbox:
|
451 |
+
if isinstance(inpaint_mask_image_upload, np.ndarray):
|
452 |
+
if inpaint_mask_image_upload.ndim == 3:
|
453 |
+
H, W, C = inpaint_image.shape
|
454 |
+
inpaint_mask_image_upload = resample_image(inpaint_mask_image_upload, width=W, height=H)
|
455 |
+
inpaint_mask_image_upload = np.mean(inpaint_mask_image_upload, axis=2)
|
456 |
+
inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255
|
457 |
+
inpaint_mask = np.maximum(np.zeros(shape=(H, W), dtype=np.uint8), inpaint_mask_image_upload)
|
458 |
+
|
459 |
+
if int(inpaint_erode_or_dilate) != 0:
|
460 |
+
inpaint_mask = erode_or_dilate(inpaint_mask, inpaint_erode_or_dilate)
|
461 |
+
|
462 |
+
if invert_mask_checkbox:
|
463 |
+
inpaint_mask = 255 - inpaint_mask
|
464 |
+
|
465 |
+
inpaint_image = HWC3(inpaint_image)
|
466 |
+
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
|
467 |
+
and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
|
468 |
+
progressbar(async_task, 1, 'Downloading upscale models ...')
|
469 |
+
config.downloading_upscale_model()
|
470 |
+
if inpaint_parameterized:
|
471 |
+
progressbar(async_task, 1, 'Downloading inpainter ...')
|
472 |
+
inpaint_head_model_path, inpaint_patch_model_path = config.downloading_inpaint_models(
|
473 |
+
inpaint_engine)
|
474 |
+
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
|
475 |
+
logger.std_info(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
476 |
+
if refiner_model_name == 'None':
|
477 |
+
use_synthetic_refiner = True
|
478 |
+
refiner_switch = 0.8
|
479 |
+
else:
|
480 |
+
inpaint_head_model_path, inpaint_patch_model_path = None, None
|
481 |
+
logger.std_info('[Inpaint] Parameterized inpaint is disabled.')
|
482 |
+
if inpaint_additional_prompt != '':
|
483 |
+
if prompt == '':
|
484 |
+
prompt = inpaint_additional_prompt
|
485 |
+
else:
|
486 |
+
prompt = inpaint_additional_prompt + '\n' + prompt
|
487 |
+
goals.append('inpaint')
|
488 |
+
if current_tab == 'ip' or \
|
489 |
+
mixing_image_prompt_and_vary_upscale or \
|
490 |
+
mixing_image_prompt_and_inpaint:
|
491 |
+
goals.append('cn')
|
492 |
+
progressbar(async_task, 1, 'Downloading control models ...')
|
493 |
+
if len(cn_tasks[flags.cn_canny]) > 0:
|
494 |
+
controlnet_canny_path = config.downloading_controlnet_canny()
|
495 |
+
if len(cn_tasks[flags.cn_cpds]) > 0:
|
496 |
+
controlnet_cpds_path = config.downloading_controlnet_cpds()
|
497 |
+
if len(cn_tasks[flags.cn_ip]) > 0:
|
498 |
+
clip_vision_path, ip_negative_path, ip_adapter_path = config.downloading_ip_adapters('ip')
|
499 |
+
if len(cn_tasks[flags.cn_ip_face]) > 0:
|
500 |
+
clip_vision_path, ip_negative_path, ip_adapter_face_path = config.downloading_ip_adapters(
|
501 |
+
'face')
|
502 |
+
progressbar(async_task, 1, 'Loading control models ...')
|
503 |
+
|
504 |
+
# Load or unload CNs
|
505 |
+
pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path])
|
506 |
+
ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path)
|
507 |
+
ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_face_path)
|
508 |
+
|
509 |
+
if overwrite_step > 0:
|
510 |
+
steps = overwrite_step
|
511 |
+
|
512 |
+
switch = int(round(steps * refiner_switch))
|
513 |
+
|
514 |
+
if overwrite_switch > 0:
|
515 |
+
switch = overwrite_switch
|
516 |
+
|
517 |
+
if overwrite_width > 0:
|
518 |
+
width = overwrite_width
|
519 |
+
|
520 |
+
if overwrite_height > 0:
|
521 |
+
height = overwrite_height
|
522 |
+
|
523 |
+
logger.std_info(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}')
|
524 |
+
logger.std_info(f'[Parameters] Steps = {steps} - {switch}')
|
525 |
+
|
526 |
+
progressbar(async_task, 1, 'Initializing ...')
|
527 |
+
|
528 |
+
if not skip_prompt_processing:
|
529 |
+
|
530 |
+
prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
|
531 |
+
negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='')
|
532 |
+
|
533 |
+
prompt = prompts[0]
|
534 |
+
negative_prompt = negative_prompts[0]
|
535 |
+
|
536 |
+
if prompt == '':
|
537 |
+
# disable expansion when empty since it is not meaningful and influences image prompt
|
538 |
+
use_expansion = False
|
539 |
+
|
540 |
+
extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
|
541 |
+
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
542 |
+
|
543 |
+
progressbar(async_task, 3, 'Loading models ...')
|
544 |
+
lora_filenames = remove_performance_lora(config.lora_filenames, performance_selection)
|
545 |
+
loras, prompt = parse_lora_references_from_prompt(prompt, loras, config.default_max_lora_number, lora_filenames=lora_filenames)
|
546 |
+
loras += performance_loras
|
547 |
+
|
548 |
+
pipeline.refresh_everything(
|
549 |
+
refiner_model_name=refiner_model_name,
|
550 |
+
base_model_name=base_model_name,
|
551 |
+
loras=loras,
|
552 |
+
base_model_additional_loras=base_model_additional_loras,
|
553 |
+
use_synthetic_refiner=use_synthetic_refiner)
|
554 |
+
|
555 |
+
pipeline.set_clip_skip(clip_skip)
|
556 |
+
|
557 |
+
progressbar(async_task, 3, 'Processing prompts ...')
|
558 |
+
tasks = []
|
559 |
+
|
560 |
+
for i in range(image_number):
|
561 |
+
if disable_seed_increment:
|
562 |
+
task_seed = seed % (constants.MAX_SEED + 1)
|
563 |
+
else:
|
564 |
+
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
|
565 |
+
|
566 |
+
task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
|
567 |
+
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
|
568 |
+
task_prompt = apply_arrays(task_prompt, i)
|
569 |
+
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
|
570 |
+
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
|
571 |
+
extra_positive_prompts]
|
572 |
+
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
|
573 |
+
extra_negative_prompts]
|
574 |
+
|
575 |
+
positive_basic_workloads = []
|
576 |
+
negative_basic_workloads = []
|
577 |
+
|
578 |
+
task_styles = style_selections.copy()
|
579 |
+
if use_style:
|
580 |
+
for index, style in enumerate(task_styles):
|
581 |
+
if style == random_style_name:
|
582 |
+
style = get_random_style(task_rng)
|
583 |
+
task_styles[index] = style
|
584 |
+
p, n = apply_style(style, positive=task_prompt)
|
585 |
+
positive_basic_workloads = positive_basic_workloads + p
|
586 |
+
negative_basic_workloads = negative_basic_workloads + n
|
587 |
+
else:
|
588 |
+
positive_basic_workloads.append(task_prompt)
|
589 |
+
|
590 |
+
negative_basic_workloads.append(task_negative_prompt) # Always use independent workload for negative.
|
591 |
+
|
592 |
+
positive_basic_workloads = positive_basic_workloads + task_extra_positive_prompts
|
593 |
+
negative_basic_workloads = negative_basic_workloads + task_extra_negative_prompts
|
594 |
+
|
595 |
+
positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=task_prompt)
|
596 |
+
negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=task_negative_prompt)
|
597 |
+
|
598 |
+
tasks.append(dict(
|
599 |
+
task_seed=task_seed,
|
600 |
+
task_prompt=task_prompt,
|
601 |
+
task_negative_prompt=task_negative_prompt,
|
602 |
+
positive=positive_basic_workloads,
|
603 |
+
negative=negative_basic_workloads,
|
604 |
+
expansion='',
|
605 |
+
c=None,
|
606 |
+
uc=None,
|
607 |
+
positive_top_k=len(positive_basic_workloads),
|
608 |
+
negative_top_k=len(negative_basic_workloads),
|
609 |
+
log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
|
610 |
+
log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
|
611 |
+
styles=task_styles
|
612 |
+
))
|
613 |
+
|
614 |
+
if use_expansion:
|
615 |
+
for i, t in enumerate(tasks):
|
616 |
+
progressbar(async_task, 4, f'Preparing Fooocus text #{i + 1} ...')
|
617 |
+
expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
|
618 |
+
logger.std_info(f'[Prompt Expansion] {expansion}')
|
619 |
+
t['expansion'] = expansion
|
620 |
+
t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
|
621 |
+
|
622 |
+
for i, t in enumerate(tasks):
|
623 |
+
progressbar(async_task, 5, f'Encoding positive #{i + 1} ...')
|
624 |
+
t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
|
625 |
+
|
626 |
+
for i, t in enumerate(tasks):
|
627 |
+
if abs(float(cfg_scale) - 1.0) < 1e-4:
|
628 |
+
t['uc'] = pipeline.clone_cond(t['c'])
|
629 |
+
else:
|
630 |
+
progressbar(async_task, 6, f'Encoding negative #{i + 1} ...')
|
631 |
+
t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
|
632 |
+
|
633 |
+
if len(goals) > 0:
|
634 |
+
progressbar(async_task, 7, 'Image processing ...')
|
635 |
+
|
636 |
+
if 'vary' in goals:
|
637 |
+
if 'subtle' in uov_method:
|
638 |
+
denoising_strength = 0.5
|
639 |
+
if 'strong' in uov_method:
|
640 |
+
denoising_strength = 0.85
|
641 |
+
if overwrite_vary_strength > 0:
|
642 |
+
denoising_strength = overwrite_vary_strength
|
643 |
+
|
644 |
+
shape_ceil = get_image_shape_ceil(uov_input_image)
|
645 |
+
if shape_ceil < 1024:
|
646 |
+
logger.std_warn('[Vary] Image is resized because it is too small.')
|
647 |
+
shape_ceil = 1024
|
648 |
+
elif shape_ceil > 2048:
|
649 |
+
logger.std_warn('[Vary] Image is resized because it is too big.')
|
650 |
+
shape_ceil = 2048
|
651 |
+
|
652 |
+
uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
|
653 |
+
|
654 |
+
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
655 |
+
progressbar(async_task, 8, 'VAE encoding ...')
|
656 |
+
|
657 |
+
candidate_vae, _ = pipeline.get_candidate_vae(
|
658 |
+
steps=steps,
|
659 |
+
switch=switch,
|
660 |
+
denoise=denoising_strength,
|
661 |
+
refiner_swap_method=refiner_swap_method
|
662 |
+
)
|
663 |
+
|
664 |
+
initial_latent = core.encode_vae(vae=candidate_vae, pixels=initial_pixels)
|
665 |
+
B, C, H, W = initial_latent['samples'].shape
|
666 |
+
width = W * 8
|
667 |
+
height = H * 8
|
668 |
+
logger.std_info(f'[Vary] Final resolution is {str((height, width))}.')
|
669 |
+
|
670 |
+
if 'upscale' in goals:
|
671 |
+
H, W, C = uov_input_image.shape
|
672 |
+
progressbar(async_task, 9, f'Upscaling image from {str((H, W))} ...')
|
673 |
+
uov_input_image = perform_upscale(uov_input_image)
|
674 |
+
logger.std_info('[Upscale] Image upscale.')
|
675 |
+
|
676 |
+
if upscale_value is not None and upscale_value > 1.0:
|
677 |
+
f = upscale_value
|
678 |
+
else:
|
679 |
+
if '1.5x' in uov_method:
|
680 |
+
f = 1.5
|
681 |
+
elif '2x' in uov_method:
|
682 |
+
f = 2.0
|
683 |
+
else:
|
684 |
+
f = 1.0
|
685 |
+
|
686 |
+
shape_ceil = get_shape_ceil(H * f, W * f)
|
687 |
+
|
688 |
+
if shape_ceil < 1024:
|
689 |
+
logger.std_info('[Upscale] Image is resized because it is too small.')
|
690 |
+
uov_input_image = set_image_shape_ceil(uov_input_image, 1024)
|
691 |
+
shape_ceil = 1024
|
692 |
+
else:
|
693 |
+
uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f)
|
694 |
+
|
695 |
+
image_is_super_large = shape_ceil > 2800
|
696 |
+
|
697 |
+
if 'fast' in uov_method:
|
698 |
+
direct_return = True
|
699 |
+
elif image_is_super_large:
|
700 |
+
logger.std_info('[Upscale] Image is too large. Directly returned the SR image. '
|
701 |
+
'Usually directly return SR image at 4K resolution '
|
702 |
+
'yields better results than SDXL diffusion.')
|
703 |
+
direct_return = True
|
704 |
+
else:
|
705 |
+
direct_return = False
|
706 |
+
|
707 |
+
if direct_return:
|
708 |
+
# d = [('Upscale (Fast)', '2x')]
|
709 |
+
# log(uov_input_image, d, output_format=save_extension)
|
710 |
+
if config.default_black_out_nsfw or black_out_nsfw:
|
711 |
+
uov_input_image = default_censor(uov_input_image)
|
712 |
+
yield_result(async_task, uov_input_image, tasks, save_extension, False, False)
|
713 |
+
return
|
714 |
+
|
715 |
+
tiled = True
|
716 |
+
denoising_strength = 0.382
|
717 |
+
|
718 |
+
if overwrite_upscale_strength > 0:
|
719 |
+
denoising_strength = overwrite_upscale_strength
|
720 |
+
|
721 |
+
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
722 |
+
progressbar(async_task, 10, 'VAE encoding ...')
|
723 |
+
|
724 |
+
candidate_vae, _ = pipeline.get_candidate_vae(
|
725 |
+
steps=steps,
|
726 |
+
switch=switch,
|
727 |
+
denoise=denoising_strength,
|
728 |
+
refiner_swap_method=refiner_swap_method
|
729 |
+
)
|
730 |
+
|
731 |
+
initial_latent = core.encode_vae(
|
732 |
+
vae=candidate_vae,
|
733 |
+
pixels=initial_pixels, tiled=True)
|
734 |
+
B, C, H, W = initial_latent['samples'].shape
|
735 |
+
width = W * 8
|
736 |
+
height = H * 8
|
737 |
+
logger.std_info(f'[Upscale] Final resolution is {str((height, width))}.')
|
738 |
+
|
739 |
+
if 'inpaint' in goals:
|
740 |
+
if len(outpaint_selections) > 0:
|
741 |
+
H, W, C = inpaint_image.shape
|
742 |
+
if 'top' in outpaint_selections:
|
743 |
+
distance_top = int(H * 0.3)
|
744 |
+
if outpaint_distance_top > 0:
|
745 |
+
distance_top = outpaint_distance_top
|
746 |
+
|
747 |
+
inpaint_image = np.pad(inpaint_image, [[distance_top, 0], [0, 0], [0, 0]], mode='edge')
|
748 |
+
inpaint_mask = np.pad(inpaint_mask, [[distance_top, 0], [0, 0]], mode='constant',
|
749 |
+
constant_values=255)
|
750 |
+
|
751 |
+
if 'bottom' in outpaint_selections:
|
752 |
+
distance_bottom = int(H * 0.3)
|
753 |
+
if outpaint_distance_bottom > 0:
|
754 |
+
distance_bottom = outpaint_distance_bottom
|
755 |
+
|
756 |
+
inpaint_image = np.pad(inpaint_image, [[0, distance_bottom], [0, 0], [0, 0]], mode='edge')
|
757 |
+
inpaint_mask = np.pad(inpaint_mask, [[0, distance_bottom], [0, 0]], mode='constant',
|
758 |
+
constant_values=255)
|
759 |
+
|
760 |
+
H, W, C = inpaint_image.shape
|
761 |
+
if 'left' in outpaint_selections:
|
762 |
+
distance_left = int(W * 0.3)
|
763 |
+
if outpaint_distance_left > 0:
|
764 |
+
distance_left = outpaint_distance_left
|
765 |
+
|
766 |
+
inpaint_image = np.pad(inpaint_image, [[0, 0], [distance_left, 0], [0, 0]], mode='edge')
|
767 |
+
inpaint_mask = np.pad(inpaint_mask, [[0, 0], [distance_left, 0]], mode='constant',
|
768 |
+
constant_values=255)
|
769 |
+
|
770 |
+
if 'right' in outpaint_selections:
|
771 |
+
distance_right = int(W * 0.3)
|
772 |
+
if outpaint_distance_right > 0:
|
773 |
+
distance_right = outpaint_distance_right
|
774 |
+
|
775 |
+
inpaint_image = np.pad(inpaint_image, [[0, 0], [0, distance_right], [0, 0]], mode='edge')
|
776 |
+
inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, distance_right]], mode='constant',
|
777 |
+
constant_values=255)
|
778 |
+
|
779 |
+
inpaint_image = np.ascontiguousarray(inpaint_image.copy())
|
780 |
+
inpaint_mask = np.ascontiguousarray(inpaint_mask.copy())
|
781 |
+
inpaint_strength = 1.0
|
782 |
+
inpaint_respective_field = 1.0
|
783 |
+
|
784 |
+
denoising_strength = inpaint_strength
|
785 |
+
|
786 |
+
inpaint_worker.current_task = inpaint_worker.InpaintWorker(
|
787 |
+
image=inpaint_image,
|
788 |
+
mask=inpaint_mask,
|
789 |
+
use_fill=denoising_strength > 0.99,
|
790 |
+
k=inpaint_respective_field
|
791 |
+
)
|
792 |
+
|
793 |
+
if debugging_inpaint_preprocessor:
|
794 |
+
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), tasks,
|
795 |
+
black_out_nsfw)
|
796 |
+
return
|
797 |
+
|
798 |
+
progressbar(async_task, 11, 'VAE Inpaint encoding ...')
|
799 |
+
|
800 |
+
inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
|
801 |
+
inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
|
802 |
+
inpaint_pixel_mask = core.numpy_to_pytorch(inpaint_worker.current_task.interested_mask)
|
803 |
+
|
804 |
+
candidate_vae, candidate_vae_swap = pipeline.get_candidate_vae(
|
805 |
+
steps=steps,
|
806 |
+
switch=switch,
|
807 |
+
denoise=denoising_strength,
|
808 |
+
refiner_swap_method=refiner_swap_method
|
809 |
+
)
|
810 |
+
|
811 |
+
latent_inpaint, latent_mask = core.encode_vae_inpaint(
|
812 |
+
mask=inpaint_pixel_mask,
|
813 |
+
vae=candidate_vae,
|
814 |
+
pixels=inpaint_pixel_image)
|
815 |
+
|
816 |
+
latent_swap = None
|
817 |
+
if candidate_vae_swap is not None:
|
818 |
+
progressbar(async_task, 12, 'VAE SD15 encoding ...')
|
819 |
+
latent_swap = core.encode_vae(
|
820 |
+
vae=candidate_vae_swap,
|
821 |
+
pixels=inpaint_pixel_fill)['samples']
|
822 |
+
|
823 |
+
progressbar(async_task, 13, 'VAE encoding ...')
|
824 |
+
latent_fill = core.encode_vae(
|
825 |
+
vae=candidate_vae,
|
826 |
+
pixels=inpaint_pixel_fill)['samples']
|
827 |
+
|
828 |
+
inpaint_worker.current_task.load_latent(
|
829 |
+
latent_fill=latent_fill, latent_mask=latent_mask, latent_swap=latent_swap)
|
830 |
+
|
831 |
+
if inpaint_parameterized:
|
832 |
+
pipeline.final_unet = inpaint_worker.current_task.patch(
|
833 |
+
inpaint_head_model_path=inpaint_head_model_path,
|
834 |
+
inpaint_latent=latent_inpaint,
|
835 |
+
inpaint_latent_mask=latent_mask,
|
836 |
+
model=pipeline.final_unet
|
837 |
+
)
|
838 |
+
|
839 |
+
if not inpaint_disable_initial_latent:
|
840 |
+
initial_latent = {'samples': latent_fill}
|
841 |
+
|
842 |
+
B, C, H, W = latent_fill.shape
|
843 |
+
height, width = H * 8, W * 8
|
844 |
+
final_height, final_width = inpaint_worker.current_task.image.shape[:2]
|
845 |
+
logger.std_info(f'[Inpaint] Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.')
|
846 |
+
|
847 |
+
if 'cn' in goals:
|
848 |
+
for task in cn_tasks[flags.cn_canny]:
|
849 |
+
cn_img, cn_stop, cn_weight = task
|
850 |
+
cn_img = resize_image(HWC3(cn_img), width=width, height=height)
|
851 |
+
|
852 |
+
if not skipping_cn_preprocessor:
|
853 |
+
cn_img = preprocessors.canny_pyramid(cn_img, canny_low_threshold, canny_high_threshold)
|
854 |
+
|
855 |
+
cn_img = HWC3(cn_img)
|
856 |
+
task[0] = core.numpy_to_pytorch(cn_img)
|
857 |
+
if debugging_cn_preprocessor:
|
858 |
+
yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
|
859 |
+
return
|
860 |
+
for task in cn_tasks[flags.cn_cpds]:
|
861 |
+
cn_img, cn_stop, cn_weight = task
|
862 |
+
cn_img = resize_image(HWC3(cn_img), width=width, height=height)
|
863 |
+
|
864 |
+
if not skipping_cn_preprocessor:
|
865 |
+
cn_img = preprocessors.cpds(cn_img)
|
866 |
+
|
867 |
+
cn_img = HWC3(cn_img)
|
868 |
+
task[0] = core.numpy_to_pytorch(cn_img)
|
869 |
+
if debugging_cn_preprocessor:
|
870 |
+
yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
|
871 |
+
return
|
872 |
+
for task in cn_tasks[flags.cn_ip]:
|
873 |
+
cn_img, cn_stop, cn_weight = task
|
874 |
+
cn_img = HWC3(cn_img)
|
875 |
+
|
876 |
+
# https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
|
877 |
+
cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
|
878 |
+
|
879 |
+
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
|
880 |
+
if debugging_cn_preprocessor:
|
881 |
+
yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
|
882 |
+
return
|
883 |
+
for task in cn_tasks[flags.cn_ip_face]:
|
884 |
+
cn_img, cn_stop, cn_weight = task
|
885 |
+
cn_img = HWC3(cn_img)
|
886 |
+
|
887 |
+
if not skipping_cn_preprocessor:
|
888 |
+
cn_img = face_crop.crop_image(cn_img)
|
889 |
+
|
890 |
+
# https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
|
891 |
+
cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
|
892 |
+
|
893 |
+
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
|
894 |
+
if debugging_cn_preprocessor:
|
895 |
+
yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
|
896 |
+
return
|
897 |
+
|
898 |
+
all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
|
899 |
+
|
900 |
+
if len(all_ip_tasks) > 0:
|
901 |
+
pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks)
|
902 |
+
|
903 |
+
if freeu_enabled:
|
904 |
+
logger.std_info('[Fooocus] FreeU is enabled!')
|
905 |
+
pipeline.final_unet = core.apply_freeu(
|
906 |
+
pipeline.final_unet,
|
907 |
+
freeu_b1,
|
908 |
+
freeu_b2,
|
909 |
+
freeu_s1,
|
910 |
+
freeu_s2
|
911 |
+
)
|
912 |
+
|
913 |
+
all_steps = steps * image_number
|
914 |
+
|
915 |
+
logger.std_info(f'[Parameters] Denoising Strength = {denoising_strength}')
|
916 |
+
|
917 |
+
if isinstance(initial_latent, dict) and 'samples' in initial_latent:
|
918 |
+
log_shape = initial_latent['samples'].shape
|
919 |
+
else:
|
920 |
+
log_shape = f'Image Space {(height, width)}'
|
921 |
+
|
922 |
+
logger.std_info(f'[Parameters] Initial Latent shape: {log_shape}')
|
923 |
+
|
924 |
+
preparation_time = time.perf_counter() - execution_start_time
|
925 |
+
logger.std_info(f'[Fooocus] Preparation time: {preparation_time:.2f} seconds')
|
926 |
+
|
927 |
+
final_sampler_name = sampler_name
|
928 |
+
final_scheduler_name = scheduler_name
|
929 |
+
|
930 |
+
if scheduler_name in ['lcm', 'tcd']:
|
931 |
+
final_scheduler_name = 'sgm_uniform'
|
932 |
+
|
933 |
+
def patch_discrete(unet):
|
934 |
+
return core.opModelSamplingDiscrete.patch(
|
935 |
+
pipeline.final_unet,
|
936 |
+
sampling=scheduler_name,
|
937 |
+
zsnr=False)[0]
|
938 |
+
|
939 |
+
if pipeline.final_unet is not None:
|
940 |
+
pipeline.final_unet = patch_discrete(pipeline.final_unet)
|
941 |
+
if pipeline.final_refiner_unet is not None:
|
942 |
+
pipeline.final_refiner_unet = patch_discrete(pipeline.final_refiner_unet)
|
943 |
+
logger.std_info(f'[Fooocus] Using {scheduler_name} scheduler.')
|
944 |
+
elif scheduler_name == 'edm_playground_v2.5':
|
945 |
+
final_scheduler_name = 'karras'
|
946 |
+
|
947 |
+
def patch_edm(unet):
|
948 |
+
return core.opModelSamplingContinuousEDM.patch(
|
949 |
+
unet,
|
950 |
+
sampling=scheduler_name,
|
951 |
+
sigma_max=120.0,
|
952 |
+
sigma_min=0.002)[0]
|
953 |
+
|
954 |
+
if pipeline.final_unet is not None:
|
955 |
+
pipeline.final_unet = patch_edm(pipeline.final_unet)
|
956 |
+
if pipeline.final_refiner_unet is not None:
|
957 |
+
pipeline.final_refiner_unet = patch_edm(pipeline.final_refiner_unet)
|
958 |
+
|
959 |
+
logger.std_info(f'[Fooocus] Using {scheduler_name} scheduler.')
|
960 |
+
|
961 |
+
outputs.append(['preview', (13, 'Moving model to GPU ...', None)])
|
962 |
+
|
963 |
+
def callback(step, x0, x, total_steps, y):
|
964 |
+
"""callback, used for progress and preview"""
|
965 |
+
done_steps = current_task_id * steps + step
|
966 |
+
outputs.append(['preview', (
|
967 |
+
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
|
968 |
+
f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling',
|
969 |
+
y)])
|
970 |
+
|
971 |
+
for current_task_id, task in enumerate(tasks):
|
972 |
+
execution_start_time = time.perf_counter()
|
973 |
+
|
974 |
+
try:
|
975 |
+
positive_cond, negative_cond = task['c'], task['uc']
|
976 |
+
|
977 |
+
if 'cn' in goals:
|
978 |
+
for cn_flag, cn_path in [
|
979 |
+
(flags.cn_canny, controlnet_canny_path),
|
980 |
+
(flags.cn_cpds, controlnet_cpds_path)
|
981 |
+
]:
|
982 |
+
for cn_img, cn_stop, cn_weight in cn_tasks[cn_flag]:
|
983 |
+
positive_cond, negative_cond = core.apply_controlnet(
|
984 |
+
positive_cond, negative_cond,
|
985 |
+
pipeline.loaded_ControlNets[cn_path], cn_img, cn_weight, 0, cn_stop)
|
986 |
+
|
987 |
+
imgs = pipeline.process_diffusion(
|
988 |
+
positive_cond=positive_cond,
|
989 |
+
negative_cond=negative_cond,
|
990 |
+
steps=steps,
|
991 |
+
switch=switch,
|
992 |
+
width=width,
|
993 |
+
height=height,
|
994 |
+
image_seed=task['task_seed'],
|
995 |
+
callback=callback,
|
996 |
+
sampler_name=final_sampler_name,
|
997 |
+
scheduler_name=final_scheduler_name,
|
998 |
+
latent=initial_latent,
|
999 |
+
denoise=denoising_strength,
|
1000 |
+
tiled=tiled,
|
1001 |
+
cfg_scale=cfg_scale,
|
1002 |
+
refiner_swap_method=refiner_swap_method,
|
1003 |
+
disable_preview=disable_preview
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
del task['c'], task['uc'], positive_cond, negative_cond # Save memory
|
1007 |
+
|
1008 |
+
if inpaint_worker.current_task is not None:
|
1009 |
+
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
|
1010 |
+
|
1011 |
+
# Fooocus async_worker.py code end
|
1012 |
+
|
1013 |
+
results += imgs
|
1014 |
+
except model_management.InterruptProcessingException as e:
|
1015 |
+
logger.std_warn("[Fooocus] User stopped")
|
1016 |
+
results = []
|
1017 |
+
results.append(ImageGenerationResult(
|
1018 |
+
im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.user_cancel))
|
1019 |
+
async_task.set_result(results, True, str(e))
|
1020 |
+
break
|
1021 |
+
except Exception as e:
|
1022 |
+
logger.std_error(f'[Fooocus] Process error: {e}')
|
1023 |
+
logging.exception(e)
|
1024 |
+
results = []
|
1025 |
+
results.append(ImageGenerationResult(
|
1026 |
+
im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.error))
|
1027 |
+
async_task.set_result(results, True, str(e))
|
1028 |
+
break
|
1029 |
+
|
1030 |
+
execution_time = time.perf_counter() - execution_start_time
|
1031 |
+
logger.std_info(f'[Fooocus] Generating and saving time: {execution_time:.2f} seconds')
|
1032 |
+
|
1033 |
+
if async_task.finish_with_error:
|
1034 |
+
worker_queue.finish_task(async_task.job_id)
|
1035 |
+
return async_task.task_result
|
1036 |
+
yield_result(None, results, tasks, save_extension, black_out_nsfw)
|
1037 |
+
return
|
1038 |
+
except Exception as e:
|
1039 |
+
logger.std_error(f'[Fooocus] Worker error: {e}')
|
1040 |
+
|
1041 |
+
if not async_task.is_finished:
|
1042 |
+
async_task.set_result([], True, str(e))
|
1043 |
+
worker_queue.finish_task(async_task.job_id)
|
1044 |
+
logger.std_info(f"[Task Queue] Finish task with error, job_id={async_task.job_id}")
|
main.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" Entry for Fooocus API.
|
4 |
+
|
5 |
+
Use for starting Fooocus API.
|
6 |
+
python main.py --help for more usage
|
7 |
+
|
8 |
+
@file: main.py
|
9 |
+
@author: Konie
|
10 |
+
@update: 2024-03-22
|
11 |
+
"""
|
12 |
+
import argparse
|
13 |
+
import os
|
14 |
+
import re
|
15 |
+
import shutil
|
16 |
+
import sys
|
17 |
+
from threading import Thread
|
18 |
+
|
19 |
+
from fooocusapi.utils.logger import logger
|
20 |
+
from fooocusapi.utils.tools import run_pip, check_torch_cuda, requirements_check
|
21 |
+
from fooocus_api_version import version
|
22 |
+
|
23 |
+
script_path = os.path.dirname(os.path.realpath(__file__))
|
24 |
+
module_path = os.path.join(script_path, "repositories/Fooocus")
|
25 |
+
|
26 |
+
sys.path.append(script_path)
|
27 |
+
sys.path.append(module_path)
|
28 |
+
|
29 |
+
logger.std_info("[System ARGV] " + str(sys.argv))
|
30 |
+
|
31 |
+
try:
|
32 |
+
index = sys.argv.index('--gpu-device-id')
|
33 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(sys.argv[index + 1])
|
34 |
+
logger.std_info(f"[Fooocus] Set device to: {str(sys.argv[index + 1])}")
|
35 |
+
except ValueError:
|
36 |
+
pass
|
37 |
+
|
38 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
39 |
+
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
|
40 |
+
|
41 |
+
python = sys.executable
|
42 |
+
default_command_live = True
|
43 |
+
index_url = os.environ.get("INDEX_URL", "")
|
44 |
+
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
45 |
+
|
46 |
+
|
47 |
+
def install_dependents(skip: bool = False):
|
48 |
+
"""
|
49 |
+
Check and install dependencies
|
50 |
+
Args:
|
51 |
+
skip: skip pip install
|
52 |
+
"""
|
53 |
+
if skip:
|
54 |
+
return
|
55 |
+
|
56 |
+
torch_index_url = os.environ.get(
|
57 |
+
"TORCH_INDEX_URL", "https://download.pytorch.org/whl/cu121"
|
58 |
+
)
|
59 |
+
|
60 |
+
# Check if you need pip install
|
61 |
+
if not requirements_check():
|
62 |
+
run_pip("install -r requirements.txt", "requirements")
|
63 |
+
|
64 |
+
if not check_torch_cuda():
|
65 |
+
run_pip(
|
66 |
+
f"install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}",
|
67 |
+
desc="torch",
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def preload_pipeline():
|
72 |
+
"""Preload pipeline"""
|
73 |
+
logger.std_info("[Fooocus-API] Preloading pipeline ...")
|
74 |
+
import modules.default_pipeline as _
|
75 |
+
|
76 |
+
|
77 |
+
def prepare_environments(args) -> bool:
|
78 |
+
"""
|
79 |
+
Prepare environments
|
80 |
+
Args:
|
81 |
+
args: command line arguments
|
82 |
+
"""
|
83 |
+
|
84 |
+
if args.base_url is None or len(args.base_url.strip()) == 0:
|
85 |
+
host = args.host
|
86 |
+
if host == "0.0.0.0":
|
87 |
+
host = "127.0.0.1"
|
88 |
+
args.base_url = f"http://{host}:{args.port}"
|
89 |
+
|
90 |
+
sys.argv = [sys.argv[0]]
|
91 |
+
|
92 |
+
# Remove and copy preset folder
|
93 |
+
origin_preset_folder = os.path.abspath(os.path.join(module_path, "presets"))
|
94 |
+
preset_folder = os.path.abspath(os.path.join(script_path, "presets"))
|
95 |
+
if os.path.exists(preset_folder):
|
96 |
+
shutil.rmtree(preset_folder)
|
97 |
+
shutil.copytree(origin_preset_folder, preset_folder)
|
98 |
+
|
99 |
+
from modules import config
|
100 |
+
from fooocusapi.configs import default
|
101 |
+
from fooocusapi.utils.model_loader import download_models
|
102 |
+
|
103 |
+
default.default_inpaint_engine_version = config.default_inpaint_engine_version
|
104 |
+
default.default_styles = config.default_styles
|
105 |
+
default.default_base_model_name = config.default_base_model_name
|
106 |
+
default.default_refiner_model_name = config.default_refiner_model_name
|
107 |
+
default.default_refiner_switch = config.default_refiner_switch
|
108 |
+
default.default_loras = config.default_loras
|
109 |
+
default.default_cfg_scale = config.default_cfg_scale
|
110 |
+
default.default_prompt_negative = config.default_prompt_negative
|
111 |
+
default.default_aspect_ratio = default.get_aspect_ratio_value(
|
112 |
+
config.default_aspect_ratio
|
113 |
+
)
|
114 |
+
default.available_aspect_ratios = [
|
115 |
+
default.get_aspect_ratio_value(a) for a in config.available_aspect_ratios
|
116 |
+
]
|
117 |
+
|
118 |
+
download_models()
|
119 |
+
|
120 |
+
# Init task queue
|
121 |
+
from fooocusapi import worker
|
122 |
+
from fooocusapi.task_queue import TaskQueue
|
123 |
+
|
124 |
+
worker.worker_queue = TaskQueue(
|
125 |
+
queue_size=args.queue_size,
|
126 |
+
history_size=args.queue_history,
|
127 |
+
webhook_url=args.webhook_url,
|
128 |
+
persistent=args.persistent,
|
129 |
+
)
|
130 |
+
|
131 |
+
logger.std_info(f"[Fooocus-API] Task queue size: {args.queue_size}")
|
132 |
+
logger.std_info(f"[Fooocus-API] Queue history size: {args.queue_history}")
|
133 |
+
logger.std_info(f"[Fooocus-API] Webhook url: {args.webhook_url}")
|
134 |
+
|
135 |
+
return True
|
136 |
+
|
137 |
+
|
138 |
+
def pre_setup():
|
139 |
+
"""
|
140 |
+
Pre setup, for replicate
|
141 |
+
"""
|
142 |
+
|
143 |
+
class Args(object):
|
144 |
+
"""
|
145 |
+
Arguments object
|
146 |
+
"""
|
147 |
+
host = "0.0.0.0"
|
148 |
+
port = 8000
|
149 |
+
base_url = None
|
150 |
+
sync_repo = "skip"
|
151 |
+
disable_image_log = True
|
152 |
+
skip_pip = True
|
153 |
+
preload_pipeline = True
|
154 |
+
queue_size = 100
|
155 |
+
queue_history = 0
|
156 |
+
preset = "default"
|
157 |
+
webhook_url = None
|
158 |
+
persistent = False
|
159 |
+
always_gpu = False
|
160 |
+
all_in_fp16 = False
|
161 |
+
gpu_device_id = None
|
162 |
+
apikey = None
|
163 |
+
|
164 |
+
print("[Pre Setup] Prepare environments")
|
165 |
+
|
166 |
+
arguments = Args()
|
167 |
+
sys.argv = [sys.argv[0]]
|
168 |
+
sys.argv.append("--disable-image-log")
|
169 |
+
|
170 |
+
install_dependents(arguments.skip_pip)
|
171 |
+
|
172 |
+
prepare_environments(arguments)
|
173 |
+
|
174 |
+
# Start task schedule thread
|
175 |
+
from fooocusapi.worker import task_schedule_loop
|
176 |
+
|
177 |
+
task_thread = Thread(target=task_schedule_loop, daemon=True)
|
178 |
+
task_thread.start()
|
179 |
+
|
180 |
+
print("[Pre Setup] Finished")
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
logger.std_info(f"[Fooocus API] Python {sys.version}")
|
185 |
+
logger.std_info(f"[Fooocus API] Fooocus API version: {version}")
|
186 |
+
|
187 |
+
from fooocusapi.base_args import add_base_args
|
188 |
+
|
189 |
+
parser = argparse.ArgumentParser()
|
190 |
+
add_base_args(parser, True)
|
191 |
+
|
192 |
+
args, _ = parser.parse_known_args()
|
193 |
+
install_dependents(skip=args.skip_pip)
|
194 |
+
|
195 |
+
from fooocusapi.args import args
|
196 |
+
|
197 |
+
if prepare_environments(args):
|
198 |
+
sys.argv = [sys.argv[0]]
|
199 |
+
|
200 |
+
# Load pipeline in new thread
|
201 |
+
preload_pipeline_thread = Thread(target=preload_pipeline, daemon=True)
|
202 |
+
preload_pipeline_thread.start()
|
203 |
+
|
204 |
+
# Start task schedule thread
|
205 |
+
from fooocusapi.worker import task_schedule_loop
|
206 |
+
|
207 |
+
task_schedule_thread = Thread(target=task_schedule_loop, daemon=True)
|
208 |
+
task_schedule_thread.start()
|
209 |
+
|
210 |
+
# Start api server
|
211 |
+
from fooocusapi.api import start_app
|
212 |
+
|
213 |
+
start_app(args)
|
mannequin_to_model.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Created By: ishwor subedi
|
3 |
+
Date: 2024-07-03
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
import base64
|
8 |
+
import requests
|
9 |
+
import numpy as np
|
10 |
+
from io import BytesIO
|
11 |
+
from PIL import Image
|
12 |
+
from fastapi import File, UploadFile, Form, Depends
|
13 |
+
from fastapi import APIRouter
|
14 |
+
from fastapi.responses import JSONResponse
|
15 |
+
|
16 |
+
from fooocusapi.utils.api_utils import api_key_auth
|
17 |
+
from src.pipeline.main_pipeline import MainPipeline
|
18 |
+
from supabase import create_client
|
19 |
+
|
20 |
+
secure_router = APIRouter(dependencies=[Depends(api_key_auth)])
|
21 |
+
|
22 |
+
pipeline = MainPipeline()
|
23 |
+
|
24 |
+
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
25 |
+
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
26 |
+
|
27 |
+
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
|
28 |
+
|
29 |
+
|
30 |
+
@secure_router.post("/mannequin_to_model")
|
31 |
+
async def mannequinToModel(
|
32 |
+
store_name: str = Form(...),
|
33 |
+
clothing_category: str = Form(...),
|
34 |
+
product_id: str = Form(...),
|
35 |
+
body_structure: str = Form(...),
|
36 |
+
skin_complexion: str = Form(...),
|
37 |
+
facial_structure: str = Form(...),
|
38 |
+
person_img: UploadFile = File(...)
|
39 |
+
):
|
40 |
+
if body_structure == "medium":
|
41 |
+
body_structure = "fat"
|
42 |
+
|
43 |
+
person_image = await person_img.read()
|
44 |
+
|
45 |
+
mannequin_image_url = f"{SUPABASE_URL}/storage/v1/object/public/ClothingTryOn/{store_name}/{clothing_category}/{product_id}/{product_id}_{skin_complexion}_{facial_structure}_{body_structure}.webp"
|
46 |
+
mannequin_img = read_return(mannequin_image_url)
|
47 |
+
|
48 |
+
try:
|
49 |
+
person_image, mannequin_image = Image.open(BytesIO(person_image)), Image.open(BytesIO(mannequin_img))
|
50 |
+
except:
|
51 |
+
|
52 |
+
error_message = {
|
53 |
+
"code": 404,
|
54 |
+
"error": "The requested resource is not available. Please verify the availability and try again."
|
55 |
+
|
56 |
+
}
|
57 |
+
|
58 |
+
return JSONResponse(content=error_message, status_code=404)
|
59 |
+
|
60 |
+
mannequin_image = cv2.cvtColor(np.array(mannequin_image), cv2.COLOR_RGB2BGR)
|
61 |
+
person_image = cv2.cvtColor(np.array(person_image), cv2.COLOR_RGB2BGR)
|
62 |
+
result = pipeline.face_swap(mannequin_image, person_image, enhance=False)
|
63 |
+
result = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
64 |
+
inMemFile = BytesIO()
|
65 |
+
result.save(inMemFile, format="WEBP", quality=85)
|
66 |
+
outputBytes = inMemFile.getvalue()
|
67 |
+
response = {
|
68 |
+
"code": 200,
|
69 |
+
"output": f"data:image/WEBP;base64,{base64.b64encode(outputBytes).decode('utf-8')}",
|
70 |
+
}
|
71 |
+
|
72 |
+
return response
|
73 |
+
|
74 |
+
|
75 |
+
@secure_router.get("/mannequin_catalogue")
|
76 |
+
async def returnJsonData(gender: str):
|
77 |
+
folderImageURL = supabase.storage.get_bucket("JSON").create_signed_url(
|
78 |
+
path=os.path.join("MannequinInfo.json"), expires_in=3600)["signedURL"]
|
79 |
+
r = requests.get(folderImageURL).content.decode()
|
80 |
+
|
81 |
+
mannequin_data = json.loads(r)
|
82 |
+
|
83 |
+
if gender.lower() == "female":
|
84 |
+
res = [item for item in mannequin_data if item["gender"] == "female"]
|
85 |
+
elif gender.lower() == "male":
|
86 |
+
res = [item for item in mannequin_data if item["gender"] == "male"]
|
87 |
+
else:
|
88 |
+
res = []
|
89 |
+
|
90 |
+
return res
|
predict.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Prediction interface for Cog ⚙️
|
3 |
+
https://github.com/replicate/cog/blob/main/docs/python.md
|
4 |
+
"""
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import os
|
8 |
+
from typing import List
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
from cog import BasePredictor, BaseModel, Input, Path
|
13 |
+
from fooocusapi.utils.lora_manager import LoraManager
|
14 |
+
from fooocusapi.utils.file_utils import output_dir
|
15 |
+
from fooocusapi.models.common.task import GenerationFinishReason
|
16 |
+
from fooocusapi.configs.default import (
|
17 |
+
available_aspect_ratios,
|
18 |
+
uov_methods,
|
19 |
+
outpaint_expansions,
|
20 |
+
default_styles,
|
21 |
+
default_base_model_name,
|
22 |
+
default_refiner_model_name,
|
23 |
+
default_loras,
|
24 |
+
default_refiner_switch,
|
25 |
+
default_cfg_scale,
|
26 |
+
default_prompt_negative
|
27 |
+
)
|
28 |
+
|
29 |
+
from fooocusapi.parameters import ImageGenerationParams
|
30 |
+
from fooocusapi.task_queue import TaskType
|
31 |
+
|
32 |
+
|
33 |
+
class Output(BaseModel):
|
34 |
+
"""
|
35 |
+
Output model
|
36 |
+
"""
|
37 |
+
seeds: List[str]
|
38 |
+
paths: List[Path]
|
39 |
+
|
40 |
+
|
41 |
+
class Predictor(BasePredictor):
|
42 |
+
"""Predictor"""
|
43 |
+
def setup(self) -> None:
|
44 |
+
"""
|
45 |
+
Load the model into memory to make running multiple predictions efficient
|
46 |
+
"""
|
47 |
+
from main import pre_setup
|
48 |
+
pre_setup()
|
49 |
+
|
50 |
+
def predict(
|
51 |
+
self,
|
52 |
+
prompt: str = Input(
|
53 |
+
default='',
|
54 |
+
description="Prompt for image generation"),
|
55 |
+
negative_prompt: str = Input(
|
56 |
+
default=default_prompt_negative,
|
57 |
+
description="Negative prompt for image generation"),
|
58 |
+
style_selections: str = Input(
|
59 |
+
default=','.join(default_styles),
|
60 |
+
description="Fooocus styles applied for image generation, separated by comma"),
|
61 |
+
performance_selection: str = Input(
|
62 |
+
default='Speed',
|
63 |
+
choices=['Speed', 'Quality', 'Extreme Speed', 'Lightning'],
|
64 |
+
description="Performance selection"),
|
65 |
+
aspect_ratios_selection: str = Input(
|
66 |
+
default='1152*896',
|
67 |
+
choices=available_aspect_ratios,
|
68 |
+
description="The generated image's size"),
|
69 |
+
image_number: int = Input(
|
70 |
+
default=1,
|
71 |
+
ge=1, le=8,
|
72 |
+
description="How many image to generate"),
|
73 |
+
image_seed: int = Input(
|
74 |
+
default=-1,
|
75 |
+
description="Seed to generate image, -1 for random"),
|
76 |
+
use_default_loras: bool = Input(
|
77 |
+
default=True,
|
78 |
+
description="Use default LoRAs"),
|
79 |
+
loras_custom_urls: str = Input(
|
80 |
+
default="",
|
81 |
+
description="Custom LoRAs URLs in the format 'url,weight' provide multiple separated by ; (example 'url1,0.3;url2,0.1')"),
|
82 |
+
sharpness: float = Input(
|
83 |
+
default=2.0,
|
84 |
+
ge=0.0, le=30.0),
|
85 |
+
guidance_scale: float = Input(
|
86 |
+
default=default_cfg_scale,
|
87 |
+
ge=1.0, le=30.0),
|
88 |
+
refiner_switch: float = Input(
|
89 |
+
default=default_refiner_switch,
|
90 |
+
ge=0.1, le=1.0),
|
91 |
+
uov_input_image: Path = Input(
|
92 |
+
default=None,
|
93 |
+
description="Input image for upscale or variation, keep None for not upscale or variation"),
|
94 |
+
uov_method: str = Input(
|
95 |
+
default='Disabled',
|
96 |
+
choices=uov_methods),
|
97 |
+
uov_upscale_value: float = Input(
|
98 |
+
default=0,
|
99 |
+
description="Only when Upscale (Custom)"),
|
100 |
+
inpaint_additional_prompt: str = Input(
|
101 |
+
default='',
|
102 |
+
description="Prompt for image generation"),
|
103 |
+
inpaint_input_image: Path = Input(
|
104 |
+
default=None,
|
105 |
+
description="Input image for inpaint or outpaint, keep None for not inpaint or outpaint. Please noticed, `uov_input_image` has bigger priority is not None."),
|
106 |
+
inpaint_input_mask: Path = Input(
|
107 |
+
default=None,
|
108 |
+
description="Input mask for inpaint"),
|
109 |
+
outpaint_selections: str = Input(
|
110 |
+
default='',
|
111 |
+
description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
|
112 |
+
outpaint_distance_left: int = Input(
|
113 |
+
default=0,
|
114 |
+
description="Outpaint expansion distance from Left of the image"),
|
115 |
+
outpaint_distance_top: int = Input(
|
116 |
+
default=0,
|
117 |
+
description="Outpaint expansion distance from Top of the image"),
|
118 |
+
outpaint_distance_right: int = Input(
|
119 |
+
default=0,
|
120 |
+
description="Outpaint expansion distance from Right of the image"),
|
121 |
+
outpaint_distance_bottom: int = Input(
|
122 |
+
default=0,
|
123 |
+
description="Outpaint expansion distance from Bottom of the image"),
|
124 |
+
cn_img1: Path = Input(
|
125 |
+
default=None,
|
126 |
+
description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
|
127 |
+
cn_stop1: float = Input(
|
128 |
+
default=None,
|
129 |
+
ge=0, le=1,
|
130 |
+
description="Stop at for image prompt, None for default value"),
|
131 |
+
cn_weight1: float = Input(
|
132 |
+
default=None,
|
133 |
+
ge=0, le=2,
|
134 |
+
description="Weight for image prompt, None for default value"),
|
135 |
+
cn_type1: str = Input(
|
136 |
+
default='ImagePrompt',
|
137 |
+
choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
|
138 |
+
description="ControlNet type for image prompt"),
|
139 |
+
cn_img2: Path = Input(
|
140 |
+
default=None,
|
141 |
+
description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
|
142 |
+
cn_stop2: float = Input(
|
143 |
+
default=None,
|
144 |
+
ge=0, le=1,
|
145 |
+
description="Stop at for image prompt, None for default value"),
|
146 |
+
cn_weight2: float = Input(
|
147 |
+
default=None,
|
148 |
+
ge=0, le=2,
|
149 |
+
description="Weight for image prompt, None for default value"),
|
150 |
+
cn_type2: str = Input(
|
151 |
+
default='ImagePrompt',
|
152 |
+
choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
|
153 |
+
description="ControlNet type for image prompt"),
|
154 |
+
cn_img3: Path = Input(
|
155 |
+
default=None,
|
156 |
+
description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
|
157 |
+
cn_stop3: float = Input(
|
158 |
+
default=None,
|
159 |
+
ge=0, le=1,
|
160 |
+
description="Stop at for image prompt, None for default value"),
|
161 |
+
cn_weight3: float = Input(
|
162 |
+
default=None,
|
163 |
+
ge=0, le=2,
|
164 |
+
description="Weight for image prompt, None for default value"),
|
165 |
+
cn_type3: str = Input(
|
166 |
+
default='ImagePrompt',
|
167 |
+
choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
|
168 |
+
description="ControlNet type for image prompt"),
|
169 |
+
cn_img4: Path = Input(
|
170 |
+
default=None,
|
171 |
+
description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
|
172 |
+
cn_stop4: float = Input(
|
173 |
+
default=None,
|
174 |
+
ge=0, le=1,
|
175 |
+
description="Stop at for image prompt, None for default value"),
|
176 |
+
cn_weight4: float = Input(
|
177 |
+
default=None,
|
178 |
+
ge=0, le=2,
|
179 |
+
description="Weight for image prompt, None for default value"),
|
180 |
+
cn_type4: str = Input(
|
181 |
+
default='ImagePrompt',
|
182 |
+
choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
|
183 |
+
description="ControlNet type for image prompt")
|
184 |
+
) -> Output:
|
185 |
+
"""Run a single prediction on the model"""
|
186 |
+
from modules import flags
|
187 |
+
from modules.sdxl_styles import legal_style_names
|
188 |
+
from fooocusapi.worker import blocking_get_task_result, worker_queue
|
189 |
+
|
190 |
+
base_model_name = default_base_model_name
|
191 |
+
refiner_model_name = default_refiner_model_name
|
192 |
+
|
193 |
+
lora_manager = LoraManager()
|
194 |
+
|
195 |
+
# Use default loras if selected
|
196 |
+
loras = copy.copy(default_loras) if use_default_loras else []
|
197 |
+
|
198 |
+
# add custom user loras if provided
|
199 |
+
if loras_custom_urls:
|
200 |
+
urls = [url.strip() for url in loras_custom_urls.split(';')]
|
201 |
+
|
202 |
+
loras_with_weights = [url.split(',') for url in urls]
|
203 |
+
|
204 |
+
custom_lora_paths = lora_manager.check([lw[0] for lw in loras_with_weights])
|
205 |
+
custom_loras = [[path, float(lw[1]) if len(lw) > 1 else 1.0] for path, lw in
|
206 |
+
zip(custom_lora_paths, loras_with_weights)]
|
207 |
+
|
208 |
+
loras.extend(custom_loras)
|
209 |
+
|
210 |
+
style_selections_arr = []
|
211 |
+
for s in style_selections.strip().split(','):
|
212 |
+
style = s.strip()
|
213 |
+
if style in legal_style_names:
|
214 |
+
style_selections_arr.append(style)
|
215 |
+
|
216 |
+
if uov_input_image is not None:
|
217 |
+
im = Image.open(str(uov_input_image))
|
218 |
+
uov_input_image = np.array(im)
|
219 |
+
|
220 |
+
inpaint_input_image_dict = None
|
221 |
+
if inpaint_input_image is not None:
|
222 |
+
im = Image.open(str(inpaint_input_image))
|
223 |
+
inpaint_input_image = np.array(im)
|
224 |
+
|
225 |
+
if inpaint_input_mask is not None:
|
226 |
+
im = Image.open(str(inpaint_input_mask))
|
227 |
+
inpaint_input_mask = np.array(im)
|
228 |
+
|
229 |
+
inpaint_input_image_dict = {
|
230 |
+
'image': inpaint_input_image,
|
231 |
+
'mask': inpaint_input_mask
|
232 |
+
}
|
233 |
+
|
234 |
+
outpaint_selections_arr = []
|
235 |
+
for e in outpaint_selections.strip().split(','):
|
236 |
+
expansion = e.strip()
|
237 |
+
if expansion in outpaint_expansions:
|
238 |
+
outpaint_selections_arr.append(expansion)
|
239 |
+
|
240 |
+
image_prompts = []
|
241 |
+
image_prompt_config = [
|
242 |
+
(cn_img1, cn_stop1, cn_weight1, cn_type1),
|
243 |
+
(cn_img2, cn_stop2, cn_weight2, cn_type2),
|
244 |
+
(cn_img3, cn_stop3, cn_weight3, cn_type3),
|
245 |
+
(cn_img4, cn_stop4, cn_weight4, cn_type4)]
|
246 |
+
for config in image_prompt_config:
|
247 |
+
cn_img, cn_stop, cn_weight, cn_type = config
|
248 |
+
if cn_img is not None:
|
249 |
+
im = Image.open(str(cn_img))
|
250 |
+
cn_img = np.array(im)
|
251 |
+
if cn_stop is None:
|
252 |
+
cn_stop = flags.default_parameters[cn_type][0]
|
253 |
+
if cn_weight is None:
|
254 |
+
cn_weight = flags.default_parameters[cn_type][1]
|
255 |
+
image_prompts.append((cn_img, cn_stop, cn_weight, cn_type))
|
256 |
+
|
257 |
+
advanced_params = None
|
258 |
+
|
259 |
+
params = ImageGenerationParams(
|
260 |
+
prompt=prompt,
|
261 |
+
negative_prompt=negative_prompt,
|
262 |
+
style_selections=style_selections_arr,
|
263 |
+
performance_selection=performance_selection,
|
264 |
+
aspect_ratios_selection=aspect_ratios_selection,
|
265 |
+
image_number=image_number,
|
266 |
+
image_seed=image_seed,
|
267 |
+
sharpness=sharpness,
|
268 |
+
guidance_scale=guidance_scale,
|
269 |
+
base_model_name=base_model_name,
|
270 |
+
refiner_model_name=refiner_model_name,
|
271 |
+
refiner_switch=refiner_switch,
|
272 |
+
loras=loras,
|
273 |
+
uov_input_image=uov_input_image,
|
274 |
+
uov_method=uov_method,
|
275 |
+
upscale_value=uov_upscale_value,
|
276 |
+
outpaint_selections=outpaint_selections_arr,
|
277 |
+
inpaint_input_image=inpaint_input_image_dict,
|
278 |
+
image_prompts=image_prompts,
|
279 |
+
advanced_params=advanced_params,
|
280 |
+
inpaint_additional_prompt=inpaint_additional_prompt,
|
281 |
+
outpaint_distance_left=outpaint_distance_left,
|
282 |
+
outpaint_distance_top=outpaint_distance_top,
|
283 |
+
outpaint_distance_right=outpaint_distance_right,
|
284 |
+
outpaint_distance_bottom=outpaint_distance_bottom,
|
285 |
+
save_meta=True,
|
286 |
+
meta_scheme='fooocus',
|
287 |
+
save_extension='png',
|
288 |
+
save_name='',
|
289 |
+
require_base64=False,
|
290 |
+
)
|
291 |
+
|
292 |
+
print(f"[Predictor Predict] Params: {params.__dict__}")
|
293 |
+
|
294 |
+
async_task = worker_queue.add_task(
|
295 |
+
TaskType.text_2_img,
|
296 |
+
params)
|
297 |
+
|
298 |
+
if async_task is None:
|
299 |
+
print("[Task Queue] The task queue has reached limit")
|
300 |
+
raise Exception("The task queue has reached limit.")
|
301 |
+
|
302 |
+
results = blocking_get_task_result(async_task.job_id)
|
303 |
+
|
304 |
+
output_paths: List[Path] = []
|
305 |
+
output_seeds: List[str] = []
|
306 |
+
for r in results:
|
307 |
+
if r.finish_reason == GenerationFinishReason.success and r.im is not None:
|
308 |
+
output_seeds.append(r.seed)
|
309 |
+
output_paths.append(Path(os.path.join(output_dir, r.im)))
|
310 |
+
|
311 |
+
print(f"[Predictor Predict] Finished with {len(output_paths)} images")
|
312 |
+
|
313 |
+
if len(output_paths) == 0:
|
314 |
+
raise Exception("Process failed.")
|
315 |
+
|
316 |
+
return Output(seeds=output_seeds, paths=output_paths)
|
repositories/Fooocus/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Created By: ishwor subedi
|
3 |
+
Date: 2024-07-19
|
4 |
+
"""
|
repositories/Fooocus/args_manager.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ldm_patched.modules.args_parser as args_parser
|
2 |
+
|
3 |
+
args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.")
|
4 |
+
|
5 |
+
args_parser.parser.add_argument("--preset", type=str, default=None, help="Apply specified UI preset.")
|
6 |
+
args_parser.parser.add_argument("--disable-preset-selection", action='store_true',
|
7 |
+
help="Disables preset selection in Gradio.")
|
8 |
+
|
9 |
+
args_parser.parser.add_argument("--language", type=str, default='default',
|
10 |
+
help="Translate UI using json files in [language] folder. "
|
11 |
+
"For example, [--language example] will use [language/example.json] for translation.")
|
12 |
+
|
13 |
+
# For example, https://github.com/lllyasviel/Fooocus/issues/849
|
14 |
+
args_parser.parser.add_argument("--disable-offload-from-vram", action="store_true",
|
15 |
+
help="Force loading models to vram when the unload can be avoided. "
|
16 |
+
"Some Mac users may need this.")
|
17 |
+
|
18 |
+
args_parser.parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
19 |
+
args_parser.parser.add_argument("--disable-image-log", action='store_true',
|
20 |
+
help="Prevent writing images and logs to hard drive.")
|
21 |
+
|
22 |
+
args_parser.parser.add_argument("--disable-analytics", action='store_true',
|
23 |
+
help="Disables analytics for Gradio.")
|
24 |
+
|
25 |
+
args_parser.parser.add_argument("--disable-metadata", action='store_true',
|
26 |
+
help="Disables saving metadata to images.")
|
27 |
+
|
28 |
+
args_parser.parser.add_argument("--disable-preset-download", action='store_true',
|
29 |
+
help="Disables downloading models for presets", default=False)
|
30 |
+
|
31 |
+
args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
|
32 |
+
help="Disables automatic description of uov images when prompt is empty", default=False)
|
33 |
+
|
34 |
+
args_parser.parser.add_argument("--always-download-new-model", action='store_true',
|
35 |
+
help="Always download newer models ", default=False)
|
36 |
+
|
37 |
+
args_parser.parser.set_defaults(
|
38 |
+
disable_cuda_malloc=True,
|
39 |
+
in_browser=True,
|
40 |
+
port=None
|
41 |
+
)
|
42 |
+
|
43 |
+
args_parser.args = args_parser.parser.parse_args()
|
44 |
+
|
45 |
+
# (Disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724)
|
46 |
+
args_parser.args.always_offload_from_vram = not args_parser.args.disable_offload_from_vram
|
47 |
+
|
48 |
+
if args_parser.args.disable_analytics:
|
49 |
+
import os
|
50 |
+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
51 |
+
|
52 |
+
if args_parser.args.disable_in_browser:
|
53 |
+
args_parser.args.in_browser = False
|
54 |
+
|
55 |
+
args = args_parser.args
|
repositories/Fooocus/extras/BLIP/configs/bert_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30522,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
repositories/Fooocus/extras/BLIP/configs/caption_coco.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
coco_gt_root: 'annotation/coco_gt'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
vit: 'base'
|
10 |
+
vit_grad_ckpt: False
|
11 |
+
vit_ckpt_layer: 0
|
12 |
+
batch_size: 32
|
13 |
+
init_lr: 1e-5
|
14 |
+
|
15 |
+
# vit: 'large'
|
16 |
+
# vit_grad_ckpt: True
|
17 |
+
# vit_ckpt_layer: 5
|
18 |
+
# batch_size: 16
|
19 |
+
# init_lr: 2e-6
|
20 |
+
|
21 |
+
image_size: 384
|
22 |
+
|
23 |
+
# generation configs
|
24 |
+
max_length: 20
|
25 |
+
min_length: 5
|
26 |
+
num_beams: 3
|
27 |
+
prompt: 'a picture of '
|
28 |
+
|
29 |
+
# optimizer
|
30 |
+
weight_decay: 0.05
|
31 |
+
min_lr: 0
|
32 |
+
max_epoch: 5
|
33 |
+
|
repositories/Fooocus/extras/BLIP/configs/med_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
repositories/Fooocus/extras/BLIP/configs/nlvr.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/NLVR2/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
|
6 |
+
|
7 |
+
#size of vit model; base or large
|
8 |
+
vit: 'base'
|
9 |
+
batch_size_train: 16
|
10 |
+
batch_size_test: 64
|
11 |
+
vit_grad_ckpt: False
|
12 |
+
vit_ckpt_layer: 0
|
13 |
+
max_epoch: 15
|
14 |
+
|
15 |
+
image_size: 384
|
16 |
+
|
17 |
+
# optimizer
|
18 |
+
weight_decay: 0.05
|
19 |
+
init_lr: 3e-5
|
20 |
+
min_lr: 0
|
21 |
+
|
repositories/Fooocus/extras/BLIP/configs/nocaps.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/nocaps/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
6 |
+
|
7 |
+
vit: 'base'
|
8 |
+
batch_size: 32
|
9 |
+
|
10 |
+
image_size: 384
|
11 |
+
|
12 |
+
max_length: 20
|
13 |
+
min_length: 5
|
14 |
+
num_beams: 3
|
15 |
+
prompt: 'a picture of '
|
repositories/Fooocus/extras/BLIP/configs/pretrain.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
|
2 |
+
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
|
3 |
+
]
|
4 |
+
laion_path: ''
|
5 |
+
|
6 |
+
# size of vit model; base or large
|
7 |
+
vit: 'base'
|
8 |
+
vit_grad_ckpt: False
|
9 |
+
vit_ckpt_layer: 0
|
10 |
+
|
11 |
+
image_size: 224
|
12 |
+
batch_size: 75
|
13 |
+
|
14 |
+
queue_size: 57600
|
15 |
+
alpha: 0.4
|
16 |
+
|
17 |
+
# optimizer
|
18 |
+
weight_decay: 0.05
|
19 |
+
init_lr: 3e-4
|
20 |
+
min_lr: 1e-6
|
21 |
+
warmup_lr: 1e-6
|
22 |
+
lr_decay_rate: 0.9
|
23 |
+
max_epoch: 20
|
24 |
+
warmup_steps: 3000
|
25 |
+
|
26 |
+
|
27 |
+
|
repositories/Fooocus/extras/BLIP/configs/retrieval_coco.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
dataset: 'coco'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 32
|
12 |
+
batch_size_test: 64
|
13 |
+
vit_grad_ckpt: True
|
14 |
+
vit_ckpt_layer: 4
|
15 |
+
init_lr: 1e-5
|
16 |
+
|
17 |
+
# vit: 'large'
|
18 |
+
# batch_size_train: 16
|
19 |
+
# batch_size_test: 32
|
20 |
+
# vit_grad_ckpt: True
|
21 |
+
# vit_ckpt_layer: 12
|
22 |
+
# init_lr: 5e-6
|
23 |
+
|
24 |
+
image_size: 384
|
25 |
+
queue_size: 57600
|
26 |
+
alpha: 0.4
|
27 |
+
k_test: 256
|
28 |
+
negative_all_rank: True
|
29 |
+
|
30 |
+
# optimizer
|
31 |
+
weight_decay: 0.05
|
32 |
+
min_lr: 0
|
33 |
+
max_epoch: 6
|
34 |
+
|
repositories/Fooocus/extras/BLIP/configs/retrieval_flickr.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/flickr30k/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
dataset: 'flickr'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 32
|
12 |
+
batch_size_test: 64
|
13 |
+
vit_grad_ckpt: True
|
14 |
+
vit_ckpt_layer: 4
|
15 |
+
init_lr: 1e-5
|
16 |
+
|
17 |
+
# vit: 'large'
|
18 |
+
# batch_size_train: 16
|
19 |
+
# batch_size_test: 32
|
20 |
+
# vit_grad_ckpt: True
|
21 |
+
# vit_ckpt_layer: 10
|
22 |
+
# init_lr: 5e-6
|
23 |
+
|
24 |
+
image_size: 384
|
25 |
+
queue_size: 57600
|
26 |
+
alpha: 0.4
|
27 |
+
k_test: 128
|
28 |
+
negative_all_rank: False
|
29 |
+
|
30 |
+
# optimizer
|
31 |
+
weight_decay: 0.05
|
32 |
+
min_lr: 0
|
33 |
+
max_epoch: 6
|
34 |
+
|
repositories/Fooocus/extras/BLIP/configs/retrieval_msrvtt.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
6 |
+
|
7 |
+
# size of vit model; base or large
|
8 |
+
vit: 'base'
|
9 |
+
batch_size: 64
|
10 |
+
k_test: 128
|
11 |
+
image_size: 384
|
12 |
+
num_frm_test: 8
|
repositories/Fooocus/extras/BLIP/configs/vqa.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
|
2 |
+
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
|
3 |
+
train_files: ['vqa_train','vqa_val','vg_qa']
|
4 |
+
ann_root: 'annotation'
|
5 |
+
|
6 |
+
# set pretrained as a file path or an url
|
7 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
|
8 |
+
|
9 |
+
# size of vit model; base or large
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 16
|
12 |
+
batch_size_test: 32
|
13 |
+
vit_grad_ckpt: False
|
14 |
+
vit_ckpt_layer: 0
|
15 |
+
init_lr: 2e-5
|
16 |
+
|
17 |
+
image_size: 480
|
18 |
+
|
19 |
+
k_test: 128
|
20 |
+
inference: 'rank'
|
21 |
+
|
22 |
+
# optimizer
|
23 |
+
weight_decay: 0.05
|
24 |
+
min_lr: 0
|
25 |
+
max_epoch: 10
|
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"gradient_checkpointing": false,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 3072,
|
12 |
+
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 512,
|
14 |
+
"model_type": "bert",
|
15 |
+
"num_attention_heads": 12,
|
16 |
+
"num_hidden_layers": 12,
|
17 |
+
"pad_token_id": 0,
|
18 |
+
"position_embedding_type": "absolute",
|
19 |
+
"transformers_version": "4.6.0.dev0",
|
20 |
+
"type_vocab_size": 2,
|
21 |
+
"use_cache": true,
|
22 |
+
"vocab_size": 30522
|
23 |
+
}
|
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_lower_case": true
|
3 |
+
}
|
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|