ishworrsubedii commited on
Commit
36cd99b
1 Parent(s): 7277638

Updated the latest changes

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/workflows/dockerhub.yaml +39 -0
  2. .gitignore +61 -0
  3. Dockerfile +22 -0
  4. fooocus_api_version.py +1 -0
  5. fooocusapi/api.py +41 -0
  6. fooocusapi/args.py +20 -0
  7. fooocusapi/base_args.py +27 -0
  8. fooocusapi/configs/default.py +92 -0
  9. fooocusapi/models/common/base.py +189 -0
  10. fooocusapi/models/common/image_meta.py +118 -0
  11. fooocusapi/models/common/requests.py +132 -0
  12. fooocusapi/models/common/response.py +90 -0
  13. fooocusapi/models/common/task.py +60 -0
  14. fooocusapi/models/requests_v1.py +274 -0
  15. fooocusapi/models/requests_v2.py +50 -0
  16. fooocusapi/parameters.py +94 -0
  17. fooocusapi/routes/__init__.py +0 -0
  18. fooocusapi/routes/generate_v1.py +186 -0
  19. fooocusapi/routes/generate_v2.py +199 -0
  20. fooocusapi/routes/query.py +135 -0
  21. fooocusapi/sql_client.py +269 -0
  22. fooocusapi/task_queue.py +323 -0
  23. fooocusapi/utils/api_utils.py +291 -0
  24. fooocusapi/utils/call_worker.py +97 -0
  25. fooocusapi/utils/file_utils.py +143 -0
  26. fooocusapi/utils/img_utils.py +198 -0
  27. fooocusapi/utils/logger.py +132 -0
  28. fooocusapi/utils/lora_manager.py +71 -0
  29. fooocusapi/utils/model_loader.py +46 -0
  30. fooocusapi/utils/tools.py +159 -0
  31. fooocusapi/worker.py +1044 -0
  32. main.py +213 -0
  33. mannequin_to_model.py +90 -0
  34. predict.py +316 -0
  35. repositories/Fooocus/__init__.py +4 -0
  36. repositories/Fooocus/args_manager.py +55 -0
  37. repositories/Fooocus/extras/BLIP/configs/bert_config.json +21 -0
  38. repositories/Fooocus/extras/BLIP/configs/caption_coco.yaml +33 -0
  39. repositories/Fooocus/extras/BLIP/configs/med_config.json +21 -0
  40. repositories/Fooocus/extras/BLIP/configs/nlvr.yaml +21 -0
  41. repositories/Fooocus/extras/BLIP/configs/nocaps.yaml +15 -0
  42. repositories/Fooocus/extras/BLIP/configs/pretrain.yaml +27 -0
  43. repositories/Fooocus/extras/BLIP/configs/retrieval_coco.yaml +34 -0
  44. repositories/Fooocus/extras/BLIP/configs/retrieval_flickr.yaml +34 -0
  45. repositories/Fooocus/extras/BLIP/configs/retrieval_msrvtt.yaml +12 -0
  46. repositories/Fooocus/extras/BLIP/configs/vqa.yaml +25 -0
  47. repositories/Fooocus/extras/BLIP/models/bert_tokenizer/config.json +23 -0
  48. repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer.json +0 -0
  49. repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer_config.json +3 -0
  50. repositories/Fooocus/extras/BLIP/models/bert_tokenizer/vocab.txt +0 -0
.github/workflows/dockerhub.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish MannequinToModel Docker image
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+
7
+ jobs:
8
+ push_to_registry:
9
+ name: Push Docker image to Docker Hub
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ packages: write
13
+ contents: read
14
+ attestations: write
15
+ steps:
16
+ - name: Check out the repo
17
+ uses: actions/checkout@v4
18
+
19
+ - name: Log in to Docker Hub
20
+ uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a
21
+ with:
22
+ username: ${{ secrets.DOCKER_USERNAME }}
23
+ password: ${{ secrets.DOCKER_PASSWORD }}
24
+
25
+ - name: Extract metadata (tags, labels) for Docker
26
+ id: meta
27
+ uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
28
+ with:
29
+ images: techconsp/mannequin_to_model
30
+
31
+ - name: Build and push Docker image
32
+ id: push
33
+ uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
34
+ with:
35
+ context: .
36
+ file: ./Dockerfile
37
+ push: true
38
+ tags: ${{ steps.meta.outputs.tags }}
39
+ labels: ${{ steps.meta.outputs.labels }}
.gitignore ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # Jupyter Notebook
28
+ .ipynb_checkpoints/
29
+ *.ipynb_checkpoints/
30
+ # Exclude if necessary for security reasons:
31
+ #*.ipynb
32
+
33
+ # Virtual environments
34
+ venv/
35
+ env/
36
+ ENV/
37
+ .venv/
38
+ .ENV/
39
+
40
+ # IDEs
41
+ .idea/
42
+ .vscode/
43
+ *.sublime-project
44
+ *.sublime-workspace
45
+
46
+ htmlcov/
47
+ .coverage
48
+ .coverage.*
49
+ .cache/
50
+
51
+
52
+ .env
53
+
54
+ yolov8m-seg.pt
55
+ .yoloface
56
+ logs
57
+ experiments
58
+ artifacts
59
+ examples
60
+ resources
61
+ *.safetensors
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /mannequin_to_model
4
+
5
+ COPY . /mannequin_to_model
6
+
7
+ RUN apt-get update && apt-get install -y \
8
+ libgl1-mesa-glx \
9
+ ffmpeg \
10
+ libsm6 \
11
+ libxext6 \
12
+ build-essential \
13
+ git \
14
+ && apt-get clean
15
+
16
+ RUN pip install --no-cache-dir --upgrade pip==23.0
17
+
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ EXPOSE 8000
21
+
22
+ CMD ["python", "main.py"]
fooocus_api_version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ version = '0.4.1.1'
fooocusapi/api.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry for startup fastapi server
3
+ """
4
+ from fastapi import FastAPI
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+
8
+ import uvicorn
9
+
10
+ from fooocusapi.utils import file_utils
11
+ from fooocusapi.routes.generate_v1 import secure_router as generate_v1
12
+ from fooocusapi.routes.generate_v2 import secure_router as generate_v2
13
+ from fooocusapi.routes.query import secure_router as query
14
+ from mannequin_to_model import secure_router as mannequin_to_model
15
+
16
+ app = FastAPI()
17
+
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"], # Allow access from all sources
21
+ allow_credentials=True,
22
+ allow_methods=["*"], # Allow all HTTP methods
23
+ allow_headers=["*"], # Allow all request headers
24
+ )
25
+
26
+ app.mount("/files", StaticFiles(directory=file_utils.output_dir), name="files")
27
+
28
+ app.include_router(query)
29
+ app.include_router(generate_v1)
30
+ app.include_router(generate_v2)
31
+ app.include_router(mannequin_to_model)
32
+
33
+
34
+ def start_app(args):
35
+ """Start the FastAPI application"""
36
+ file_utils.STATIC_SERVER_BASE = args.base_url + "/files/"
37
+ uvicorn.run(
38
+ app="fooocusapi.api:app",
39
+ host="0.0.0.0",
40
+ port=8000,
41
+ log_level=args.log_level)
fooocusapi/args.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Do not modify the import order
3
+ """
4
+ from fooocusapi.base_args import add_base_args
5
+ import ldm_patched.modules.args_parser as args_parser
6
+
7
+ # Add Fooocus-API args to parser
8
+ add_base_args(args_parser.parser, False)
9
+
10
+ # Apply Fooocus args
11
+ from args_manager import args_parser
12
+
13
+ # Override the port default value
14
+ args_parser.parser.set_defaults(
15
+ port=8888
16
+ )
17
+
18
+ # Execute args parse again
19
+ args_parser.args = args_parser.parser.parse_args()
20
+ args = args_parser.args
fooocusapi/base_args.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ base_args.py
3
+ """
4
+ from argparse import ArgumentParser
5
+
6
+
7
+ def add_base_args(parser: ArgumentParser, before_prepared: bool):
8
+ """
9
+ Add base args for fooocusapi
10
+ Args:
11
+ parser: ArgumentParser
12
+ before_prepared: before prepare environment
13
+ Returns:
14
+ """
15
+ if before_prepared:
16
+ parser.add_argument("--port", type=int, default=8888, help="Set the listen port, default: 8888")
17
+
18
+ parser.add_argument("--host", type=str, default='127.0.0.1', help="Set the listen host, default: 127.0.0.1")
19
+ parser.add_argument("--base-url", type=str, default=None, help="Set base url for outside visit, default is http://host:port")
20
+ parser.add_argument("--log-level", type=str, default='info', help="Log info for Uvicorn, default: info")
21
+ parser.add_argument("--skip-pip", default=False, action="store_true", help="Skip automatic pip install when setup")
22
+ parser.add_argument("--preload-pipeline", default=False, action="store_true", help="Preload pipeline before start http server")
23
+ parser.add_argument("--queue-size", type=int, default=100, help="Working queue size, default: 100, generation requests exceeding working queue size will return failure")
24
+ parser.add_argument("--queue-history", type=int, default=0, help="Finished jobs reserve size, tasks exceeding the limit will be deleted, including output image files, default: 0, means no limit")
25
+ parser.add_argument('--webhook-url', type=str, default=None, help='The URL to send a POST request when a job is finished')
26
+ parser.add_argument('--persistent', default=False, action="store_true", help="Store history to db")
27
+ parser.add_argument("--apikey", type=str, default=None, help="API key for authenticating requests")
fooocusapi/configs/default.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Static variables for Fooocus API
3
+ """
4
+ img_generate_responses = {
5
+ "200": {
6
+ "description": "PNG bytes if request's 'Accept' header is 'image/png', otherwise JSON",
7
+ "content": {
8
+ "application/json": {
9
+ "example": [{
10
+ "base64": "...very long string...",
11
+ "seed": "1050625087",
12
+ "finish_reason": "SUCCESS",
13
+ }]
14
+ },
15
+ "application/json async": {
16
+ "example": {
17
+ "job_id": 1,
18
+ "job_type": "Text to Image"
19
+ }
20
+ },
21
+ "image/png": {
22
+ "example": "PNG bytes, what did you expect?"
23
+ },
24
+ },
25
+ }
26
+ }
27
+
28
+ default_inpaint_engine_version = "v2.6"
29
+
30
+ default_styles = ["Fooocus V2", "Fooocus Enhance", "Fooocus Sharp"]
31
+ default_base_model_name = "juggernautXL_v8Rundiffusion.safetensors"
32
+ default_refiner_model_name = "None"
33
+ default_refiner_switch = 0.5
34
+ default_loras = [[True, "sd_xl_offset_example-lora_1.0.safetensors", 0.1]]
35
+ default_cfg_scale = 7.0
36
+ default_prompt_negative = ""
37
+ default_aspect_ratio = "1152*896"
38
+ default_sampler = "dpmpp_2m_sde_gpu"
39
+ default_scheduler = "karras"
40
+
41
+ available_aspect_ratios = [
42
+ "704*1408",
43
+ "704*1344",
44
+ "768*1344",
45
+ "768*1280",
46
+ "832*1216",
47
+ "832*1152",
48
+ "896*1152",
49
+ "896*1088",
50
+ "960*1088",
51
+ "960*1024",
52
+ "1024*1024",
53
+ "1024*960",
54
+ "1088*960",
55
+ "1088*896",
56
+ "1152*896",
57
+ "1152*832",
58
+ "1216*832",
59
+ "1280*768",
60
+ "1344*768",
61
+ "1344*704",
62
+ "1408*704",
63
+ "1472*704",
64
+ "1536*640",
65
+ "1600*640",
66
+ "1664*576",
67
+ "1728*576",
68
+ ]
69
+
70
+ uov_methods = [
71
+ "Disabled",
72
+ "Vary (Subtle)",
73
+ "Vary (Strong)",
74
+ "Upscale (1.5x)",
75
+ "Upscale (2x)",
76
+ "Upscale (Fast 2x)",
77
+ "Upscale (Custom)",
78
+ ]
79
+
80
+ outpaint_expansions = ["Left", "Right", "Top", "Bottom"]
81
+
82
+
83
+ def get_aspect_ratio_value(label: str) -> str:
84
+ """
85
+ Get aspect ratio
86
+ Args:
87
+ label: str, aspect ratio
88
+
89
+ Returns:
90
+
91
+ """
92
+ return label.split(" ")[0].replace("×", "*")
fooocusapi/models/common/base.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Common models"""
2
+ from typing import List, Tuple
3
+ from enum import Enum
4
+ from fastapi import UploadFile
5
+ from fastapi.exceptions import RequestValidationError
6
+ from pydantic import (
7
+ ValidationError,
8
+ ConfigDict,
9
+ BaseModel,
10
+ TypeAdapter,
11
+ Field
12
+ )
13
+ from pydantic_core import InitErrorDetails
14
+
15
+ from fooocusapi.configs.default import default_loras
16
+
17
+
18
+ class PerformanceSelection(str, Enum):
19
+ """Performance selection"""
20
+ speed = 'Speed'
21
+ quality = 'Quality'
22
+ extreme_speed = 'Extreme Speed'
23
+ lightning = 'Lightning'
24
+ hyper_sd = 'Hyper-SD'
25
+
26
+
27
+ class Lora(BaseModel):
28
+ """Common params lora model"""
29
+ enabled: bool
30
+ model_name: str
31
+ weight: float = Field(default=0.5, ge=-2, le=2)
32
+
33
+ model_config = ConfigDict(
34
+ protected_namespaces=('protect_me_', 'also_protect_')
35
+ )
36
+
37
+
38
+ LoraList = TypeAdapter(List[Lora])
39
+ default_loras_model = []
40
+ for lora in default_loras:
41
+ if lora[0] != 'None':
42
+ default_loras_model.append(
43
+ Lora(
44
+ enabled=lora[0],
45
+ model_name=lora[1],
46
+ weight=lora[2])
47
+ )
48
+ default_loras_json = LoraList.dump_json(default_loras_model)
49
+
50
+
51
+ class UpscaleOrVaryMethod(str, Enum):
52
+ """Upscale or Vary method"""
53
+ subtle_variation = 'Vary (Subtle)'
54
+ strong_variation = 'Vary (Strong)'
55
+ upscale_15 = 'Upscale (1.5x)'
56
+ upscale_2 = 'Upscale (2x)'
57
+ upscale_fast = 'Upscale (Fast 2x)'
58
+ upscale_custom = 'Upscale (Custom)'
59
+
60
+
61
+ class OutpaintExpansion(str, Enum):
62
+ """Outpaint expansion"""
63
+ left = 'Left'
64
+ right = 'Right'
65
+ top = 'Top'
66
+ bottom = 'Bottom'
67
+
68
+
69
+ class ControlNetType(str, Enum):
70
+ """ControlNet Type"""
71
+ cn_ip = "ImagePrompt"
72
+ cn_ip_face = "FaceSwap"
73
+ cn_canny = "PyraCanny"
74
+ cn_cpds = "CPDS"
75
+
76
+
77
+ class ImagePrompt(BaseModel):
78
+ """Common params object ImagePrompt"""
79
+ cn_img: UploadFile | None = Field(default=None)
80
+ cn_stop: float | None = Field(default=None, ge=0, le=1)
81
+ cn_weight: float | None = Field(default=None, ge=0, le=2, description="None for default value")
82
+ cn_type: ControlNetType = Field(default=ControlNetType.cn_ip)
83
+
84
+
85
+ class DescribeImageType(str, Enum):
86
+ """Image type for image to prompt"""
87
+ photo = 'Photo'
88
+ anime = 'Anime'
89
+
90
+
91
+ class ImageMetaScheme(str, Enum):
92
+ """Scheme for save image meta
93
+ Attributes:
94
+ Fooocus: json format
95
+ A111: string
96
+ """
97
+ Fooocus = 'fooocus'
98
+ A111 = 'a111'
99
+
100
+
101
+ def style_selection_parser(style_selections: str | List[str]) -> List[str]:
102
+ """
103
+ Parse style selections, Convert to list
104
+ Args:
105
+ style_selections: str, comma separated Fooocus style selections
106
+ e.g. Fooocus V2, Fooocus Enhance, Fooocus Sharp
107
+ Returns:
108
+ List[str]
109
+ """
110
+ style_selection_arr: List[str] = []
111
+ if style_selections is None or len(style_selections) == 0:
112
+ return []
113
+ for part in style_selections:
114
+ if len(part) > 0:
115
+ for s in part.split(','):
116
+ style = s.strip()
117
+ style_selection_arr.append(style)
118
+ return style_selection_arr
119
+
120
+
121
+ def lora_parser(loras: str) -> List[Lora]:
122
+ """
123
+ Parse lora config, Convert to list
124
+ Args:
125
+ loras: a json string for loras
126
+ Returns:
127
+ List[Lora]
128
+ """
129
+ loras_model: List[Lora] = []
130
+ if loras is None or len(loras) == 0:
131
+ return loras_model
132
+ try:
133
+ loras_model = LoraList.validate_json(loras)
134
+ return loras_model
135
+ except ValidationError as ve:
136
+ errs = ve.errors()
137
+ raise RequestValidationError from errs
138
+
139
+
140
+ def outpaint_selections_parser(outpaint_selections: str | list[str]) -> List[OutpaintExpansion]:
141
+ """
142
+ Parse outpaint selections, Convert to list
143
+ Args:
144
+ outpaint_selections: str, comma separated Left, Right, Top, Bottom
145
+ e.g. Left, Right, Top, Bottom
146
+ Returns:
147
+ List[OutpaintExpansion]
148
+ """
149
+ outpaint_selections_arr: List[OutpaintExpansion] = []
150
+ if outpaint_selections is None or len(outpaint_selections) == 0:
151
+ return []
152
+ for part in outpaint_selections:
153
+ if len(part) > 0:
154
+ for s in part.split(','):
155
+ try:
156
+ expansion = OutpaintExpansion(s)
157
+ outpaint_selections_arr.append(expansion)
158
+ except ValueError:
159
+ errs = InitErrorDetails(
160
+ type='enum',
161
+ loc=tuple('outpaint_selections'),
162
+ input=outpaint_selections,
163
+ ctx={
164
+ 'expected': "str, comma separated Left, Right, Top, Bottom"
165
+ })
166
+ raise RequestValidationError from errs
167
+ return outpaint_selections_arr
168
+
169
+
170
+ def image_prompt_parser(image_prompts_config: List[Tuple]) -> List[ImagePrompt]:
171
+ """
172
+ Image prompt parser, Convert to List[ImagePrompt]
173
+ Args:
174
+ image_prompts_config: List[Tuple]
175
+ e.g. ('image1.jpg', 0.5, 1.0, 'normal'), ('image2.jpg', 0.5, 1.0, 'normal')
176
+ returns:
177
+ List[ImagePrompt]
178
+ """
179
+ image_prompts: List[ImagePrompt] = []
180
+ if image_prompts_config is None or len(image_prompts_config) == 0:
181
+ return []
182
+ for config in image_prompts_config:
183
+ cn_img, cn_stop, cn_weight, cn_type = config
184
+ image_prompts.append(ImagePrompt(
185
+ cn_img=cn_img,
186
+ cn_stop=cn_stop,
187
+ cn_weight=cn_weight,
188
+ cn_type=cn_type))
189
+ return image_prompts
fooocusapi/models/common/image_meta.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image meta schema
3
+ """
4
+ from typing import List
5
+
6
+ from fooocus_version import version
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class ImageMeta(BaseModel):
11
+ """
12
+ Image meta data model
13
+ """
14
+
15
+ metadata_scheme: str = "fooocus"
16
+
17
+ base_model: str
18
+ base_model_hash: str
19
+
20
+ prompt: str
21
+ full_prompt: List[str]
22
+ prompt_expansion: str
23
+
24
+ negative_prompt: str
25
+ full_negative_prompt: List[str]
26
+
27
+ performance: str
28
+
29
+ style: str
30
+
31
+ refiner_model: str = "None"
32
+ refiner_switch: float = 0.5
33
+
34
+ loras: List[list]
35
+
36
+ resolution: str
37
+
38
+ sampler: str = "dpmpp_2m_sde_gpu"
39
+ scheduler: str = "karras"
40
+ seed: str
41
+ adm_guidance: str
42
+ guidance_scale: float
43
+ sharpness: float
44
+ steps: int
45
+ vae_name: str
46
+
47
+ version: str = version
48
+
49
+ def __repr__(self):
50
+ return ""
51
+
52
+
53
+ def loras_parser(loras: list) -> list:
54
+ """
55
+ Parse lora list
56
+ """
57
+ return [
58
+ [
59
+ lora[0].rsplit('.', maxsplit=1)[:1][0],
60
+ lora[1],
61
+ "hash_not_calculated",
62
+ ] for lora in loras if lora[0] != 'None' and lora[0] is not None]
63
+
64
+
65
+ def image_parse(
66
+ async_tak: object,
67
+ task: dict
68
+ ) -> dict | str:
69
+ """
70
+ Parse image meta data
71
+ Generate meta data for image from task and async task object
72
+ Args:
73
+ async_tak: async task obj
74
+ task: task obj
75
+
76
+ Returns:
77
+ dict: image meta data
78
+ """
79
+ req_param = async_tak.req_param
80
+ meta = ImageMeta(
81
+ metadata_scheme=req_param.meta_scheme,
82
+ base_model=req_param.base_model_name.rsplit('.', maxsplit=1)[:1][0],
83
+ base_model_hash='',
84
+ prompt=req_param.prompt,
85
+ full_prompt=task['positive'],
86
+ prompt_expansion=task['expansion'],
87
+ negative_prompt=req_param.negative_prompt,
88
+ full_negative_prompt=task['negative'],
89
+ performance=req_param.performance_selection,
90
+ style=str(req_param.style_selections),
91
+ refiner_model=req_param.refiner_model_name,
92
+ refiner_switch=req_param.refiner_switch,
93
+ loras=loras_parser(req_param.loras),
94
+ resolution=str(tuple([int(n) for n in req_param.aspect_ratios_selection.split('*')])),
95
+ sampler=req_param.advanced_params.sampler_name,
96
+ scheduler=req_param.advanced_params.scheduler_name,
97
+ seed=str(task['task_seed']),
98
+ adm_guidance=str((
99
+ req_param.advanced_params.adm_scaler_positive,
100
+ req_param.advanced_params.adm_scaler_negative,
101
+ req_param.advanced_params.adm_scaler_end)),
102
+ guidance_scale=req_param.guidance_scale,
103
+ sharpness=req_param.sharpness,
104
+ steps=-1,
105
+ vae_name=req_param.advanced_params.vae_name,
106
+ version=version
107
+ )
108
+ if meta.metadata_scheme not in ["fooocus", "a111"]:
109
+ meta.metadata_scheme = "fooocus"
110
+ if meta.metadata_scheme == "fooocus":
111
+ meta_dict = meta.model_dump()
112
+ for i, lora in enumerate(meta.loras):
113
+ attr_name = f"lora_combined_{i+1}"
114
+ lr = [str(x) for x in lora]
115
+ meta_dict[attr_name] = f"{lr[0]} : {lr[1]}"
116
+ else:
117
+ meta_dict = meta.model_dump()
118
+ return meta_dict
fooocusapi/models/common/requests.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Common model for requests"""
2
+ from typing import List
3
+ from pydantic import (
4
+ BaseModel,
5
+ Field,
6
+ ValidationError
7
+ )
8
+
9
+ from modules.config import (
10
+ default_sampler,
11
+ default_scheduler,
12
+ default_prompt,
13
+ default_prompt_negative,
14
+ default_aspect_ratio,
15
+ default_base_model_name,
16
+ default_refiner_model_name,
17
+ default_refiner_switch,
18
+ default_cfg_scale,
19
+ default_styles,
20
+ default_overwrite_step,
21
+ default_inpaint_engine_version,
22
+ default_overwrite_switch,
23
+ default_cfg_tsnr,
24
+ default_sample_sharpness,
25
+ default_vae,
26
+ default_clip_skip
27
+ )
28
+
29
+ from modules.flags import clip_skip_max
30
+
31
+ from fooocusapi.models.common.base import (
32
+ PerformanceSelection,
33
+ Lora,
34
+ default_loras_model
35
+ )
36
+
37
+ default_aspect_ratio = default_aspect_ratio.split(" ")[0].replace("×", "*")
38
+
39
+
40
+ class QueryJobRequest(BaseModel):
41
+ """Query job request"""
42
+ job_id: str = Field(description="Job ID to query")
43
+ require_step_preview: bool = Field(
44
+ default=False,
45
+ description="Set to true will return preview image of generation steps at current time")
46
+
47
+
48
+ class AdvancedParams(BaseModel):
49
+ """Common params object AdvancedParams"""
50
+ disable_preview: bool = Field(False, description="Disable preview during generation")
51
+ disable_intermediate_results: bool = Field(False, description="Disable intermediate results")
52
+ disable_seed_increment: bool = Field(False, description="Disable Seed Increment")
53
+ adm_scaler_positive: float = Field(1.5, description="Positive ADM Guidance Scaler", ge=0.1, le=3.0)
54
+ adm_scaler_negative: float = Field(0.8, description="Negative ADM Guidance Scaler", ge=0.1, le=3.0)
55
+ adm_scaler_end: float = Field(0.3, description="ADM Guidance End At Step", ge=0.0, le=1.0)
56
+ adaptive_cfg: float = Field(default_cfg_tsnr, description="CFG Mimicking from TSNR", ge=1.0, le=30.0)
57
+ clip_skip: int = Field(default_clip_skip, description="Clip Skip", ge=1, le=clip_skip_max)
58
+ sampler_name: str = Field(default_sampler, description="Sampler")
59
+ scheduler_name: str = Field(default_scheduler, description="Scheduler")
60
+ overwrite_step: int = Field(default_overwrite_step, description="Forced Overwrite of Sampling Step", ge=-1, le=200)
61
+ overwrite_switch: float = Field(default_overwrite_switch, description="Forced Overwrite of Refiner Switch Step", ge=-1, le=1)
62
+ overwrite_width: int = Field(-1, description="Forced Overwrite of Generating Width", ge=-1, le=2048)
63
+ overwrite_height: int = Field(-1, description="Forced Overwrite of Generating Height", ge=-1, le=2048)
64
+ overwrite_vary_strength: float = Field(-1, description='Forced Overwrite of Denoising Strength of "Vary"', ge=-1, le=1.0)
65
+ overwrite_upscale_strength: float = Field(-1, description='Forced Overwrite of Denoising Strength of "Upscale"', ge=-1, le=1.0)
66
+ mixing_image_prompt_and_vary_upscale: bool = Field(False, description="Mixing Image Prompt and Vary/Upscale")
67
+ mixing_image_prompt_and_inpaint: bool = Field(False, description="Mixing Image Prompt and Inpaint")
68
+ debugging_cn_preprocessor: bool = Field(False, description="Debug Preprocessors")
69
+ skipping_cn_preprocessor: bool = Field(False, description="Skip Preprocessors")
70
+ canny_low_threshold: int = Field(64, description="Canny Low Threshold", ge=1, le=255)
71
+ canny_high_threshold: int = Field(128, description="Canny High Threshold", ge=1, le=255)
72
+ refiner_swap_method: str = Field('joint', description="Refiner swap method")
73
+ controlnet_softness: float = Field(0.25, description="Softness of ControlNet", ge=0.0, le=1.0)
74
+ freeu_enabled: bool = Field(False, description="FreeU enabled")
75
+ freeu_b1: float = Field(1.01, description="FreeU B1")
76
+ freeu_b2: float = Field(1.02, description="FreeU B2")
77
+ freeu_s1: float = Field(0.99, description="FreeU B3")
78
+ freeu_s2: float = Field(0.95, description="FreeU B4")
79
+ debugging_inpaint_preprocessor: bool = Field(False, description="Debug Inpaint Preprocessing")
80
+ inpaint_disable_initial_latent: bool = Field(False, description="Disable initial latent in inpaint")
81
+ inpaint_engine: str = Field(default_inpaint_engine_version, description="Inpaint Engine")
82
+ inpaint_strength: float = Field(1.0, description="Inpaint Denoising Strength", ge=0.0, le=1.0)
83
+ inpaint_respective_field: float = Field(1.0, description="Inpaint Respective Field", ge=0.0, le=1.0)
84
+ inpaint_mask_upload_checkbox: bool = Field(False, description="Upload Mask")
85
+ invert_mask_checkbox: bool = Field(False, description="Invert Mask")
86
+ inpaint_erode_or_dilate: int = Field(0, description="Mask Erode or Dilate", ge=-64, le=64)
87
+ black_out_nsfw: bool = Field(False, description="Block out NSFW")
88
+ vae_name: str = Field(default_vae, description="VAE name")
89
+
90
+
91
+ class CommonRequest(BaseModel):
92
+ """All generate request based on this model"""
93
+ prompt: str = default_prompt
94
+ negative_prompt: str = default_prompt_negative
95
+ style_selections: List[str] = default_styles
96
+ performance_selection: PerformanceSelection = PerformanceSelection.speed
97
+ aspect_ratios_selection: str = default_aspect_ratio
98
+ image_number: int = Field(default=1, description="Image number", ge=1, le=32)
99
+ image_seed: int = Field(default=-1, description="Seed to generate image, -1 for random")
100
+ sharpness: float = Field(default=default_sample_sharpness, ge=0.0, le=30.0)
101
+ guidance_scale: float = Field(default=default_cfg_scale, ge=1.0, le=30.0)
102
+ base_model_name: str = default_base_model_name
103
+ refiner_model_name: str = default_refiner_model_name
104
+ refiner_switch: float = Field(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0)
105
+ loras: List[Lora] = Field(default=default_loras_model)
106
+ advanced_params: AdvancedParams = AdvancedParams()
107
+ save_meta: bool = Field(default=True, description="Save meta data")
108
+ meta_scheme: str = Field(default='fooocus', description="Meta data scheme, one of [fooocus, a111]")
109
+ save_extension: str = Field(default='png', description="Save extension, one of [png, jpg, webp]")
110
+ save_name: str = Field(default='', description="Image name for output image, default is job id + seq")
111
+ read_wildcards_in_order: bool = Field(default=False, description="Read wildcards in order")
112
+ require_base64: bool = Field(default=False, description="Return base64 data of generated image")
113
+ async_process: bool = Field(default=False, description="Set to true will run async and return job info for retrieve generation result later")
114
+ webhook_url: str | None = Field(default='', description="Optional URL for a webhook callback. If provided, the system will send a POST request to this URL upon task completion or failure."
115
+ " This allows for asynchronous notification of task status.")
116
+
117
+
118
+ def advanced_params_parser(advanced_params: str | None) -> AdvancedParams:
119
+ """
120
+ Parse advanced params, Convert to AdvancedParams
121
+ Args:
122
+ advanced_params: str, json format
123
+ Returns:
124
+ AdvancedParams object, if validate error return default value
125
+ """
126
+ if advanced_params is not None and len(advanced_params) > 0:
127
+ try:
128
+ advanced_params_obj = AdvancedParams.__pydantic_validator__.validate_json(advanced_params)
129
+ return AdvancedParams(**advanced_params_obj)
130
+ except ValidationError:
131
+ return AdvancedParams()
132
+ return AdvancedParams()
fooocusapi/models/common/response.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fooocus API models for response"""
2
+ from typing import List
3
+
4
+ from pydantic import (
5
+ BaseModel,
6
+ ConfigDict,
7
+ Field
8
+ )
9
+
10
+ from fooocusapi.models.common.task import (
11
+ GeneratedImageResult,
12
+ AsyncJobStage
13
+ )
14
+ from fooocusapi.task_queue import TaskType
15
+
16
+
17
+ class DescribeImageResponse(BaseModel):
18
+ """
19
+ describe image response
20
+ """
21
+ describe: str
22
+
23
+
24
+ class AsyncJobResponse(BaseModel):
25
+ """
26
+ Async job response
27
+ Attributes:
28
+ job_id: Job ID
29
+ job_type: Job type
30
+ job_stage: Job stage
31
+ job_progress: Job progress, 0-100
32
+ job_status: Job status
33
+ job_step_preview: Job step preview
34
+ job_result: Job result
35
+ """
36
+ job_id: str = Field(description="Job ID")
37
+ job_type: TaskType = Field(description="Job type")
38
+ job_stage: AsyncJobStage = Field(description="Job running stage")
39
+ job_progress: int = Field(description="Job running progress, 100 is for finished.")
40
+ job_status: str | None = Field(None, description="Job running status in text")
41
+ job_step_preview: str | None = Field(None, description="Preview image of generation steps at current time, as base64 image")
42
+ job_result: List[GeneratedImageResult] | None = Field(None, description="Job generation result")
43
+
44
+
45
+ class JobQueueInfo(BaseModel):
46
+ """
47
+ job queue info
48
+ Attributes:
49
+ running_size: int, The current running and waiting job count
50
+ finished_size: int, The current finished job count
51
+ last_job_id: str, Last submit generation job id
52
+ """
53
+ running_size: int = Field(description="The current running and waiting job count")
54
+ finished_size: int = Field(description="Finished job count (after auto clean)")
55
+ last_job_id: str | None = Field(description="Last submit generation job id")
56
+
57
+
58
+ # TODO May need more detail fields, will add later when someone need
59
+ class JobHistoryInfo(BaseModel):
60
+ """
61
+ job history info
62
+ """
63
+ job_id: str
64
+ is_finished: bool = False
65
+
66
+
67
+ # Response model for the historical tasks
68
+ class JobHistoryResponse(BaseModel):
69
+ """
70
+ job history response
71
+ """
72
+ queue: List[JobHistoryInfo] = []
73
+ history: List[JobHistoryInfo] = []
74
+
75
+
76
+ class AllModelNamesResponse(BaseModel):
77
+ """
78
+ all model list response
79
+ """
80
+ model_filenames: List[str] = Field(description="All available model filenames")
81
+ lora_filenames: List[str] = Field(description="All available lora filenames")
82
+
83
+ model_config = ConfigDict(
84
+ protected_namespaces=('protect_me_', 'also_protect_')
85
+ )
86
+
87
+
88
+ class StopResponse(BaseModel):
89
+ """stop task response"""
90
+ msg: str
fooocusapi/models/common/task.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task and job related models
3
+ """
4
+ from enum import Enum
5
+ from pydantic import (
6
+ BaseModel,
7
+ Field
8
+ )
9
+
10
+
11
+ class TaskType(str, Enum):
12
+ """
13
+ Task type object
14
+ """
15
+ text_2_img = 'Text to Image'
16
+ img_uov = 'Image Upscale or Variation'
17
+ img_inpaint_outpaint = 'Image Inpaint or Outpaint'
18
+ img_prompt = 'Image Prompt'
19
+ not_found = 'Not Found'
20
+
21
+
22
+ class GenerationFinishReason(str, Enum):
23
+ """
24
+ Generation finish reason
25
+ """
26
+ success = 'SUCCESS'
27
+ queue_is_full = 'QUEUE_IS_FULL'
28
+ user_cancel = 'USER_CANCEL'
29
+ error = 'ERROR'
30
+
31
+
32
+ class ImageGenerationResult:
33
+ """
34
+ Image generation result
35
+ """
36
+ def __init__(self, im: str | None, seed: str, finish_reason: GenerationFinishReason):
37
+ self.im = im
38
+ self.seed = seed
39
+ self.finish_reason = finish_reason
40
+
41
+
42
+ class AsyncJobStage(str, Enum):
43
+ """
44
+ Async job stage
45
+ """
46
+ waiting = 'WAITING'
47
+ running = 'RUNNING'
48
+ success = 'SUCCESS'
49
+ error = 'ERROR'
50
+
51
+
52
+ class GeneratedImageResult(BaseModel):
53
+ """
54
+ Generated images result
55
+ """
56
+ base64: str | None = Field(
57
+ description="Image encoded in base64, or null if finishReason is not 'SUCCESS', only return when request require base64")
58
+ url: str | None = Field(description="Image file static serve url, or null if finishReason is not 'SUCCESS'")
59
+ seed: str = Field(description="The seed associated with this image")
60
+ finish_reason: GenerationFinishReason
fooocusapi/models/requests_v1.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ requests models for v1 endpoints
3
+ """
4
+ from typing import List
5
+ from fastapi.params import File
6
+ from fastapi import (
7
+ UploadFile,
8
+ Form
9
+ )
10
+ from fooocusapi.models.common.requests import (
11
+ CommonRequest,
12
+ advanced_params_parser
13
+ )
14
+ from fooocusapi.models.common.base import (
15
+ ImagePrompt,
16
+ ControlNetType,
17
+ OutpaintExpansion,
18
+ UpscaleOrVaryMethod,
19
+ PerformanceSelection
20
+ )
21
+
22
+ from fooocusapi.models.common.base import (
23
+ style_selection_parser,
24
+ lora_parser,
25
+ outpaint_selections_parser,
26
+ image_prompt_parser,
27
+ default_loras_json
28
+ )
29
+
30
+ from fooocusapi.configs.default import (
31
+ default_prompt_negative,
32
+ default_aspect_ratio,
33
+ default_base_model_name,
34
+ default_refiner_model_name,
35
+ default_refiner_switch,
36
+ default_cfg_scale,
37
+ default_styles,
38
+ )
39
+
40
+
41
+ class ImgUpscaleOrVaryRequest(CommonRequest):
42
+ """
43
+ Request for image upscale or variation
44
+ Attributes:
45
+ input_image: Input image
46
+ uov_method: Upscale or variation method
47
+ upscale_value: upscale value
48
+ Functions:
49
+ as_form: Convert request to form data
50
+ """
51
+ input_image: UploadFile
52
+ uov_method: UpscaleOrVaryMethod
53
+ upscale_value: float | None
54
+
55
+ @classmethod
56
+ def as_form(
57
+ cls,
58
+ input_image: UploadFile = Form(description="Init image for upscale or outpaint"),
59
+ uov_method: UpscaleOrVaryMethod = Form(),
60
+ upscale_value: float | None = Form(None, description="Upscale custom value, None for default value", ge=1.0, le=5.0),
61
+ prompt: str = Form(''),
62
+ negative_prompt: str = Form(default_prompt_negative),
63
+ style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
64
+ performance_selection: PerformanceSelection = Form(PerformanceSelection.speed, description="Performance Selection, one of 'Speed','Quality','Extreme Speed'"),
65
+ aspect_ratios_selection: str = Form(default_aspect_ratio, description="Aspect Ratios Selection, default 1152*896"),
66
+ image_number: int = Form(default=1, description="Image number", ge=1, le=32),
67
+ image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
68
+ sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
69
+ guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
70
+ base_model_name: str = Form(default_base_model_name, description="checkpoint file name"),
71
+ refiner_model_name: str = Form(default_refiner_model_name, description="refiner file name"),
72
+ refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
73
+ loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
74
+ advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
75
+ save_meta: bool = Form(default=False, description="Save metadata to image"),
76
+ meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
77
+ save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
78
+ save_name: str = Form(default="", description="Save name, empty for auto generate"),
79
+ require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
80
+ read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
81
+ async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
82
+ webhook_url: str = Form(default="", description="Webhook url for generation result"),
83
+ ):
84
+ style_selection_arr = style_selection_parser(style_selections)
85
+ loras_model = lora_parser(loras)
86
+ advanced_params_obj = advanced_params_parser(advanced_params)
87
+
88
+ return cls(
89
+ input_image=input_image, uov_method=uov_method, upscale_value=upscale_value,
90
+ prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
91
+ performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
92
+ image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
93
+ base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
94
+ loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
95
+ save_extension=save_extension, save_name=save_name, require_base64=require_base64,
96
+ read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
97
+
98
+
99
+ class ImgInpaintOrOutpaintRequest(CommonRequest):
100
+ """
101
+ Image Inpaint or Outpaint Request
102
+ """
103
+ input_image: UploadFile | None
104
+ input_mask: UploadFile | None
105
+ inpaint_additional_prompt: str | None
106
+ outpaint_selections: List[OutpaintExpansion]
107
+ outpaint_distance_left: int
108
+ outpaint_distance_right: int
109
+ outpaint_distance_top: int
110
+ outpaint_distance_bottom: int
111
+
112
+ @classmethod
113
+ def as_form(
114
+ cls,
115
+ input_image: UploadFile = Form(description="Init image for inpaint or outpaint"),
116
+ input_mask: UploadFile = Form(File(None), description="Inpaint or outpaint mask"),
117
+ inpaint_additional_prompt: str | None = Form("", description="Describe what you want to inpaint"),
118
+ outpaint_selections: List[str] = Form([], description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
119
+ outpaint_distance_left: int = Form(default=0, description="Set outpaint left distance, -1 for default"),
120
+ outpaint_distance_right: int = Form(default=0, description="Set outpaint right distance, -1 for default"),
121
+ outpaint_distance_top: int = Form(default=0, description="Set outpaint top distance, -1 for default"),
122
+ outpaint_distance_bottom: int = Form(default=0, description="Set outpaint bottom distance, -1 for default"),
123
+ prompt: str = Form(''),
124
+ negative_prompt: str = Form(default_prompt_negative),
125
+ style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
126
+ performance_selection: PerformanceSelection = Form(PerformanceSelection.speed, description="Performance Selection, one of 'Speed','Quality','Extreme Speed'"),
127
+ aspect_ratios_selection: str = Form(default_aspect_ratio, description="Aspect Ratios Selection, default 1152*896"),
128
+ image_number: int = Form(default=1, description="Image number", ge=1, le=32),
129
+ image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
130
+ sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
131
+ guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
132
+ base_model_name: str = Form(default_base_model_name),
133
+ refiner_model_name: str = Form(default_refiner_model_name),
134
+ refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
135
+ loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
136
+ advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
137
+ save_meta: bool = Form(default=False, description="Save metadata to image"),
138
+ meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
139
+ save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
140
+ save_name: str = Form(default="", description="Save name, empty for auto generate"),
141
+ require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
142
+ read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
143
+ async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
144
+ webhook_url: str = Form(default="", description="Webhook url for generation result"),
145
+ ):
146
+ if isinstance(input_mask, File):
147
+ input_mask = None
148
+
149
+ outpaint_selections_arr = outpaint_selections_parser(outpaint_selections)
150
+ style_selection_arr = style_selection_parser(style_selections)
151
+ loras_model = lora_parser(loras)
152
+ advanced_params_obj = advanced_params_parser(advanced_params)
153
+
154
+ return cls(
155
+ input_image=input_image, input_mask=input_mask, inpaint_additional_prompt=inpaint_additional_prompt,
156
+ outpaint_selections=outpaint_selections_arr, outpaint_distance_left=outpaint_distance_left,
157
+ outpaint_distance_right=outpaint_distance_right, outpaint_distance_top=outpaint_distance_top,
158
+ outpaint_distance_bottom=outpaint_distance_bottom, prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
159
+ performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
160
+ image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
161
+ base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
162
+ loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
163
+ save_extension=save_extension, save_name=save_name, require_base64=require_base64,
164
+ read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
165
+
166
+
167
+ class ImgPromptRequest(ImgInpaintOrOutpaintRequest):
168
+ """
169
+ Image Prompt Request
170
+ """
171
+ image_prompts: List[ImagePrompt]
172
+
173
+ @classmethod
174
+ def as_form(
175
+ cls,
176
+ input_image: UploadFile = Form(File(None), description="Init image for inpaint or outpaint"),
177
+ input_mask: UploadFile = Form(File(None), description="Inpaint or outpaint mask"),
178
+ inpaint_additional_prompt: str | None = Form(None, description="Describe what you want to inpaint"),
179
+ outpaint_selections: List[str] = Form([], description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
180
+ outpaint_distance_left: int = Form(default=0, description="Set outpaint left distance, 0 for default"),
181
+ outpaint_distance_right: int = Form(default=0, description="Set outpaint right distance, 0 for default"),
182
+ outpaint_distance_top: int = Form(default=0, description="Set outpaint top distance, 0 for default"),
183
+ outpaint_distance_bottom: int = Form(default=0, description="Set outpaint bottom distance, 0 for default"),
184
+ cn_img1: UploadFile = Form(File(None), description="Input image for image prompt"),
185
+ cn_stop1: float | None = Form(
186
+ default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
187
+ cn_weight1: float | None = Form(
188
+ default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
189
+ cn_type1: ControlNetType = Form(
190
+ default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
191
+ cn_img2: UploadFile = Form(
192
+ File(None), description="Input image for image prompt"),
193
+ cn_stop2: float | None = Form(
194
+ default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
195
+ cn_weight2: float | None = Form(
196
+ default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
197
+ cn_type2: ControlNetType = Form(
198
+ default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
199
+ cn_img3: UploadFile = Form(
200
+ File(None), description="Input image for image prompt"),
201
+ cn_stop3: float | None = Form(
202
+ default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
203
+ cn_weight3: float | None = Form(
204
+ default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
205
+ cn_type3: ControlNetType = Form(
206
+ default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
207
+ cn_img4: UploadFile = Form(
208
+ File(None), description="Input image for image prompt"),
209
+ cn_stop4: float | None = Form(
210
+ default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
211
+ cn_weight4: float | None = Form(
212
+ default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
213
+ cn_type4: ControlNetType = Form(
214
+ default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
215
+ prompt: str = Form(''),
216
+ negative_prompt: str = Form(default_prompt_negative),
217
+ style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
218
+ performance_selection: PerformanceSelection = Form(
219
+ PerformanceSelection.speed),
220
+ aspect_ratios_selection: str = Form(default_aspect_ratio),
221
+ image_number: int = Form(
222
+ default=1, description="Image number", ge=1, le=32),
223
+ image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
224
+ sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
225
+ guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
226
+ base_model_name: str = Form(default_base_model_name),
227
+ refiner_model_name: str = Form(default_refiner_model_name),
228
+ refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
229
+ loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
230
+ advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
231
+ save_meta: bool = Form(default=False, description="Save metadata to image"),
232
+ meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
233
+ save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
234
+ save_name: str = Form(default="", description="Save name, empty for auto generate"),
235
+ require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
236
+ read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
237
+ async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
238
+ webhook_url: str = Form(default="", description="Webhook url for generation result"),
239
+ ):
240
+ if isinstance(input_image, File):
241
+ input_image = None
242
+ if isinstance(input_mask, File):
243
+ input_mask = None
244
+ if isinstance(cn_img1, File):
245
+ cn_img1 = None
246
+ if isinstance(cn_img2, File):
247
+ cn_img2 = None
248
+ if isinstance(cn_img3, File):
249
+ cn_img3 = None
250
+ if isinstance(cn_img4, File):
251
+ cn_img4 = None
252
+
253
+ outpaint_selections_arr = outpaint_selections_parser(outpaint_selections)
254
+
255
+ image_prompt_config = [
256
+ (cn_img1, cn_stop1, cn_weight1, cn_type1),
257
+ (cn_img2, cn_stop2, cn_weight2, cn_type2),
258
+ (cn_img3, cn_stop3, cn_weight3, cn_type3),
259
+ (cn_img4, cn_stop4, cn_weight4, cn_type4)]
260
+ image_prompts = image_prompt_parser(image_prompt_config)
261
+ style_selection_arr = style_selection_parser(style_selections)
262
+ loras_model = lora_parser(loras)
263
+ advanced_params_obj = advanced_params_parser(advanced_params)
264
+
265
+ return cls(
266
+ input_image=input_image, input_mask=input_mask, inpaint_additional_prompt=inpaint_additional_prompt, outpaint_selections=outpaint_selections_arr,
267
+ outpaint_distance_left=outpaint_distance_left, outpaint_distance_right=outpaint_distance_right, outpaint_distance_top=outpaint_distance_top, outpaint_distance_bottom=outpaint_distance_bottom,
268
+ image_prompts=image_prompts, prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
269
+ performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
270
+ image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
271
+ base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
272
+ loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
273
+ save_extension=save_extension, save_name=save_name, require_base64=require_base64,
274
+ read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
fooocusapi/models/requests_v2.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """V2 API models"""
2
+ from typing import List
3
+ from pydantic import BaseModel, Field
4
+ from fooocusapi.models.common.requests import CommonRequest
5
+ from fooocusapi.models.common.base import (
6
+ ControlNetType,
7
+ OutpaintExpansion,
8
+ ImagePrompt,
9
+ UpscaleOrVaryMethod
10
+ )
11
+
12
+
13
+ class ImagePromptJson(BaseModel):
14
+ """Image prompt for V2 API"""
15
+ cn_img: str | None = Field(None, description="Input image for image prompt as base64")
16
+ cn_stop: float | None = Field(0, ge=0, le=1, description="Stop at for image prompt, 0 for default value")
17
+ cn_weight: float | None = Field(0, ge=0, le=2, description="Weight for image prompt, 0 for default value")
18
+ cn_type: ControlNetType = Field(default=ControlNetType.cn_ip, description="ControlNet type for image prompt")
19
+
20
+
21
+ class ImgInpaintOrOutpaintRequestJson(CommonRequest):
22
+ """image inpaint or outpaint request"""
23
+ input_image: str = Field('', description="Init image for inpaint or outpaint as base64")
24
+ input_mask: str | None = Field('', description="Inpaint or outpaint mask as base64")
25
+ inpaint_additional_prompt: str | None = Field('', description="Describe what you want to inpaint")
26
+ outpaint_selections: List[OutpaintExpansion] = []
27
+ outpaint_distance_left: int | None = Field(-1, description="Set outpaint left distance")
28
+ outpaint_distance_right: int | None = Field(-1, description="Set outpaint right distance")
29
+ outpaint_distance_top: int | None = Field(-1, description="Set outpaint top distance")
30
+ outpaint_distance_bottom: int | None = Field(-1, description="Set outpaint bottom distance")
31
+ image_prompts: List[ImagePromptJson | ImagePrompt] = []
32
+
33
+
34
+ class ImgPromptRequestJson(ImgInpaintOrOutpaintRequestJson):
35
+ """img prompt request json"""
36
+ input_image: str | None = Field(None, description="Init image for inpaint or outpaint as base64")
37
+ image_prompts: List[ImagePromptJson | ImagePrompt]
38
+
39
+
40
+ class Text2ImgRequestWithPrompt(CommonRequest):
41
+ """text to image request with prompt"""
42
+ image_prompts: List[ImagePromptJson] = []
43
+
44
+
45
+ class ImgUpscaleOrVaryRequestJson(CommonRequest):
46
+ """img upscale or vary request json"""
47
+ uov_method: UpscaleOrVaryMethod = UpscaleOrVaryMethod.upscale_2
48
+ upscale_value: float | None = Field(1.0, ge=1.0, le=5.0, description="Upscale custom value, 1.0 for default value")
49
+ input_image: str = Field(description="Init image for upscale or outpaint as base64")
50
+ image_prompts: List[ImagePromptJson | ImagePrompt] = []
fooocusapi/parameters.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+ import numpy as np
3
+ import copy
4
+
5
+ from fooocusapi.models.common.requests import AdvancedParams
6
+
7
+
8
+ class ImageGenerationParams:
9
+ def __init__(
10
+ self,
11
+ prompt: str,
12
+ negative_prompt: str,
13
+ style_selections: List[str],
14
+ performance_selection: str,
15
+ aspect_ratios_selection: str,
16
+ image_number: int,
17
+ image_seed: int | None,
18
+ sharpness: float,
19
+ guidance_scale: float,
20
+ base_model_name: str,
21
+ refiner_model_name: str,
22
+ refiner_switch: float,
23
+ loras: List[Tuple[str, float]],
24
+ uov_input_image: np.ndarray | None,
25
+ uov_method: str,
26
+ upscale_value: float | None,
27
+ outpaint_selections: List[str],
28
+ outpaint_distance_left: int,
29
+ outpaint_distance_right: int,
30
+ outpaint_distance_top: int,
31
+ outpaint_distance_bottom: int,
32
+ inpaint_input_image: Dict[str, np.ndarray] | None,
33
+ inpaint_additional_prompt: str | None,
34
+ image_prompts: List[Tuple[np.ndarray, float, float, str]],
35
+ advanced_params: List[any] | None,
36
+ save_extension: str,
37
+ save_meta: bool,
38
+ meta_scheme: str,
39
+ save_name: str,
40
+ require_base64: bool,
41
+ ):
42
+ self.prompt = prompt
43
+ self.negative_prompt = negative_prompt
44
+ self.style_selections = style_selections
45
+ self.performance_selection = performance_selection
46
+ self.aspect_ratios_selection = aspect_ratios_selection
47
+ self.image_number = image_number
48
+ self.image_seed = image_seed
49
+ self.sharpness = sharpness
50
+ self.guidance_scale = guidance_scale
51
+ self.base_model_name = base_model_name
52
+ self.refiner_model_name = refiner_model_name
53
+ self.refiner_switch = refiner_switch
54
+ self.loras = loras
55
+ self.uov_input_image = uov_input_image
56
+ self.uov_method = uov_method
57
+ self.upscale_value = upscale_value
58
+ self.outpaint_selections = outpaint_selections
59
+ self.outpaint_distance_left = outpaint_distance_left
60
+ self.outpaint_distance_right = outpaint_distance_right
61
+ self.outpaint_distance_top = outpaint_distance_top
62
+ self.outpaint_distance_bottom = outpaint_distance_bottom
63
+ self.inpaint_input_image = inpaint_input_image
64
+ self.inpaint_additional_prompt = inpaint_additional_prompt
65
+ self.image_prompts = image_prompts
66
+ self.save_extension = save_extension
67
+ self.save_meta = save_meta
68
+ self.meta_scheme = meta_scheme
69
+ self.save_name = save_name
70
+ self.require_base64 = require_base64
71
+ self.advanced_params = advanced_params
72
+
73
+ if self.advanced_params is None:
74
+ self.advanced_params = AdvancedParams()
75
+
76
+ # Auto set mixing_image_prompt_and_inpaint to True
77
+ if len(self.image_prompts) > 0 and self.inpaint_input_image is not None:
78
+ print("Mixing Image Prompts and Inpaint Enabled")
79
+ self.advanced_params.mixing_image_prompt_and_inpaint = True
80
+ if len(self.image_prompts) > 0 and self.uov_input_image is not None:
81
+ print("Mixing Image Prompts and Vary Upscale Enabled")
82
+ self.advanced_params.mixing_image_prompt_and_vary_upscale = True
83
+
84
+ def to_dict(self):
85
+ """
86
+ Convert the ImageGenerationParams object to a dictionary.
87
+ Args:
88
+ self:
89
+
90
+ Returns:
91
+ self to dict
92
+ """
93
+ obj_dict = copy.deepcopy(self)
94
+ return obj_dict.__dict__
fooocusapi/routes/__init__.py ADDED
File without changes
fooocusapi/routes/generate_v1.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate API V1 routes
2
+
3
+ """
4
+ from typing import List, Optional
5
+ from fastapi import APIRouter, Depends, Header, Query, UploadFile
6
+ from fastapi.params import File
7
+
8
+ from modules.util import HWC3
9
+
10
+ from fooocusapi.models.common.base import DescribeImageType
11
+ from fooocusapi.utils.api_utils import api_key_auth
12
+
13
+ from fooocusapi.models.common.requests import CommonRequest as Text2ImgRequest
14
+ from fooocusapi.models.requests_v1 import (
15
+ ImgUpscaleOrVaryRequest,
16
+ ImgPromptRequest,
17
+ ImgInpaintOrOutpaintRequest
18
+ )
19
+ from fooocusapi.models.common.response import (
20
+ AsyncJobResponse,
21
+ GeneratedImageResult,
22
+ DescribeImageResponse,
23
+ StopResponse
24
+ )
25
+ from fooocusapi.utils.call_worker import call_worker
26
+ from fooocusapi.utils.img_utils import read_input_image
27
+ from fooocusapi.configs.default import img_generate_responses
28
+ from fooocusapi.worker import process_stop
29
+
30
+
31
+ secure_router = APIRouter(
32
+ dependencies=[Depends(api_key_auth)]
33
+ )
34
+
35
+
36
+ def stop_worker():
37
+ """Interrupt worker process"""
38
+ process_stop()
39
+
40
+
41
+ @secure_router.post(
42
+ path="/v1/generation/text-to-image",
43
+ response_model=List[GeneratedImageResult] | AsyncJobResponse,
44
+ responses=img_generate_responses,
45
+ tags=["GenerateV1"])
46
+ def text2img_generation(
47
+ req: Text2ImgRequest,
48
+ accept: str = Header(None),
49
+ accept_query: str | None = Query(
50
+ None, alias='accept',
51
+ description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
52
+ """\nText to Image Generation\n
53
+ A text to image generation endpoint
54
+ Arguments:
55
+ req {Text2ImgRequest} -- Text to image generation request
56
+ accept {str} -- Accept header
57
+ accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
58
+ returns:
59
+ Response -- img_generate_responses
60
+ """
61
+ if accept_query is not None and len(accept_query) > 0:
62
+ accept = accept_query
63
+
64
+ return call_worker(req, accept)
65
+
66
+
67
+ @secure_router.post(
68
+ path="/v1/generation/image-upscale-vary",
69
+ response_model=List[GeneratedImageResult] | AsyncJobResponse,
70
+ responses=img_generate_responses,
71
+ tags=["GenerateV1"])
72
+ def img_upscale_or_vary(
73
+ input_image: UploadFile,
74
+ req: ImgUpscaleOrVaryRequest = Depends(ImgUpscaleOrVaryRequest.as_form),
75
+ accept: str = Header(None),
76
+ accept_query: str | None = Query(
77
+ None, alias='accept',
78
+ description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
79
+ """\nImage upscale or vary\n
80
+ Image upscale or vary
81
+ Arguments:
82
+ input_image {UploadFile} -- Input image file
83
+ req {ImgUpscaleOrVaryRequest} -- Request body
84
+ accept {str} -- Accept header
85
+ accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
86
+ Returns:
87
+ Response -- img_generate_responses
88
+ """
89
+ if accept_query is not None and len(accept_query) > 0:
90
+ accept = accept_query
91
+
92
+ return call_worker(req, accept)
93
+
94
+
95
+ @secure_router.post(
96
+ path="/v1/generation/image-inpaint-outpaint",
97
+ response_model=List[GeneratedImageResult] | AsyncJobResponse,
98
+ responses=img_generate_responses,
99
+ tags=["GenerateV1"])
100
+ def img_inpaint_or_outpaint(
101
+ input_image: UploadFile,
102
+ req: ImgInpaintOrOutpaintRequest = Depends(ImgInpaintOrOutpaintRequest.as_form),
103
+ accept: str = Header(None),
104
+ accept_query: str | None = Query(
105
+ None, alias='accept',
106
+ description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
107
+ """\nInpaint or outpaint\n
108
+ Inpaint or outpaint
109
+ Arguments:
110
+ input_image {UploadFile} -- Input image file
111
+ req {ImgInpaintOrOutpaintRequest} -- Request body
112
+ accept {str} -- Accept header
113
+ accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
114
+ """
115
+ if accept_query is not None and len(accept_query) > 0:
116
+ accept = accept_query
117
+
118
+ return call_worker(req, accept)
119
+
120
+
121
+ @secure_router.post(
122
+ path="/v1/generation/image-prompt",
123
+ response_model=List[GeneratedImageResult] | AsyncJobResponse,
124
+ responses=img_generate_responses,
125
+ tags=["GenerateV1"])
126
+ def img_prompt(
127
+ cn_img1: Optional[UploadFile] = File(None),
128
+ req: ImgPromptRequest = Depends(ImgPromptRequest.as_form),
129
+ accept: str = Header(None),
130
+ accept_query: str | None = Query(
131
+ None, alias='accept',
132
+ description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
133
+ """\nImage Prompt\n
134
+ Image Prompt
135
+ A prompt-based image generation.
136
+ Arguments:
137
+ cn_img1 {UploadFile} -- Input image file
138
+ req {ImgPromptRequest} -- Request body
139
+ accept {str} -- Accept header
140
+ accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
141
+ Returns:
142
+ Response -- img_generate_responses
143
+ """
144
+ if accept_query is not None and len(accept_query) > 0:
145
+ accept = accept_query
146
+
147
+ return call_worker(req, accept)
148
+
149
+
150
+ @secure_router.post(
151
+ path="/v1/tools/describe-image",
152
+ response_model=DescribeImageResponse,
153
+ tags=["GenerateV1"])
154
+ def describe_image(
155
+ image: UploadFile,
156
+ image_type: DescribeImageType = Query(
157
+ DescribeImageType.photo,
158
+ description="Image type, 'Photo' or 'Anime'")):
159
+ """\nDescribe image\n
160
+ Describe image, Get tags from an image
161
+ Arguments:
162
+ image {UploadFile} -- Image to get tags
163
+ image_type {DescribeImageType} -- Image type, 'Photo' or 'Anime'
164
+ Returns:
165
+ DescribeImageResponse -- Describe image response, a string
166
+ """
167
+ if image_type == DescribeImageType.photo:
168
+ from extras.interrogate import default_interrogator as default_interrogator_photo
169
+ interrogator = default_interrogator_photo
170
+ else:
171
+ from extras.wd14tagger import default_interrogator as default_interrogator_anime
172
+ interrogator = default_interrogator_anime
173
+ img = HWC3(read_input_image(image))
174
+ result = interrogator(img)
175
+ return DescribeImageResponse(describe=result)
176
+
177
+
178
+ @secure_router.post(
179
+ path="/v1/generation/stop",
180
+ response_model=StopResponse,
181
+ description="Job stopping",
182
+ tags=["Default"])
183
+ def stop():
184
+ """Interrupt worker"""
185
+ stop_worker()
186
+ return StopResponse(msg="success")
fooocusapi/routes/generate_v2.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate API V2 routes
2
+
3
+ """
4
+ from typing import List
5
+ from fastapi import APIRouter, Depends, Header, Query
6
+
7
+ from fooocusapi.utils.api_utils import api_key_auth
8
+ from fooocusapi.models.requests_v1 import ImagePrompt
9
+ from fooocusapi.models.requests_v2 import (
10
+ ImgInpaintOrOutpaintRequestJson,
11
+ ImgPromptRequestJson,
12
+ Text2ImgRequestWithPrompt,
13
+ ImgUpscaleOrVaryRequestJson
14
+ )
15
+ from fooocusapi.models.common.response import (
16
+ AsyncJobResponse,
17
+ GeneratedImageResult
18
+ )
19
+ from fooocusapi.utils.call_worker import call_worker
20
+ from fooocusapi.utils.img_utils import base64_to_stream
21
+ from fooocusapi.configs.default import img_generate_responses
22
+
23
+
24
+ secure_router = APIRouter(
25
+ dependencies=[Depends(api_key_auth)]
26
+ )
27
+
28
+
29
+ @secure_router.post(
30
+ path="/v2/generation/text-to-image-with-ip",
31
+ response_model=List[GeneratedImageResult] | AsyncJobResponse,
32
+ responses=img_generate_responses,
33
+ tags=["GenerateV2"])
34
+ def text_to_img_with_ip(
35
+ req: Text2ImgRequestWithPrompt,
36
+ accept: str = Header(None),
37
+ accept_query: str | None = Query(
38
+ default=None, alias='accept',
39
+ description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
40
+ """\nText to image with prompt\n
41
+ Text to image with prompt
42
+ Arguments:
43
+ req {Text2ImgRequestWithPrompt} -- Text to image generation request
44
+ accept {str} -- Accept header
45
+ accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
46
+ Returns:
47
+ Response -- img_generate_responses
48
+ """
49
+ if accept_query is not None and len(accept_query) > 0:
50
+ accept = accept_query
51
+
52
+ default_image_prompt = ImagePrompt(cn_img=None)
53
+ image_prompts_files: List[ImagePrompt] = []
54
+ for image_prompt in req.image_prompts:
55
+ image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
56
+ image = ImagePrompt(
57
+ cn_img=image_prompt.cn_img,
58
+ cn_stop=image_prompt.cn_stop,
59
+ cn_weight=image_prompt.cn_weight,
60
+ cn_type=image_prompt.cn_type)
61
+ image_prompts_files.append(image)
62
+
63
+ while len(image_prompts_files) <= 4:
64
+ image_prompts_files.append(default_image_prompt)
65
+
66
+ req.image_prompts = image_prompts_files
67
+
68
+ return call_worker(req, accept)
69
+
70
+
71
+ @secure_router.post(
72
+ path="/v2/generation/image-upscale-vary",
73
+ response_model=List[GeneratedImageResult] | AsyncJobResponse,
74
+ responses=img_generate_responses,
75
+ tags=["GenerateV2"])
76
+ def img_upscale_or_vary(
77
+ req: ImgUpscaleOrVaryRequestJson,
78
+ accept: str = Header(None),
79
+ accept_query: str | None = Query(
80
+ None, alias='accept', description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
81
+ """\nImage upscale or vary\n
82
+ Image upscale or vary
83
+ Arguments:
84
+ req {ImgUpscaleOrVaryRequestJson} -- Image upscale or vary request
85
+ accept {str} -- Accept header
86
+ accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
87
+ Returns:
88
+ Response -- img_generate_responses
89
+ """
90
+ if accept_query is not None and len(accept_query) > 0:
91
+ accept = accept_query
92
+
93
+ req.input_image = base64_to_stream(req.input_image)
94
+
95
+ default_image_prompt = ImagePrompt(cn_img=None)
96
+ image_prompts_files: List[ImagePrompt] = []
97
+ for image_prompt in req.image_prompts:
98
+ image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
99
+ image = ImagePrompt(
100
+ cn_img=image_prompt.cn_img,
101
+ cn_stop=image_prompt.cn_stop,
102
+ cn_weight=image_prompt.cn_weight,
103
+ cn_type=image_prompt.cn_type)
104
+ image_prompts_files.append(image)
105
+ while len(image_prompts_files) <= 4:
106
+ image_prompts_files.append(default_image_prompt)
107
+ req.image_prompts = image_prompts_files
108
+
109
+ return call_worker(req, accept)
110
+
111
+
112
+ @secure_router.post(
113
+ path="/v2/generation/image-inpaint-outpaint",
114
+ response_model=List[GeneratedImageResult] | AsyncJobResponse,
115
+ responses=img_generate_responses,
116
+ tags=["GenerateV2"])
117
+ def img_inpaint_or_outpaint(
118
+ req: ImgInpaintOrOutpaintRequestJson,
119
+ accept: str = Header(None),
120
+ accept_query: str | None = Query(
121
+ None, alias='accept',
122
+ description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
123
+ """\nInpaint or outpaint\n
124
+ Inpaint or outpaint
125
+ Arguments:
126
+ req {ImgInpaintOrOutpaintRequestJson} -- Request body
127
+ accept {str} -- Accept header
128
+ accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
129
+ Returns:
130
+ Response -- img_generate_responses
131
+ """
132
+ if accept_query is not None and len(accept_query) > 0:
133
+ accept = accept_query
134
+
135
+ req.input_image = base64_to_stream(req.input_image)
136
+ if req.input_mask is not None:
137
+ req.input_mask = base64_to_stream(req.input_mask)
138
+ default_image_prompt = ImagePrompt(cn_img=None)
139
+ image_prompts_files: List[ImagePrompt] = []
140
+ for image_prompt in req.image_prompts:
141
+ image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
142
+ image = ImagePrompt(
143
+ cn_img=image_prompt.cn_img,
144
+ cn_stop=image_prompt.cn_stop,
145
+ cn_weight=image_prompt.cn_weight,
146
+ cn_type=image_prompt.cn_type)
147
+ image_prompts_files.append(image)
148
+ while len(image_prompts_files) <= 4:
149
+ image_prompts_files.append(default_image_prompt)
150
+ req.image_prompts = image_prompts_files
151
+
152
+ return call_worker(req, accept)
153
+
154
+
155
+ @secure_router.post(
156
+ path="/v2/generation/image-prompt",
157
+ response_model=List[GeneratedImageResult] | AsyncJobResponse,
158
+ responses=img_generate_responses,
159
+ tags=["GenerateV2"])
160
+ def img_prompt(
161
+ req: ImgPromptRequestJson,
162
+ accept: str = Header(None),
163
+ accept_query: str | None = Query(
164
+ None, alias='accept',
165
+ description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
166
+ """\nImage prompt\n
167
+ Image prompt generation
168
+ Arguments:
169
+ req {ImgPromptRequest} -- Request body
170
+ accept {str} -- Accept header
171
+ accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
172
+ Returns:
173
+ Response -- img_generate_responses
174
+ """
175
+ if accept_query is not None and len(accept_query) > 0:
176
+ accept = accept_query
177
+
178
+ if req.input_image is not None:
179
+ req.input_image = base64_to_stream(req.input_image)
180
+ if req.input_mask is not None:
181
+ req.input_mask = base64_to_stream(req.input_mask)
182
+
183
+ default_image_prompt = ImagePrompt(cn_img=None)
184
+ image_prompts_files: List[ImagePrompt] = []
185
+ for image_prompt in req.image_prompts:
186
+ image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
187
+ image = ImagePrompt(
188
+ cn_img=image_prompt.cn_img,
189
+ cn_stop=image_prompt.cn_stop,
190
+ cn_weight=image_prompt.cn_weight,
191
+ cn_type=image_prompt.cn_type)
192
+ image_prompts_files.append(image)
193
+
194
+ while len(image_prompts_files) <= 4:
195
+ image_prompts_files.append(default_image_prompt)
196
+
197
+ req.image_prompts = image_prompts_files
198
+
199
+ return call_worker(req, accept)
fooocusapi/routes/query.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query API"""
2
+ from typing import List
3
+ from fastapi import Depends, Response, APIRouter
4
+
5
+ from fooocusapi.args import args
6
+
7
+ from fooocusapi.models.common.requests import QueryJobRequest
8
+ from fooocusapi.models.common.response import (
9
+ AsyncJobResponse,
10
+ JobHistoryInfo,
11
+ JobQueueInfo,
12
+ JobHistoryResponse,
13
+ AllModelNamesResponse
14
+ )
15
+ from fooocusapi.models.common.task import AsyncJobStage
16
+
17
+ from fooocusapi.utils.api_utils import generate_async_output, api_key_auth
18
+ from fooocusapi.task_queue import TaskType
19
+ from fooocusapi.worker import worker_queue
20
+
21
+ secure_router = APIRouter(dependencies=[Depends(api_key_auth)])
22
+
23
+
24
+ @secure_router.get(path="/", tags=['Query'])
25
+ def home():
26
+ """Home page"""
27
+ return Response(
28
+ content='Swagger-UI to: <a href="/docs">/docs</a>',
29
+ media_type="text/html"
30
+ )
31
+
32
+
33
+ @secure_router.get(
34
+ path="/ping",
35
+ description="Returns a simple 'pong'",
36
+ tags=['Query'])
37
+ async def ping():
38
+ """\nPing\n
39
+ Ping page, just to check if the fastapi is up.
40
+ Instant return correct, does not mean the service is available.
41
+ Returns:
42
+ A simple string pong
43
+ """
44
+ return 'pong'
45
+
46
+
47
+ @secure_router.get(
48
+ path="/v1/generation/query-job",
49
+ response_model=AsyncJobResponse,
50
+ description="Query async generation job",
51
+ tags=['Query'])
52
+ def query_job(req: QueryJobRequest = Depends()):
53
+ """query job info by id"""
54
+ queue_task = worker_queue.get_task(req.job_id, True)
55
+ if queue_task is None:
56
+ result = AsyncJobResponse(
57
+ job_id="",
58
+ job_type=TaskType.not_found,
59
+ job_stage=AsyncJobStage.error,
60
+ job_progress=0,
61
+ job_status="Job not found")
62
+ content = result.model_dump_json()
63
+ return Response(content=content, media_type='application/json', status_code=404)
64
+ return generate_async_output(queue_task, req.require_step_preview)
65
+
66
+
67
+ @secure_router.get(
68
+ path="/v1/generation/job-queue",
69
+ response_model=JobQueueInfo,
70
+ description="Query job queue info",
71
+ tags=['Query'])
72
+ def job_queue():
73
+ """Query job queue info"""
74
+ queue = JobQueueInfo(
75
+ running_size=len(worker_queue.queue),
76
+ finished_size=len(worker_queue.history),
77
+ last_job_id=worker_queue.last_job_id
78
+ )
79
+ return queue
80
+
81
+
82
+ @secure_router.get(
83
+ path="/v1/generation/job-history",
84
+ response_model=JobHistoryResponse | dict,
85
+ description="Query historical job data",
86
+ tags=["Query"])
87
+ def get_history(job_id: str = None, page: int = 0, page_size: int = 20):
88
+ """Fetch and return the historical tasks"""
89
+ queue = [
90
+ JobHistoryInfo(
91
+ job_id=item.job_id,
92
+ is_finished=item.is_finished
93
+ ) for item in worker_queue.queue
94
+ ]
95
+ if not args.persistent:
96
+ history = [
97
+ JobHistoryInfo(
98
+ job_id=item.job_id,
99
+ is_finished=item.is_finished
100
+ ) for item in worker_queue.history
101
+ ]
102
+ return JobHistoryResponse(history=history, queue=queue)
103
+
104
+ from fooocusapi.sql_client import query_history
105
+ history = query_history(task_id=job_id, page=page, page_size=page_size)
106
+ return {
107
+ "history": history,
108
+ "queue": queue
109
+ }
110
+
111
+
112
+ @secure_router.get(
113
+ path="/v1/engines/all-models",
114
+ response_model=AllModelNamesResponse,
115
+ description="Get all filenames of base model and lora",
116
+ tags=["Query"])
117
+ def all_models():
118
+ """Refresh and return all models"""
119
+ from modules import config
120
+ config.update_files()
121
+ models = AllModelNamesResponse(
122
+ model_filenames=config.model_filenames,
123
+ lora_filenames=config.lora_filenames)
124
+ return models
125
+
126
+
127
+ @secure_router.get(
128
+ path="/v1/engines/styles",
129
+ response_model=List[str],
130
+ description="Get all legal Fooocus styles",
131
+ tags=['Query'])
132
+ def all_styles():
133
+ """Return all available styles"""
134
+ from modules.sdxl_styles import legal_style_names
135
+ return legal_style_names
fooocusapi/sql_client.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLite client for Fooocus API
3
+ """
4
+ import os
5
+ import time
6
+ import platform
7
+ from datetime import datetime
8
+ from typing import Optional
9
+ import copy
10
+
11
+ from sqlalchemy import Integer, Float, VARCHAR, Boolean, JSON, Text, create_engine
12
+ from sqlalchemy.orm import declarative_base, Session, Mapped, mapped_column
13
+
14
+
15
+ Base = declarative_base()
16
+
17
+
18
+ if platform.system().lower() == "windows":
19
+ default_sqlite_db_path = os.path.join(
20
+ os.path.dirname(__file__), "../database.db"
21
+ ).replace("\\", "/")
22
+ else:
23
+ default_sqlite_db_path = os.path.join(os.path.dirname(__file__), "../database.db")
24
+
25
+ connection_uri = os.environ.get(
26
+ "FOOOCUS_DB_CONF", f"sqlite:///{default_sqlite_db_path}"
27
+ )
28
+
29
+
30
+ class GenerateRecord(Base):
31
+ """
32
+ GenerateRecord
33
+
34
+ __tablename__ = 'generate_record'
35
+ """
36
+
37
+ __tablename__ = "generate_record"
38
+
39
+ task_id: Mapped[str] = mapped_column(VARCHAR(255), nullable=False, primary_key=True)
40
+ task_type: Mapped[str] = mapped_column(Text, nullable=False)
41
+ result_url: Mapped[str] = mapped_column(Text, nullable=True)
42
+ finish_reason: Mapped[str] = mapped_column(Text, nullable=True)
43
+ date_time: Mapped[int] = mapped_column(Integer, nullable=False)
44
+
45
+ prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
46
+ negative_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
47
+ style_selections: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
48
+ performance_selection: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
49
+ aspect_ratios_selection: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
50
+ base_model_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
51
+ refiner_model_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
52
+ refiner_switch: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
53
+ loras: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
54
+ image_number: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
55
+ image_seed: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
56
+ sharpness: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
57
+ guidance_scale: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
58
+ advanced_params: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
59
+
60
+ input_image: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
61
+ input_mask: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
62
+ image_prompts: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
63
+ inpaint_additional_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
64
+ outpaint_selections: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
65
+ outpaint_distance_left: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
66
+ outpaint_distance_right: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
67
+ outpaint_distance_top: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
68
+ outpaint_distance_bottom: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
69
+ uov_method: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
70
+ upscale_value: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
71
+
72
+ webhook_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
73
+ require_base64: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
74
+ async_process: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
75
+
76
+ def __repr__(self) -> str:
77
+ return f"GenerateRecord(task_id={self.task_id!r}, task_type={self.task_type!r}, \
78
+ result_url={self.result_url!r}, finish_reason={self.finish_reason!r}, date_time={self.date_time!r}, \
79
+ prompt={self.prompt!r}, negative_prompt={self.negative_prompt!r}, style_selections={self.style_selections!r}, performance_selection={self.performance_selection!r}, \
80
+ aspect_ratios_selection={self.aspect_ratios_selection!r}, base_model_name={self.base_model_name!r}, \
81
+ refiner_model_name={self.refiner_model_name!r}, refiner_switch={self.refiner_switch!r}, loras={self.loras!r}, \
82
+ image_number={self.image_number!r}, image_seed={self.image_seed!r}, sharpness={self.sharpness!r}, \
83
+ guidance_scale={self.guidance_scale!r}, advanced_params={self.advanced_params!r}, input_image={self.input_image!r}, \
84
+ input_mask={self.input_mask!r}, image_prompts={self.image_prompts!r}, inpaint_additional_prompt={self.inpaint_additional_prompt!r}, \
85
+ outpaint_selections={self.outpaint_selections!r}, outpaint_distance_left={self.outpaint_distance_left!r}, outpaint_distance_right={self.outpaint_distance_right!r}, \
86
+ outpaint_distance_top={self.outpaint_distance_top!r}, outpaint_distance_bottom={self.outpaint_distance_bottom!r}, uov_method={self.uov_method!r}, \
87
+ upscale_value={self.upscale_value!r}, webhook_url={self.webhook_url!r}, require_base64={self.require_base64!r}, \
88
+ async_process={self.async_process!r})"
89
+
90
+
91
+ engine = create_engine(connection_uri)
92
+
93
+ session = Session(engine)
94
+ Base.metadata.create_all(engine, checkfirst=True)
95
+ session.close()
96
+
97
+
98
+ def convert_to_dict_list(obj_list: list[object]) -> list[dict]:
99
+ """
100
+ Convert a list of objects to a list of dictionaries.
101
+ Args:
102
+ obj_list:
103
+
104
+ Returns:
105
+ dict_list:
106
+ """
107
+ dict_list = []
108
+ for obj in obj_list:
109
+ # 将对象属性转化为字典键值对
110
+ dict_obj = {}
111
+ for attr, value in vars(obj).items():
112
+ if (
113
+ not callable(value)
114
+ and not attr.startswith("__")
115
+ and not attr.startswith("_")
116
+ ):
117
+ dict_obj[attr] = value
118
+ task_info = {
119
+ "task_id": obj.task_id,
120
+ "task_type": obj.task_type,
121
+ "result_url": obj.result_url,
122
+ "finish_reason": obj.finish_reason,
123
+ "date_time": datetime.fromtimestamp(obj.date_time).strftime(
124
+ "%Y-%m-%d %H:%M:%S"
125
+ ),
126
+ }
127
+ del dict_obj["task_id"]
128
+ del dict_obj["task_type"]
129
+ del dict_obj["result_url"]
130
+ del dict_obj["finish_reason"]
131
+ del dict_obj["date_time"]
132
+ dict_list.append({"params": dict_obj, "task_info": task_info})
133
+ return dict_list
134
+
135
+
136
+ class MySQLAlchemy:
137
+ """
138
+ MySQLAlchemy, a toolkit for managing SQLAlchemy connections and sessions.
139
+
140
+ :param uri: SQLAlchemy connection URI
141
+ """
142
+
143
+ def __init__(self, uri: str):
144
+ # 'mysql+pymysql://{username}:{password}@{host}:{port}/{database}'
145
+ self.engine = create_engine(uri)
146
+ self.session = Session(self.engine)
147
+
148
+ def store_history(self, record: dict) -> None:
149
+ """
150
+ Store history to database
151
+ :param record:
152
+ :return:
153
+ """
154
+ self.session.add_all([GenerateRecord(**record)])
155
+ self.session.commit()
156
+
157
+ def get_history(
158
+ self,
159
+ task_id: str = None,
160
+ page: int = 0,
161
+ page_size: int = 20,
162
+ order_by: str = "date_time",
163
+ ) -> list:
164
+ """
165
+ Get history from database
166
+ :param task_id:
167
+ :param page:
168
+ :param page_size:
169
+ :param order_by:
170
+ :return:
171
+ """
172
+ if task_id is not None:
173
+ res = (
174
+ self.session.query(GenerateRecord)
175
+ .filter(GenerateRecord.task_id == task_id)
176
+ .all()
177
+ )
178
+ if len(res) == 0:
179
+ return []
180
+ return convert_to_dict_list(res)
181
+
182
+ res = (
183
+ self.session.query(GenerateRecord)
184
+ .order_by(getattr(GenerateRecord, order_by).desc())
185
+ .offset(page * page_size)
186
+ .limit(page_size)
187
+ .all()
188
+ )
189
+ if len(res) == 0:
190
+ return []
191
+ return convert_to_dict_list(res)
192
+
193
+
194
+ db = MySQLAlchemy(uri=connection_uri)
195
+
196
+
197
+ def req_to_dict(req: dict) -> dict:
198
+ """
199
+ Convert request to dictionary
200
+ Args:
201
+ req:
202
+
203
+ Returns:
204
+
205
+ """
206
+ req["loras"] = [{"model_name": lora[0], "weight": lora[1]} for lora in req["loras"]]
207
+ # req["advanced_params"] = dict(zip(adv_params_keys, req["advanced_params"]))
208
+ req["image_prompts"] = [
209
+ {"cn_img": "", "cn_stop": image[1], "cn_weight": image[2], "cn_type": image[3]}
210
+ for image in req["image_prompts"]
211
+ ]
212
+ del req["inpaint_input_image"]
213
+ del req["uov_input_image"]
214
+ return req
215
+
216
+
217
+ def add_history(
218
+ params: dict, task_type: str, task_id: str, result_url: str, finish_reason: str
219
+ ) -> None:
220
+ """
221
+ Store history to database
222
+ Args:
223
+ params:
224
+ task_type:
225
+ task_id:
226
+ result_url:
227
+ finish_reason:
228
+
229
+ Returns:
230
+
231
+ """
232
+ adv = copy.deepcopy(params["advanced_params"])
233
+ params["advanced_params"] = adv.__dict__
234
+ params["date_time"] = int(time.time())
235
+ params["task_type"] = task_type
236
+ params["task_id"] = task_id
237
+ params["result_url"] = result_url
238
+ params["finish_reason"] = finish_reason
239
+
240
+ del params["inpaint_input_image"]
241
+ del params["uov_input_image"]
242
+ del params["save_extension"]
243
+ del params["save_meta"]
244
+ del params["save_name"]
245
+ del params["meta_scheme"]
246
+
247
+ db.store_history(params)
248
+
249
+
250
+ def query_history(
251
+ task_id: str = None,
252
+ page: int = 0,
253
+ page_size: int = 20,
254
+ order_by: str = "date_time"
255
+ ) -> list:
256
+ """
257
+ Query history from database
258
+ Args:
259
+ task_id:
260
+ page:
261
+ page_size:
262
+ order_by:
263
+
264
+ Returns:
265
+
266
+ """
267
+ return db.get_history(
268
+ task_id=task_id, page=page, page_size=page_size, order_by=order_by
269
+ )
fooocusapi/task_queue.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task queue management
3
+
4
+ This module provides classes and functions for managing the task queue.
5
+
6
+ Classes:
7
+ QueueTask: A class representing a task in the queue.
8
+ TaskQueue: A class for managing the task queue.
9
+ """
10
+ import uuid
11
+ import time
12
+ from typing import List, Tuple
13
+ import numpy as np
14
+ import requests
15
+
16
+ from fooocusapi.utils.file_utils import delete_output_file, get_file_serve_url
17
+ from fooocusapi.utils.img_utils import narray_to_base64img
18
+ from fooocusapi.utils.logger import logger
19
+
20
+ from fooocusapi.models.common.task import ImageGenerationResult, GenerationFinishReason
21
+ from fooocusapi.parameters import ImageGenerationParams
22
+ from fooocusapi.models.common.task import TaskType
23
+
24
+
25
+ class QueueTask:
26
+ """
27
+ A class representing a task in the queue.
28
+
29
+ Attributes:
30
+ job_id (str): The unique identifier for the task, generated by uuid.
31
+ task_type (TaskType): The type of task.
32
+ is_finished (bool): Indicates whether the task has been completed.
33
+ finish_progress (int): The progress of the task completion.
34
+ in_queue_mills (int): The time the task was added to the queue, in milliseconds.
35
+ start_mills (int): The time the task started, in milliseconds.
36
+ finish_mills (int): The time the task finished, in milliseconds.
37
+ finish_with_error (bool): Indicates whether the task finished with an error.
38
+ task_status (str): The status of the task.
39
+ task_step_preview (str): A list of step previews for the task.
40
+ task_result (List[ImageGenerationResult]): The result of the task.
41
+ error_message (str): The error message, if any.
42
+ webhook_url (str): The webhook URL, if any.
43
+ """
44
+
45
+ job_id: str
46
+ task_type: TaskType
47
+ req_param: ImageGenerationParams
48
+ is_finished: bool = False
49
+ finish_progress: int = 0
50
+ in_queue_mills: int
51
+ start_mills: int = 0
52
+ finish_mills: int = 0
53
+ finish_with_error: bool = False
54
+ task_status: str | None = None
55
+ task_step_preview: str | None = None
56
+ task_result: List[ImageGenerationResult] = None
57
+ error_message: str | None = None
58
+ webhook_url: str | None = None # attribute for individual webhook_url
59
+
60
+ def __init__(
61
+ self,
62
+ job_id: str,
63
+ task_type: TaskType,
64
+ req_param: ImageGenerationParams,
65
+ webhook_url: str | None = None,
66
+ ):
67
+ self.job_id = job_id
68
+ self.task_type = task_type
69
+ self.req_param = req_param
70
+ self.in_queue_mills = int(round(time.time() * 1000))
71
+ self.webhook_url = webhook_url
72
+
73
+ def set_progress(self, progress: int, status: str | None):
74
+ """
75
+ Set progress and status
76
+ Arguments:
77
+ progress {int} -- progress
78
+ status {str} -- status
79
+ """
80
+ progress = min(progress, 100)
81
+ self.finish_progress = progress
82
+ self.task_status = status
83
+
84
+ def set_step_preview(self, task_step_preview: str | None):
85
+ """set step preview
86
+ Set step preview
87
+ Arguments:
88
+ task_step_preview {str} -- step preview
89
+ """
90
+ self.task_step_preview = task_step_preview
91
+
92
+ def set_result(
93
+ self,
94
+ task_result: List[ImageGenerationResult],
95
+ finish_with_error: bool,
96
+ error_message: str | None = None,
97
+ ):
98
+ """set result
99
+ Set task result
100
+ Arguments:
101
+ task_result {List[ImageGenerationResult]} -- task result
102
+ finish_with_error {bool} -- finish with error
103
+ error_message {str} -- error message
104
+ """
105
+ if not finish_with_error:
106
+ self.finish_progress = 100
107
+ self.task_status = "Finished"
108
+ self.task_result = task_result
109
+ self.finish_with_error = finish_with_error
110
+ self.error_message = error_message
111
+
112
+ def __str__(self) -> str:
113
+ return f"QueueTask(job_id={self.job_id}, task_type={self.task_type},\
114
+ is_finished={self.is_finished}, finished_progress={self.finish_progress}, \
115
+ in_queue_mills={self.in_queue_mills}, start_mills={self.start_mills}, \
116
+ finish_mills={self.finish_mills}, finish_with_error={self.finish_with_error}, \
117
+ error_message={self.error_message}, task_status={self.task_status}, \
118
+ task_step_preview={self.task_step_preview}, webhook_url={self.webhook_url})"
119
+
120
+
121
+ class TaskQueue:
122
+ """
123
+ TaskQueue is a queue of tasks that are waiting to be processed.
124
+
125
+ Attributes:
126
+ queue: List[QueueTask]
127
+ history: List[QueueTask]
128
+ last_job_id: str
129
+ webhook_url: str
130
+ persistent: bool
131
+ """
132
+
133
+ queue: List[QueueTask] = []
134
+ history: List[QueueTask] = []
135
+ last_job_id: str = None
136
+ webhook_url: str | None = None
137
+ persistent: bool = False
138
+
139
+ def __init__(
140
+ self,
141
+ queue_size: int,
142
+ history_size: int,
143
+ webhook_url: str | None = None,
144
+ persistent: bool | None = False,
145
+ ):
146
+ self.queue_size = queue_size
147
+ self.history_size = history_size
148
+ self.webhook_url = webhook_url
149
+ self.persistent = False if persistent is None else persistent
150
+
151
+ def add_task(
152
+ self,
153
+ task_type: TaskType,
154
+ req_param: ImageGenerationParams,
155
+ webhook_url: str | None = None,
156
+ ) -> QueueTask | None:
157
+ """
158
+ Create and add task to queue
159
+ :param task_type: task type
160
+ :param req_param: request parameters
161
+ :param webhook_url: webhook url
162
+ :returns: The created task's job_id, or None if reach the queue size limit
163
+ """
164
+ if len(self.queue) >= self.queue_size:
165
+ return None
166
+
167
+ if isinstance(req_param, dict):
168
+ req_param = ImageGenerationParams(**req_param)
169
+
170
+ job_id = str(uuid.uuid4())
171
+ task = QueueTask(
172
+ job_id=job_id,
173
+ task_type=task_type,
174
+ req_param=req_param,
175
+ webhook_url=webhook_url,
176
+ )
177
+ self.queue.append(task)
178
+ self.last_job_id = job_id
179
+ return task
180
+
181
+ def get_task(self, job_id: str, include_history: bool = False) -> QueueTask | None:
182
+ """
183
+ Get task by job_id
184
+ :param job_id: job id
185
+ :param include_history: whether to include history tasks
186
+ :returns: The task with the given job_id, or None if not found
187
+ """
188
+ for task in self.queue:
189
+ if task.job_id == job_id:
190
+ return task
191
+
192
+ if include_history:
193
+ for task in self.history:
194
+ if task.job_id == job_id:
195
+ return task
196
+
197
+ return None
198
+
199
+ def is_task_ready_to_start(self, job_id: str) -> bool:
200
+ """
201
+ Check if the task is ready to start
202
+ :param job_id: job id
203
+ :returns: True if the task is ready to start, False otherwise
204
+ """
205
+ task = self.get_task(job_id)
206
+ if task is None:
207
+ return False
208
+
209
+ return self.queue[0].job_id == job_id
210
+
211
+ def is_task_finished(self, job_id: str) -> bool:
212
+ """
213
+ Check if the task is finished
214
+ :param job_id: job id
215
+ :returns: True if the task is finished, False otherwise
216
+ """
217
+ task = self.get_task(job_id, True)
218
+ if task is None:
219
+ return False
220
+
221
+ return task.is_finished
222
+
223
+ def start_task(self, job_id: str):
224
+ """
225
+ Start task by job_id
226
+ :param job_id: job id
227
+ """
228
+ task = self.get_task(job_id)
229
+ if task is not None:
230
+ task.start_mills = int(round(time.time() * 1000))
231
+
232
+ def finish_task(self, job_id: str):
233
+ """
234
+ Finish task by job_id
235
+ :param job_id: job id
236
+ """
237
+ task = self.get_task(job_id)
238
+ if task is not None:
239
+ task.is_finished = True
240
+ task.finish_mills = int(round(time.time() * 1000))
241
+
242
+ # Use the task's webhook_url if available, else use the default
243
+ webhook_url = task.webhook_url or self.webhook_url
244
+
245
+ data = {"job_id": task.job_id, "job_result": []}
246
+
247
+ if isinstance(task.task_result, List):
248
+ for item in task.task_result:
249
+ data["job_result"].append(
250
+ {
251
+ "url": get_file_serve_url(item.im) if item.im else None,
252
+ "seed": item.seed if item.seed else "-1",
253
+ }
254
+ )
255
+
256
+ # Send webhook
257
+ if task.is_finished and webhook_url:
258
+ try:
259
+ res = requests.post(webhook_url, json=data, timeout=15)
260
+ print(f"Call webhook response status: {res.status_code}")
261
+ except Exception as e:
262
+ print("Call webhook error:", e)
263
+
264
+ # Move task to history
265
+ self.queue.remove(task)
266
+ self.history.append(task)
267
+
268
+ # save history to database
269
+ if self.persistent:
270
+ from fooocusapi.sql_client import add_history
271
+
272
+ add_history(
273
+ params=task.req_param.to_dict(),
274
+ task_type=task.task_type.value,
275
+ task_id=task.job_id,
276
+ result_url=",".join([job["url"] for job in data["job_result"]]),
277
+ finish_reason=task.task_result[0].finish_reason.value,
278
+ )
279
+
280
+ # Clean history
281
+ if len(self.history) > self.history_size != 0:
282
+ removed_task = self.history.pop(0)
283
+ if isinstance(removed_task.task_result, List):
284
+ for item in removed_task.task_result:
285
+ if (
286
+ isinstance(item, ImageGenerationResult)
287
+ and item.finish_reason == GenerationFinishReason.success
288
+ and item.im is not None
289
+ ):
290
+ delete_output_file(item.im)
291
+ logger.std_info(
292
+ f"[TaskQueue] Clean task history, remove task: {removed_task.job_id}"
293
+ )
294
+
295
+
296
+ class TaskOutputs:
297
+ """
298
+ TaskOutputs is a container for task outputs
299
+ """
300
+
301
+ outputs = []
302
+
303
+ def __init__(self, task: QueueTask):
304
+ self.task = task
305
+
306
+ def append(self, args: List[any]):
307
+ """
308
+ Append output to task outputs list
309
+ :param args: output arguments
310
+ """
311
+ self.outputs.append(args)
312
+ if len(args) >= 2:
313
+ if (
314
+ args[0] == "preview"
315
+ and isinstance(args[1], Tuple)
316
+ and len(args[1]) >= 2
317
+ ):
318
+ number = args[1][0]
319
+ text = args[1][1]
320
+ self.task.set_progress(number, text)
321
+ if len(args[1]) >= 3 and isinstance(args[1][2], np.ndarray):
322
+ base64_preview_img = narray_to_base64img(args[1][2])
323
+ self.task.set_step_preview(base64_preview_img)
fooocusapi/utils/api_utils.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """some utils for api"""
2
+ from typing import List
3
+
4
+ from fastapi import Response
5
+ from fastapi.security import APIKeyHeader
6
+ from fastapi import HTTPException, Security
7
+
8
+ from modules import flags
9
+ from modules import config
10
+ from modules.sdxl_styles import legal_style_names
11
+
12
+ from fooocusapi.args import args
13
+ from fooocusapi.utils.img_utils import read_input_image
14
+ from fooocusapi.utils.file_utils import (
15
+ get_file_serve_url,
16
+ output_file_to_base64img,
17
+ output_file_to_bytesimg
18
+ )
19
+ from fooocusapi.utils.logger import logger
20
+ from fooocusapi.models.common.requests import (
21
+ CommonRequest as Text2ImgRequest
22
+ )
23
+ from fooocusapi.models.common.response import (
24
+ AsyncJobResponse,
25
+ AsyncJobStage,
26
+ GeneratedImageResult
27
+ )
28
+ from fooocusapi.models.requests_v1 import (
29
+ ImgInpaintOrOutpaintRequest,
30
+ ImgPromptRequest,
31
+ ImgUpscaleOrVaryRequest
32
+ )
33
+ from fooocusapi.models.requests_v2 import (
34
+ Text2ImgRequestWithPrompt,
35
+ ImgInpaintOrOutpaintRequestJson,
36
+ ImgUpscaleOrVaryRequestJson,
37
+ ImgPromptRequestJson
38
+ )
39
+ from fooocusapi.models.common.task import (
40
+ ImageGenerationResult,
41
+ GenerationFinishReason
42
+ )
43
+ from fooocusapi.configs.default import (
44
+ default_inpaint_engine_version,
45
+ default_sampler,
46
+ default_scheduler,
47
+ default_base_model_name,
48
+ default_refiner_model_name
49
+ )
50
+
51
+ from fooocusapi.parameters import ImageGenerationParams
52
+ from fooocusapi.task_queue import QueueTask
53
+
54
+
55
+ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False)
56
+
57
+
58
+ def api_key_auth(apikey: str = Security(api_key_header)):
59
+ """
60
+ Check if the API key is valid, API key is not required if no API key is set
61
+ Args:
62
+ apikey: API key
63
+ returns:
64
+ None if API key is not set, otherwise raise HTTPException
65
+ """
66
+ if args.apikey is None:
67
+ return # Skip API key check if no API key is set
68
+ if apikey != args.apikey:
69
+ raise HTTPException(status_code=403, detail="Forbidden")
70
+
71
+
72
+ def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
73
+ """
74
+ Convert Request to ImageGenerationParams
75
+ Args:
76
+ req: Request, Text2ImgRequest and classes inherited from Text2ImgRequest
77
+ returns:
78
+ ImageGenerationParams
79
+ """
80
+ config.update_files()
81
+ if req.base_model_name is not None:
82
+ if req.base_model_name not in config.model_filenames:
83
+ logger.std_warn(f"[Warning] Wrong base_model_name input: {req.base_model_name}, using default")
84
+ req.base_model_name = default_base_model_name
85
+
86
+ if req.refiner_model_name is not None and req.refiner_model_name != 'None':
87
+ if req.refiner_model_name not in config.model_filenames:
88
+ logger.std_warn(f"[Warning] Wrong refiner_model_name input: {req.refiner_model_name}, using default")
89
+ req.refiner_model_name = default_refiner_model_name
90
+
91
+ for lora in req.loras:
92
+ if lora.model_name != 'None' and lora.model_name not in config.lora_filenames:
93
+ logger.std_warn(f"[Warning] Wrong lora model_name input: {lora.model_name}, using 'None'")
94
+ lora.model_name = 'None'
95
+
96
+ prompt = req.prompt
97
+ negative_prompt = req.negative_prompt
98
+ style_selections = [
99
+ s for s in req.style_selections if s in legal_style_names]
100
+ performance_selection = req.performance_selection.value
101
+ aspect_ratios_selection = req.aspect_ratios_selection
102
+ image_number = req.image_number
103
+ image_seed = None if req.image_seed == -1 else req.image_seed
104
+ sharpness = req.sharpness
105
+ guidance_scale = req.guidance_scale
106
+ base_model_name = req.base_model_name
107
+ refiner_model_name = req.refiner_model_name
108
+ refiner_switch = req.refiner_switch
109
+ loras = [(lora.model_name, lora.weight) for lora in req.loras]
110
+ uov_input_image = None
111
+ if not isinstance(req, Text2ImgRequestWithPrompt):
112
+ if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
113
+ uov_input_image = read_input_image(req.input_image)
114
+ uov_method = flags.disabled if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.uov_method.value
115
+ upscale_value = None if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.upscale_value
116
+ outpaint_selections = [] if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else [
117
+ s.value for s in req.outpaint_selections]
118
+ outpaint_distance_left = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_left
119
+ outpaint_distance_right = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_right
120
+ outpaint_distance_top = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_top
121
+ outpaint_distance_bottom = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_bottom
122
+
123
+ if refiner_model_name == '':
124
+ refiner_model_name = 'None'
125
+
126
+ inpaint_input_image = None
127
+ inpaint_additional_prompt = None
128
+ if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) and req.input_image is not None:
129
+ inpaint_additional_prompt = req.inpaint_additional_prompt
130
+ input_image = read_input_image(req.input_image)
131
+ input_mask = None
132
+ if req.input_mask is not None:
133
+ input_mask = read_input_image(req.input_mask)
134
+ inpaint_input_image = {
135
+ 'image': input_image,
136
+ 'mask': input_mask
137
+ }
138
+
139
+ image_prompts = []
140
+ if isinstance(req, (ImgInpaintOrOutpaintRequestJson, ImgPromptRequest, ImgPromptRequestJson, ImgUpscaleOrVaryRequestJson, Text2ImgRequestWithPrompt)):
141
+ # Auto set mixing_image_prompt_and_inpaint to True
142
+ if len(req.image_prompts) > 0 and uov_input_image is not None:
143
+ print("[INFO] Mixing image prompt and vary upscale is set to True")
144
+ req.advanced_params.mixing_image_prompt_and_vary_upscale = True
145
+ elif len(req.image_prompts) > 0 and not isinstance(req, Text2ImgRequestWithPrompt) and req.input_image is not None:
146
+ print("[INFO] Mixing image prompt and inpaint is set to True")
147
+ req.advanced_params.mixing_image_prompt_and_inpaint = True
148
+
149
+ for img_prompt in req.image_prompts:
150
+ if img_prompt.cn_img is not None:
151
+ cn_img = read_input_image(img_prompt.cn_img)
152
+ if img_prompt.cn_stop is None or img_prompt.cn_stop == 0:
153
+ img_prompt.cn_stop = flags.default_parameters[img_prompt.cn_type.value][0]
154
+ if img_prompt.cn_weight is None or img_prompt.cn_weight == 0:
155
+ img_prompt.cn_weight = flags.default_parameters[img_prompt.cn_type.value][1]
156
+ image_prompts.append(
157
+ (cn_img, img_prompt.cn_stop, img_prompt.cn_weight, img_prompt.cn_type.value))
158
+
159
+ advanced_params = None
160
+ if req.advanced_params is not None:
161
+ adp = req.advanced_params
162
+
163
+ if adp.refiner_swap_method not in ['joint', 'separate', 'vae']:
164
+ print(f"[Warning] Wrong refiner_swap_method input: {adp.refiner_swap_method}, using default")
165
+ adp.refiner_swap_method = 'joint'
166
+
167
+ if adp.sampler_name not in flags.sampler_list:
168
+ print(f"[Warning] Wrong sampler_name input: {adp.sampler_name}, using default")
169
+ adp.sampler_name = default_sampler
170
+
171
+ if adp.scheduler_name not in flags.scheduler_list:
172
+ print(f"[Warning] Wrong scheduler_name input: {adp.scheduler_name}, using default")
173
+ adp.scheduler_name = default_scheduler
174
+
175
+ if adp.inpaint_engine not in flags.inpaint_engine_versions:
176
+ print(f"[Warning] Wrong inpaint_engine input: {adp.inpaint_engine}, using default")
177
+ adp.inpaint_engine = default_inpaint_engine_version
178
+
179
+ advanced_params = adp
180
+
181
+ return ImageGenerationParams(
182
+ prompt=prompt,
183
+ negative_prompt=negative_prompt,
184
+ style_selections=style_selections,
185
+ performance_selection=performance_selection,
186
+ aspect_ratios_selection=aspect_ratios_selection,
187
+ image_number=image_number,
188
+ image_seed=image_seed,
189
+ sharpness=sharpness,
190
+ guidance_scale=guidance_scale,
191
+ base_model_name=base_model_name,
192
+ refiner_model_name=refiner_model_name,
193
+ refiner_switch=refiner_switch,
194
+ loras=loras,
195
+ uov_input_image=uov_input_image,
196
+ uov_method=uov_method,
197
+ upscale_value=upscale_value,
198
+ outpaint_selections=outpaint_selections,
199
+ outpaint_distance_left=outpaint_distance_left,
200
+ outpaint_distance_right=outpaint_distance_right,
201
+ outpaint_distance_top=outpaint_distance_top,
202
+ outpaint_distance_bottom=outpaint_distance_bottom,
203
+ inpaint_input_image=inpaint_input_image,
204
+ inpaint_additional_prompt=inpaint_additional_prompt,
205
+ image_prompts=image_prompts,
206
+ advanced_params=advanced_params,
207
+ save_meta=req.save_meta,
208
+ meta_scheme=req.meta_scheme,
209
+ save_name=req.save_name,
210
+ save_extension=req.save_extension,
211
+ require_base64=req.require_base64,
212
+ )
213
+
214
+
215
+ def generate_async_output(
216
+ task: QueueTask,
217
+ require_step_preview: bool = False) -> AsyncJobResponse:
218
+ """
219
+ Generate output for async job
220
+ Arguments:
221
+ task: QueueTask
222
+ require_step_preview: bool
223
+ Returns:
224
+ AsyncJobResponse
225
+ """
226
+ job_stage = AsyncJobStage.running
227
+ job_result = None
228
+
229
+ if task.start_mills == 0:
230
+ job_stage = AsyncJobStage.waiting
231
+
232
+ if task.is_finished:
233
+ if task.finish_with_error:
234
+ job_stage = AsyncJobStage.error
235
+ elif task.task_result is not None:
236
+ job_stage = AsyncJobStage.success
237
+ job_result = generate_image_result_output(task.task_result, task.req_param.require_base64)
238
+
239
+ result = AsyncJobResponse(
240
+ job_id=task.job_id,
241
+ job_type=task.task_type,
242
+ job_stage=job_stage,
243
+ job_progress=task.finish_progress,
244
+ job_status=task.task_status,
245
+ job_step_preview=task.task_step_preview if require_step_preview else None,
246
+ job_result=job_result)
247
+ return result
248
+
249
+
250
+ def generate_streaming_output(results: List[ImageGenerationResult]) -> Response:
251
+ """
252
+ Generate streaming output for image generation results.
253
+ Args:
254
+ results (List[ImageGenerationResult]): List of image generation results.
255
+ Returns:
256
+ Response: Streaming response object, bytes image.
257
+ """
258
+ if len(results) == 0:
259
+ return Response(status_code=500)
260
+ result = results[0]
261
+ if result.finish_reason == GenerationFinishReason.queue_is_full:
262
+ return Response(status_code=409, content=result.finish_reason.value)
263
+ if result.finish_reason == GenerationFinishReason.user_cancel:
264
+ return Response(status_code=400, content=result.finish_reason.value)
265
+ if result.finish_reason == GenerationFinishReason.error:
266
+ return Response(status_code=500, content=result.finish_reason.value)
267
+
268
+ img_bytes = output_file_to_bytesimg(results[0].im)
269
+ return Response(img_bytes, media_type='image/png')
270
+
271
+
272
+ def generate_image_result_output(
273
+ results: List[ImageGenerationResult],
274
+ require_base64: bool) -> List[GeneratedImageResult]:
275
+ """
276
+ Generate image result output
277
+ Arguments:
278
+ results: List[ImageGenerationResult]
279
+ require_base64: bool
280
+ Returns:
281
+ List[GeneratedImageResult]
282
+ """
283
+ results = [
284
+ GeneratedImageResult(
285
+ base64=output_file_to_base64img(item.im) if require_base64 else None,
286
+ url=get_file_serve_url(item.im),
287
+ seed=str(item.seed),
288
+ finish_reason=item.finish_reason
289
+ ) for item in results
290
+ ]
291
+ return results
fooocusapi/utils/call_worker.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """function for call generate worker"""
2
+ from typing import List
3
+ from fastapi import Response
4
+
5
+ from fooocusapi.models.common.requests import (
6
+ CommonRequest as Text2ImgRequest
7
+ )
8
+ from fooocusapi.models.common.response import (
9
+ AsyncJobResponse,
10
+ GeneratedImageResult
11
+ )
12
+ from fooocusapi.models.common.task import (
13
+ GenerationFinishReason,
14
+ ImageGenerationResult,
15
+ AsyncJobStage,
16
+ TaskType
17
+ )
18
+ from fooocusapi.utils.api_utils import (
19
+ req_to_params,
20
+ generate_async_output,
21
+ generate_streaming_output,
22
+ generate_image_result_output
23
+ )
24
+ from fooocusapi.models.requests_v1 import (
25
+ ImgUpscaleOrVaryRequest,
26
+ ImgPromptRequest,
27
+ ImgInpaintOrOutpaintRequest
28
+ )
29
+ from fooocusapi.models.requests_v2 import (
30
+ ImgInpaintOrOutpaintRequestJson,
31
+ ImgPromptRequestJson,
32
+ ImgUpscaleOrVaryRequestJson
33
+ )
34
+ from fooocusapi.worker import worker_queue, blocking_get_task_result
35
+
36
+
37
+ def get_task_type(req: Text2ImgRequest) -> TaskType:
38
+ """return task type"""
39
+ if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
40
+ return TaskType.img_uov
41
+ if isinstance(req, (ImgPromptRequest, ImgPromptRequestJson)):
42
+ return TaskType.img_prompt
43
+ if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)):
44
+ return TaskType.img_inpaint_outpaint
45
+ return TaskType.text_2_img
46
+
47
+
48
+ def call_worker(req: Text2ImgRequest, accept: str) -> Response | AsyncJobResponse | List[GeneratedImageResult]:
49
+ """call generation worker"""
50
+ if accept == 'image/png':
51
+ streaming_output = True
52
+ # image_number auto set to 1 in streaming mode
53
+ req.image_number = 1
54
+ else:
55
+ streaming_output = False
56
+
57
+ task_type = get_task_type(req)
58
+ params = req_to_params(req)
59
+ async_task = worker_queue.add_task(task_type, params, req.webhook_url)
60
+
61
+ if async_task is None:
62
+ # add to worker queue failed
63
+ failure_results = [
64
+ ImageGenerationResult(
65
+ im=None,
66
+ seed='',
67
+ finish_reason=GenerationFinishReason.queue_is_full
68
+ )]
69
+
70
+ if streaming_output:
71
+ return generate_streaming_output(failure_results)
72
+ if req.async_process:
73
+ return AsyncJobResponse(
74
+ job_id='',
75
+ job_type=get_task_type(req),
76
+ job_stage=AsyncJobStage.error,
77
+ job_progress=0,
78
+ job_status=None,
79
+ job_step_preview=None,
80
+ job_result=[GeneratedImageResult(
81
+ base64=None,
82
+ url=None,
83
+ seed='',
84
+ finish_reason=GenerationFinishReason.queue_is_full
85
+ )])
86
+ return generate_image_result_output(failure_results, False)
87
+
88
+ if req.async_process:
89
+ # return async response directly
90
+ return generate_async_output(async_task)
91
+
92
+ # blocking get generation result
93
+ results = blocking_get_task_result(async_task.job_id)
94
+
95
+ if streaming_output:
96
+ return generate_streaming_output(results)
97
+ return generate_image_result_output(results, req.require_base64)
fooocusapi/utils/file_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """ File utils
4
+
5
+ Use for managing generated files
6
+
7
+ @file: file_utils.py
8
+ @author: Konie
9
+ @update: 2024-03-22
10
+ """
11
+ import base64
12
+ import datetime
13
+ from io import BytesIO
14
+ import os
15
+ import json
16
+ from pathlib import Path
17
+ import numpy as np
18
+ from PIL import Image
19
+ from PIL.PngImagePlugin import PngInfo
20
+
21
+ from fooocusapi.utils.logger import logger
22
+
23
+
24
+ output_dir = os.path.abspath(os.path.join(
25
+ os.path.dirname(__file__), '../..', 'outputs', 'files'))
26
+ os.makedirs(output_dir, exist_ok=True)
27
+
28
+ STATIC_SERVER_BASE = 'http://127.0.0.1:8888/files/'
29
+
30
+
31
+ def save_output_file(
32
+ img: np.ndarray,
33
+ image_meta: dict = None,
34
+ image_name: str = '',
35
+ extension: str = 'png') -> str:
36
+ """
37
+ Save np image to file
38
+ Args:
39
+ img: np.ndarray image to save
40
+ image_meta: dict of image metadata
41
+ image_name: str of image name
42
+ extension: str of image extension
43
+ Returns:
44
+ str of file name
45
+ """
46
+ current_time = datetime.datetime.now()
47
+ date_string = current_time.strftime("%Y-%m-%d")
48
+
49
+ filename = os.path.join(date_string, image_name + '.' + extension)
50
+ file_path = os.path.join(output_dir, filename)
51
+
52
+ if extension not in ['png', 'jpg', 'webp']:
53
+ extension = 'png'
54
+ image_format = Image.registered_extensions()['.'+extension]
55
+
56
+ if image_meta is None:
57
+ image_meta = {}
58
+
59
+ meta = None
60
+ if extension == 'png'and image_meta != {}:
61
+ meta = PngInfo()
62
+ meta.add_text("parameters", json.dumps(image_meta))
63
+ meta.add_text("fooocus_scheme", image_meta['metadata_scheme'])
64
+
65
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
66
+ Image.fromarray(img).save(
67
+ file_path,
68
+ format=image_format,
69
+ pnginfo=meta,
70
+ optimize=True)
71
+ return Path(filename).as_posix()
72
+
73
+
74
+ def delete_output_file(filename: str):
75
+ """
76
+ Delete files specified in the output directory
77
+ Args:
78
+ filename: str of file name
79
+ """
80
+ file_path = os.path.join(output_dir, filename)
81
+ if not os.path.exists(file_path) or not os.path.isfile(file_path):
82
+ logger.std_warn(f'[Fooocus API] {filename} not exists or is not a file')
83
+ try:
84
+ os.remove(file_path)
85
+ logger.std_info(f'[Fooocus API] Delete output file: {filename}')
86
+ except OSError:
87
+ logger.std_error(f'[Fooocus API] Delete output file failed: {filename}')
88
+
89
+
90
+ def output_file_to_base64img(filename: str | None) -> str | None:
91
+ """
92
+ Convert an image file to a base64 string.
93
+ Args:
94
+ filename: str of file name
95
+ return: str of base64 string
96
+ """
97
+ if filename is None:
98
+ return None
99
+ file_path = os.path.join(output_dir, filename)
100
+ if not os.path.exists(file_path) or not os.path.isfile(file_path):
101
+ return None
102
+
103
+ ext = filename.split('.')[-1]
104
+ if ext.lower() not in ['png', 'jpg', 'webp', 'jpeg']:
105
+ ext = 'png'
106
+ img = Image.open(file_path)
107
+ output_buffer = BytesIO()
108
+ img.save(output_buffer, format=ext.upper())
109
+ byte_data = output_buffer.getvalue()
110
+ base64_str = base64.b64encode(byte_data).decode('utf-8')
111
+ return f"data:image/{ext};base64," + base64_str
112
+
113
+
114
+ def output_file_to_bytesimg(filename: str | None) -> bytes | None:
115
+ """
116
+ Convert an image file to a bytes string.
117
+ Args:
118
+ filename: str of file name
119
+ return: bytes of image data
120
+ """
121
+ if filename is None:
122
+ return None
123
+ file_path = os.path.join(output_dir, filename)
124
+ if not os.path.exists(file_path) or not os.path.isfile(file_path):
125
+ return None
126
+
127
+ img = Image.open(file_path)
128
+ output_buffer = BytesIO()
129
+ img.save(output_buffer, format='PNG')
130
+ byte_data = output_buffer.getvalue()
131
+ return byte_data
132
+
133
+
134
+ def get_file_serve_url(filename: str | None) -> str | None:
135
+ """
136
+ Get the static serve url of an image file.
137
+ Args:
138
+ filename: str of file name
139
+ return: str of static serve url
140
+ """
141
+ if filename is None:
142
+ return None
143
+ return STATIC_SERVER_BASE + filename.replace('\\', '/')
fooocusapi/utils/img_utils.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image process utils. Used to verify, convert and store Images.
3
+
4
+ @file: img_utils.py
5
+ @author: Konie
6
+ @update: 2024-03-23
7
+ """
8
+ import base64
9
+ from io import BytesIO
10
+ from fastapi import UploadFile
11
+ from PIL import Image
12
+
13
+ import requests
14
+ import numpy as np
15
+
16
+
17
+ def upload2base64(image: UploadFile) -> str | None:
18
+ """
19
+ Convert UploadFile obj to base64 string
20
+ Args:
21
+ image (UploadFile): UploadFile obj
22
+ Returns:
23
+ str: base64 string, None for None
24
+ """
25
+ if image is None:
26
+ return None
27
+ image_bytes = image.file.read()
28
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
29
+ return image_base64
30
+
31
+
32
+ def narray_to_base64img(narray: np.ndarray) -> str | None:
33
+ """
34
+ Convert numpy array to base64 image string.
35
+ Args:
36
+ narray: numpy array
37
+ Returns:
38
+ base64 image string
39
+ """
40
+ if narray is None:
41
+ return None
42
+
43
+ img = Image.fromarray(narray)
44
+ output_buffer = BytesIO()
45
+ img.save(output_buffer, format='PNG')
46
+ byte_data = output_buffer.getvalue()
47
+ base64_str = base64.b64encode(byte_data).decode('utf-8')
48
+ return base64_str
49
+
50
+
51
+ def narray_to_bytesimg(narray) -> bytes | None:
52
+ """
53
+ Convert numpy array to bytes image.
54
+ Args:
55
+ narray: numpy array
56
+ Returns:
57
+ bytes image
58
+ """
59
+ if narray is None:
60
+ return None
61
+
62
+ img = Image.fromarray(narray)
63
+ output_buffer = BytesIO()
64
+ img.save(output_buffer, format='PNG')
65
+ byte_data = output_buffer.getvalue()
66
+ return byte_data
67
+
68
+
69
+ def read_input_image(input_image: UploadFile | str | None) -> np.ndarray | None:
70
+ """
71
+ Read input image from UploadFile or base64 string.
72
+ Args:
73
+ input_image: UploadFile, or base64 image string, or None
74
+ Returns:
75
+ numpy array of image
76
+ """
77
+ if input_image is None or input_image == '':
78
+ return None
79
+ if isinstance(input_image, str):
80
+ input_image_bytes = base64.b64decode(input_image)
81
+ else:
82
+ input_image_bytes = input_image.file.read()
83
+ pil_image = Image.open(BytesIO(input_image_bytes))
84
+ image = np.array(pil_image)
85
+ return image
86
+
87
+
88
+ def base64_to_stream(image: str) -> UploadFile | None:
89
+ """
90
+ Convert base64 image string to UploadFile.
91
+ Args:
92
+ image: base64 image string
93
+ Returns:
94
+ UploadFile or None
95
+ """
96
+ if image in ['', None, 'None', 'none', 'string', 'null']:
97
+ return None
98
+ if image.startswith('http'):
99
+ return get_check_image(url=image)
100
+ if image.startswith('data:image'):
101
+ image = image.split(sep=',', maxsplit=1)[1]
102
+ image_bytes = base64.b64decode(image)
103
+ byte_stream = BytesIO()
104
+ byte_stream.write(image_bytes)
105
+ byte_stream.seek(0)
106
+ return UploadFile(file=byte_stream)
107
+
108
+
109
+ def get_check_image(url: str) -> UploadFile | None:
110
+ """
111
+ Get image from url and check if it's valid.
112
+ Args:
113
+ url: image url
114
+ Returns:
115
+ UploadFile or None
116
+ """
117
+ if url == '':
118
+ return None
119
+ headers = {
120
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
121
+ }
122
+ try:
123
+ response = requests.get(url, headers=headers, timeout=10)
124
+ binary_image = response.content
125
+ except Exception:
126
+ return None
127
+ try:
128
+ buffer = BytesIO(binary_image)
129
+ Image.open(buffer) # This validates the image
130
+ except Exception:
131
+ return None
132
+ byte_stream = BytesIO()
133
+ byte_stream.write(binary_image)
134
+ byte_stream.seek(0)
135
+ return UploadFile(file=byte_stream)
136
+
137
+
138
+ def bytes_image_to_io(binary_image: bytes) -> BytesIO | None:
139
+ """
140
+ Convert bytes image to BytesIO.
141
+ Args:
142
+ binary_image: bytes image
143
+ Returns:
144
+ BytesIO or None
145
+ """
146
+ try:
147
+ buffer = BytesIO(binary_image)
148
+ Image.open(buffer)
149
+ except Exception:
150
+ return None
151
+ byte_stream = BytesIO()
152
+ byte_stream.write(binary_image)
153
+ byte_stream.seek(0)
154
+ return byte_stream
155
+
156
+
157
+ def bytes_to_base64img(byte_data: bytes) -> str | None:
158
+ """
159
+ Convert bytes image to base64 image string.
160
+ Args:
161
+ byte_data: bytes image
162
+ Returns:
163
+ base64 image string or None
164
+ """
165
+ if byte_data is None:
166
+ return None
167
+
168
+ base64_str = base64.b64encode(byte_data).decode('utf-8')
169
+ return base64_str
170
+
171
+
172
+ def base64_to_bytesimg(base64_str: str) -> bytes | None:
173
+ """
174
+ Convert base64 image string to bytes image.
175
+ Args:
176
+ base64_str: base64 image string
177
+ Returns:
178
+ bytes image or None
179
+ """
180
+ if base64_str == '':
181
+ return None
182
+ bytes_image = base64.b64decode(base64_str)
183
+ return bytes_image
184
+
185
+
186
+ def base64_to_narray(base64_str: str) -> np.ndarray | None:
187
+ """
188
+ Convert base64 image string to numpy array.
189
+ Args:
190
+ base64_str: base64 image string
191
+ Returns:
192
+ numpy array or None
193
+ """
194
+ if base64_str == '':
195
+ return None
196
+ bytes_image = base64.b64decode(base64_str)
197
+ image = np.frombuffer(bytes_image, np.uint8)
198
+ return image
fooocusapi/utils/logger.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """ A simply logger.
4
+
5
+ This module is used to log the program.
6
+
7
+ @file: logger.py
8
+ @author: mrhan1993
9
+ @update: 2024-03-22
10
+ """
11
+ import logging
12
+ import os
13
+ import sys
14
+
15
+ try:
16
+ from colorlog import ColoredFormatter
17
+ except ImportError:
18
+ from fooocusapi.utils.tools import run_pip
19
+ run_pip(
20
+ command="install colorlog",
21
+ desc="Install colorlog for logger.",
22
+ live=True
23
+ )
24
+ finally:
25
+ from colorlog import ColoredFormatter
26
+
27
+
28
+ own_path = os.path.dirname(os.path.abspath(__file__))
29
+ log_dir = "logs"
30
+ default_log_path = os.path.join(own_path, '../../', log_dir)
31
+
32
+ std_formatter = ColoredFormatter(
33
+ fmt="%(log_color)s[%(asctime)s] %(levelname)-8s%(reset)s %(blue)s%(message)s",
34
+ datefmt='%Y-%m-%d %H:%M:%S',
35
+ reset=True,
36
+ log_colors={
37
+ 'DEBUG': 'cyan',
38
+ 'INFO': 'green',
39
+ 'WARNING': 'yellow',
40
+ 'ERROR': 'red',
41
+ 'CRITICAL': 'red,bg_white',
42
+ },
43
+ secondary_log_colors={},
44
+ style='%'
45
+ )
46
+
47
+ file_formatter = ColoredFormatter(
48
+ fmt="[%(asctime)s] %(levelname)-8s%(reset)s %(message)s",
49
+ datefmt='%Y-%m-%d %H:%M:%S',
50
+ reset=True,
51
+ no_color=True,
52
+ style='%'
53
+ )
54
+
55
+
56
+ class ConfigLogger:
57
+ """
58
+ Configure logger.
59
+ :param log_path: log file path, better absolute path
60
+ :param std_format: stdout log format
61
+ :param file_format: file log format
62
+ """
63
+ def __init__(self,
64
+ log_path: str = default_log_path,
65
+ std_format: ColoredFormatter = std_formatter,
66
+ file_format: ColoredFormatter = file_formatter) -> None:
67
+ self.log_path = log_path
68
+ self.std_format = std_format
69
+ self.file_format = file_format
70
+
71
+
72
+ class Logger:
73
+ """
74
+ A simple logger.
75
+ :param log_name: log name
76
+ :param config: config logger
77
+ """
78
+ def __init__(self, log_name, config: ConfigLogger = ConfigLogger()):
79
+ log_path = config.log_path
80
+ err_log_path = os.path.join(str(log_path), f"{log_name}_error.log")
81
+ info_log_path = os.path.join(str(log_path), f"{log_name}_info.log")
82
+ if not os.path.exists(log_path):
83
+ os.makedirs(log_path, exist_ok=True)
84
+
85
+ self._file_logger = logging.getLogger(log_name)
86
+ self._file_logger.setLevel("INFO")
87
+
88
+ self._std_logger = logging.getLogger()
89
+ self._std_logger.setLevel("INFO")
90
+
91
+ # 创建一个ERROR级别的handler,将日志记录到error.log文件中
92
+ error_handler = logging.FileHandler(err_log_path, encoding='utf-8')
93
+ error_handler.setLevel(logging.ERROR)
94
+
95
+ # 创建一个INFO级别的handler,将日志记录到info.log文件中
96
+ info_handler = logging.FileHandler(info_log_path, encoding='utf-8')
97
+ info_handler.setLevel(logging.INFO)
98
+
99
+ # 创建一个 stream handler
100
+ stream_handler = logging.StreamHandler(sys.stdout)
101
+
102
+ error_handler.setFormatter(config.file_format)
103
+ info_handler.setFormatter(config.file_format)
104
+ stream_handler.setFormatter(config.std_format)
105
+
106
+ # 将handler添加到logger中
107
+ self._file_logger.addHandler(error_handler)
108
+ self._file_logger.addHandler(info_handler)
109
+ self._std_logger.addHandler(stream_handler)
110
+
111
+ def file_error(self, message):
112
+ """file error log"""
113
+ self._file_logger.error(message)
114
+
115
+ def file_info(self, message):
116
+ """file info log"""
117
+ self._file_logger.info(message)
118
+
119
+ def std_info(self, message):
120
+ """std info log"""
121
+ self._std_logger.info(message)
122
+
123
+ def std_warn(self, message):
124
+ """std warn log"""
125
+ self._std_logger.warning(message)
126
+
127
+ def std_error(self, message):
128
+ """std error log"""
129
+ self._std_logger.error(message)
130
+
131
+
132
+ logger = Logger(log_name="fooocus_api")
fooocusapi/utils/lora_manager.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import requests
4
+ import tarfile
5
+
6
+ def _hash_url(url):
7
+ """Generates a hash value for a given URL."""
8
+ return hashlib.md5(url.encode('utf-8')).hexdigest()
9
+
10
+ class LoraManager:
11
+ """
12
+ Manager loras from url
13
+ """
14
+ def __init__(self):
15
+ self.cache_dir = os.path.join(
16
+ os.path.dirname(os.path.realpath(__file__)),
17
+ '../../',
18
+ 'repositories/Fooocus/models/loras')
19
+
20
+ def _download_lora(self, url):
21
+ """
22
+ Downloads a LoRa from a URL, saves it in the cache, and if it's a .tar file, extracts it and returns the .safetensors file.
23
+ """
24
+ url_hash = _hash_url(url)
25
+ file_ext = url.split('.')[-1]
26
+ filepath = os.path.join(self.cache_dir, f"{url_hash}.{file_ext}")
27
+
28
+ if not os.path.exists(filepath):
29
+ print(f"Start download for: {url}")
30
+
31
+ try:
32
+ response = requests.get(url, timeout=10, stream=True)
33
+ response.raise_for_status()
34
+ with open(filepath, 'wb') as f:
35
+ for chunk in response.iter_content(chunk_size=8192):
36
+ f.write(chunk)
37
+
38
+ if file_ext == "tar":
39
+ print("Extracting the tar file...")
40
+ with tarfile.open(filepath, 'r:*') as tar:
41
+ tar.extractall(path=self.cache_dir)
42
+ print("Extraction completed.")
43
+ return self._find_safetensors_file(self.cache_dir)
44
+
45
+ print(f"Download successfully, saved as {filepath}")
46
+ except Exception as e:
47
+ raise Exception(f"Error downloading {url}: {e}") from e
48
+
49
+ else:
50
+ print(f"LoRa already downloaded {url}")
51
+
52
+ return filepath
53
+
54
+ def _find_safetensors_file(self, directory):
55
+ """
56
+ Finds the first .safetensors file in the specified directory.
57
+ """
58
+ print("Searching for .safetensors file.")
59
+ for root, dirs, files in os.walk(directory):
60
+ for file in files:
61
+ if file.endswith('.safetensors'):
62
+ return os.path.join(root, file)
63
+ raise FileNotFoundError("No .safetensors file found in the extracted files.")
64
+
65
+ def check(self, urls):
66
+ """Manages the specified LoRAs: downloads missing ones and returns their file names."""
67
+ paths = []
68
+ for url in urls:
69
+ path = self._download_lora(url)
70
+ paths.append(path)
71
+ return paths
fooocusapi/utils/model_loader.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Download models from url
5
+
6
+ @file: model_loader.py
7
+ @author: Konie
8
+ @update: 2024-03-22
9
+ """
10
+ from modules.model_loader import load_file_from_url
11
+
12
+
13
+ def download_models():
14
+ """
15
+ Download models from config
16
+ """
17
+ vae_approx_filenames = [
18
+ ('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
19
+ ('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'),
20
+ ('xl-to-v1_interposer-v3.1.safetensors', 'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
21
+ ]
22
+
23
+ from modules.config import (
24
+ paths_checkpoints as modelfile_path,
25
+ paths_loras as lorafile_path,
26
+ path_vae_approx as vae_approx_path,
27
+ path_fooocus_expansion as fooocus_expansion_path,
28
+ path_embeddings as embeddings_path,
29
+ checkpoint_downloads,
30
+ embeddings_downloads,
31
+ lora_downloads)
32
+
33
+ for file_name, url in checkpoint_downloads.items():
34
+ load_file_from_url(url=url, model_dir=modelfile_path[0], file_name=file_name)
35
+ for file_name, url in embeddings_downloads.items():
36
+ load_file_from_url(url=url, model_dir=embeddings_path, file_name=file_name)
37
+ for file_name, url in lora_downloads.items():
38
+ load_file_from_url(url=url, model_dir=lorafile_path[0], file_name=file_name)
39
+ for file_name, url in vae_approx_filenames:
40
+ load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name)
41
+
42
+ load_file_from_url(
43
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',
44
+ model_dir=fooocus_expansion_path,
45
+ file_name='pytorch_model.bin'
46
+ )
fooocusapi/utils/tools.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """ Some tools
4
+
5
+ @file: tools.py
6
+ @author: Konie
7
+ @update: 2024-03-22
8
+ """
9
+ # pylint: disable=line-too-long
10
+ # pylint: disable=broad-exception-caught
11
+ import os
12
+ import sys
13
+ import re
14
+ import subprocess
15
+ from importlib.util import find_spec
16
+ from importlib import metadata
17
+ from packaging import version
18
+
19
+
20
+ PYTHON_EXEC = sys.executable
21
+ INDEX_URL = os.environ.get('INDEX_URL', "")
22
+ PATTERN = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
23
+
24
+
25
+ # This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
26
+ def run_command(command: str,
27
+ desc: str = None,
28
+ error_desc: str = None,
29
+ custom_env: str = None,
30
+ live: bool = True) -> str:
31
+ """
32
+ Run a command and return the output
33
+ Args:
34
+ command: Command to run
35
+ desc: Description of the command
36
+ error_desc: Description of the error
37
+ custom_env: Custom environment variables
38
+ live: Whether to print the output
39
+ Returns:
40
+ The output of the command
41
+ """
42
+ if desc is not None:
43
+ print(desc)
44
+
45
+ run_kwargs = {
46
+ "args": command,
47
+ "shell": True,
48
+ "env": os.environ if custom_env is None else custom_env,
49
+ "encoding": 'utf8',
50
+ "errors": 'ignore'
51
+ }
52
+
53
+ if not live:
54
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
55
+
56
+ result = subprocess.run(check=False, **run_kwargs)
57
+
58
+ if result.returncode != 0:
59
+ error_bits = [
60
+ f"{error_desc or 'Error running command'}.",
61
+ f"Command: {command}",
62
+ f"Error code: {result.returncode}",
63
+ ]
64
+ if result.stdout:
65
+ error_bits.append(f"stdout: {result.stdout}")
66
+ if result.stderr:
67
+ error_bits.append(f"stderr: {result.stderr}")
68
+ raise RuntimeError("\n".join(error_bits))
69
+
70
+ return result.stdout or ""
71
+
72
+
73
+ # This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
74
+ def run_pip(command, desc=None, live=True):
75
+ """
76
+ Run a pip command
77
+ Args:
78
+ command: Command to run
79
+ desc: Description of the command
80
+ live: Whether to print the output
81
+ Returns:
82
+ The output of the command
83
+ """
84
+ try:
85
+ index_url_line = f' --index-url {INDEX_URL}' if INDEX_URL != '' else ''
86
+ return run_command(
87
+ command=f'"{PYTHON_EXEC}" -m pip {command} --prefer-binary{index_url_line}',
88
+ desc=f"Installing {desc}",
89
+ error_desc=f"Couldn't install {desc}",
90
+ live=live
91
+ )
92
+ except Exception as e:
93
+ print(f'CMD Failed {command}: {e}')
94
+ return None
95
+
96
+
97
+ def is_installed(package: str) -> bool:
98
+ """
99
+ Check if a package is installed
100
+ Args:
101
+ package: Package name
102
+ Returns:
103
+ Whether the package is installed
104
+ """
105
+ try:
106
+ spec = find_spec(package)
107
+ except ModuleNotFoundError:
108
+ return False
109
+
110
+ return spec is not None
111
+
112
+
113
+ def check_torch_cuda() -> bool:
114
+ """
115
+ Check if torch and CUDA is available
116
+ Returns:
117
+ Whether CUDA is available
118
+ """
119
+ try:
120
+ import torch
121
+ return torch.cuda.is_available()
122
+ except ImportError:
123
+ return False
124
+
125
+
126
+ def requirements_check(requirements_file: str = 'requirements.txt',
127
+ pattern: re.Pattern = PATTERN) -> bool:
128
+ """
129
+ Check if the requirements file is satisfied
130
+ Args:
131
+ requirements_file: Path to the requirements file
132
+ pattern: Pattern to match the requirements
133
+ Returns:
134
+ Whether the requirements file is satisfied
135
+ """
136
+ with open(requirements_file, "r", encoding="utf8") as file:
137
+ for line in file:
138
+ if line.strip() == "":
139
+ continue
140
+
141
+ m = re.match(pattern, line)
142
+ if m is None:
143
+ return False
144
+
145
+ package = m.group(1).strip()
146
+ version_required = (m.group(2) or "").strip()
147
+
148
+ if version_required == "":
149
+ continue
150
+
151
+ try:
152
+ version_installed = metadata.version(package)
153
+ except Exception:
154
+ return False
155
+
156
+ if version.parse(version_required) != version.parse(version_installed):
157
+ return False
158
+
159
+ return True
fooocusapi/worker.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Worker, modify from https://github.com/lllyasviel/Fooocus/blob/main/modules/async_worker.py
3
+ """
4
+ import copy
5
+ import os
6
+ import random
7
+ import time
8
+ from typing import List
9
+ import logging
10
+ import numpy as np
11
+ import torch
12
+
13
+ from fooocusapi.models.common.image_meta import image_parse
14
+ from modules.patch import PatchSettings, patch_settings, patch_all
15
+ from modules.flags import Performance
16
+
17
+ from fooocusapi.utils.file_utils import save_output_file
18
+ from fooocusapi.models.common.task import (
19
+ GenerationFinishReason,
20
+ ImageGenerationResult
21
+ )
22
+ from fooocusapi.utils.logger import logger
23
+ from fooocusapi.task_queue import (
24
+ QueueTask,
25
+ TaskQueue,
26
+ TaskOutputs
27
+ )
28
+
29
+ patch_all()
30
+
31
+ worker_queue: TaskQueue | None = None
32
+ last_model_name = None
33
+
34
+
35
+ def process_stop():
36
+ """Stop process"""
37
+ import ldm_patched.modules.model_management
38
+ ldm_patched.modules.model_management.interrupt_current_processing()
39
+
40
+
41
+ @torch.no_grad()
42
+ @torch.inference_mode()
43
+ def task_schedule_loop():
44
+ """Task schedule loop"""
45
+ while True:
46
+ if len(worker_queue.queue) == 0:
47
+ time.sleep(0.05)
48
+ continue
49
+
50
+ current_task = worker_queue.queue[0]
51
+ if current_task.start_mills == 0:
52
+ process_generate(current_task)
53
+
54
+
55
+ @torch.no_grad()
56
+ @torch.inference_mode()
57
+ def blocking_get_task_result(job_id: str) -> List[ImageGenerationResult]:
58
+ """
59
+ Get task result, when async_task is false
60
+ :param job_id:
61
+ :return:
62
+ """
63
+ waiting_sleep_steps: int = 0
64
+ waiting_start_time = time.perf_counter()
65
+ while not worker_queue.is_task_finished(job_id):
66
+ if waiting_sleep_steps == 0:
67
+ logger.std_info(f"[Task Queue] Waiting for task finished, job_id={job_id}")
68
+ delay = 0.05
69
+ time.sleep(delay)
70
+ waiting_sleep_steps += 1
71
+ if waiting_sleep_steps % int(10 / delay) == 0:
72
+ waiting_time = time.perf_counter() - waiting_start_time
73
+ logger.std_info(f"[Task Queue] Already waiting for {round(waiting_time, 1)} seconds, job_id={job_id}")
74
+
75
+ task = worker_queue.get_task(job_id, True)
76
+ return task.task_result
77
+
78
+
79
+ @torch.no_grad()
80
+ @torch.inference_mode()
81
+ def process_generate(async_task: QueueTask):
82
+ """Generate image"""
83
+ try:
84
+ import modules.default_pipeline as pipeline
85
+ except Exception as e:
86
+ logger.std_error(f'[Task Queue] Import default pipeline error: {e}')
87
+ if not async_task.is_finished:
88
+ worker_queue.finish_task(async_task.job_id)
89
+ async_task.set_result([], True, str(e))
90
+ logger.std_error(f"[Task Queue] Finish task with error, seq={async_task.job_id}")
91
+ return []
92
+
93
+ import modules.flags as flags
94
+ import modules.core as core
95
+ import modules.inpaint_worker as inpaint_worker
96
+ import modules.config as config
97
+ import modules.constants as constants
98
+ import extras.preprocessors as preprocessors
99
+ import extras.ip_adapter as ip_adapter
100
+ import extras.face_crop as face_crop
101
+ import ldm_patched.modules.model_management as model_management
102
+ from modules.util import (
103
+ remove_empty_str, HWC3, resize_image,
104
+ get_image_shape_ceil, set_image_shape_ceil,
105
+ get_shape_ceil, resample_image, erode_or_dilate,
106
+ get_enabled_loras, parse_lora_references_from_prompt, apply_wildcards,
107
+ remove_performance_lora
108
+ )
109
+
110
+ from modules.upscaler import perform_upscale
111
+ from extras.expansion import safe_str
112
+ from extras.censor import default_censor
113
+ from modules.sdxl_styles import (
114
+ apply_style, get_random_style,
115
+ fooocus_expansion, apply_arrays, random_style_name
116
+ )
117
+
118
+ pid = os.getpid()
119
+
120
+ outputs = TaskOutputs(async_task)
121
+ results = []
122
+
123
+ def refresh_seed(seed_string: int | str | None) -> int:
124
+ """
125
+ Refresh and check seed number.
126
+ :params seed_string: seed, str or int. None means random
127
+ :return: seed number
128
+ """
129
+ if seed_string is None or seed_string == -1:
130
+ return random.randint(constants.MIN_SEED, constants.MAX_SEED)
131
+
132
+ try:
133
+ seed_value = int(seed_string)
134
+ if constants.MIN_SEED <= seed_value <= constants.MAX_SEED:
135
+ return seed_value
136
+ except ValueError:
137
+ pass
138
+ return random.randint(constants.MIN_SEED, constants.MAX_SEED)
139
+
140
+ def progressbar(_, number, text):
141
+ """progress bar"""
142
+ logger.std_info(f'[Fooocus] {text}')
143
+ outputs.append(['preview', (number, text, None)])
144
+
145
+ def yield_result(_, images, tasks, extension='png',
146
+ blockout_nsfw=False, censor=True):
147
+ """
148
+ Yield result
149
+ :param _: async task object
150
+ :param images: list for generated image
151
+ :param tasks: the image was generated one by one, when image number is not one, it will be a task list
152
+ :param extension: extension for saved image
153
+ :param blockout_nsfw: blockout nsfw image
154
+ :param censor: censor image
155
+ :return:
156
+ """
157
+ if not isinstance(images, list):
158
+ images = [images]
159
+
160
+ if censor and (config.default_black_out_nsfw or black_out_nsfw):
161
+ images = default_censor(images)
162
+
163
+ results = []
164
+ for index, im in enumerate(images):
165
+ if async_task.req_param.save_name == '':
166
+ image_name = f"{async_task.job_id}-{str(index)}"
167
+ else:
168
+ image_name = f"{async_task.req_param.save_name}-{str(index)}"
169
+ if len(tasks) == 0:
170
+ img_seed = -1
171
+ img_meta = {}
172
+ else:
173
+ img_seed = tasks[index]['task_seed']
174
+ img_meta = image_parse(
175
+ async_tak=async_task,
176
+ task=tasks[index])
177
+ img_filename = save_output_file(
178
+ img=im,
179
+ image_name=image_name,
180
+ image_meta=img_meta,
181
+ extension=extension)
182
+ results.append(ImageGenerationResult(
183
+ im=img_filename,
184
+ seed=str(img_seed),
185
+ finish_reason=GenerationFinishReason.success))
186
+ async_task.set_result(results, False)
187
+ worker_queue.finish_task(async_task.job_id)
188
+ logger.std_info(f"[Task Queue] Finish task, job_id={async_task.job_id}")
189
+
190
+ outputs.append(['results', images])
191
+ pipeline.prepare_text_encoder(async_call=True)
192
+
193
+ try:
194
+ logger.std_info(f"[Task Queue] Task queue start task, job_id={async_task.job_id}")
195
+ # clear memory
196
+ global last_model_name
197
+
198
+ if last_model_name is None:
199
+ last_model_name = async_task.req_param.base_model_name
200
+ if last_model_name != async_task.req_param.base_model_name:
201
+ model_management.cleanup_models() # key1
202
+ model_management.unload_all_models()
203
+ model_management.soft_empty_cache() # key2
204
+ last_model_name = async_task.req_param.base_model_name
205
+
206
+ worker_queue.start_task(async_task.job_id)
207
+
208
+ execution_start_time = time.perf_counter()
209
+
210
+ # Transform parameters
211
+ params = async_task.req_param
212
+ prompt = params.prompt
213
+ negative_prompt = params.negative_prompt
214
+ style_selections = params.style_selections
215
+ performance_selection = Performance(params.performance_selection)
216
+ aspect_ratios_selection = params.aspect_ratios_selection
217
+ image_number = params.image_number
218
+ save_metadata_to_images = params.save_meta
219
+ metadata_scheme = params.meta_scheme
220
+ save_extension = params.save_extension
221
+ save_name = params.save_name
222
+ image_seed = refresh_seed(params.image_seed)
223
+ read_wildcards_in_order = False
224
+ sharpness = params.sharpness
225
+ guidance_scale = params.guidance_scale
226
+ base_model_name = params.base_model_name
227
+ refiner_model_name = params.refiner_model_name
228
+ refiner_switch = params.refiner_switch
229
+ loras = params.loras
230
+ input_image_checkbox = params.uov_input_image is not None or params.inpaint_input_image is not None or len(params.image_prompts) > 0
231
+ current_tab = 'uov' if params.uov_method != flags.disabled else 'ip' if len(params.image_prompts) > 0 else 'inpaint' if params.inpaint_input_image is not None else None
232
+ uov_method = params.uov_method
233
+ upscale_value = params.upscale_value
234
+ uov_input_image = params.uov_input_image
235
+ outpaint_selections = params.outpaint_selections
236
+ outpaint_distance_left = params.outpaint_distance_left
237
+ outpaint_distance_top = params.outpaint_distance_top
238
+ outpaint_distance_right = params.outpaint_distance_right
239
+ outpaint_distance_bottom = params.outpaint_distance_bottom
240
+ inpaint_input_image = params.inpaint_input_image
241
+ inpaint_additional_prompt = '' if params.inpaint_additional_prompt is None else params.inpaint_additional_prompt
242
+ inpaint_mask_image_upload = None
243
+
244
+ adp = params.advanced_params
245
+ disable_preview = adp.disable_preview
246
+ disable_intermediate_results = adp.disable_intermediate_results
247
+ disable_seed_increment = adp.disable_seed_increment
248
+ adm_scaler_positive = adp.adm_scaler_positive
249
+ adm_scaler_negative = adp.adm_scaler_negative
250
+ adm_scaler_end = adp.adm_scaler_end
251
+ adaptive_cfg = adp.adaptive_cfg
252
+ sampler_name = adp.sampler_name
253
+ scheduler_name = adp.scheduler_name
254
+ overwrite_step = adp.overwrite_step
255
+ overwrite_switch = adp.overwrite_switch
256
+ overwrite_width = adp.overwrite_width
257
+ overwrite_height = adp.overwrite_height
258
+ overwrite_vary_strength = adp.overwrite_vary_strength
259
+ overwrite_upscale_strength = adp.overwrite_upscale_strength
260
+ mixing_image_prompt_and_vary_upscale = adp.mixing_image_prompt_and_vary_upscale
261
+ mixing_image_prompt_and_inpaint = adp.mixing_image_prompt_and_inpaint
262
+ debugging_cn_preprocessor = adp.debugging_cn_preprocessor
263
+ skipping_cn_preprocessor = adp.skipping_cn_preprocessor
264
+ canny_low_threshold = adp.canny_low_threshold
265
+ canny_high_threshold = adp.canny_high_threshold
266
+ refiner_swap_method = adp.refiner_swap_method
267
+ controlnet_softness = adp.controlnet_softness
268
+ freeu_enabled = adp.freeu_enabled
269
+ freeu_b1 = adp.freeu_b1
270
+ freeu_b2 = adp.freeu_b2
271
+ freeu_s1 = adp.freeu_s1
272
+ freeu_s2 = adp.freeu_s2
273
+ debugging_inpaint_preprocessor = adp.debugging_inpaint_preprocessor
274
+ inpaint_disable_initial_latent = adp.inpaint_disable_initial_latent
275
+ inpaint_engine = adp.inpaint_engine
276
+ inpaint_strength = adp.inpaint_strength
277
+ inpaint_respective_field = adp.inpaint_respective_field
278
+ inpaint_mask_upload_checkbox = adp.inpaint_mask_upload_checkbox
279
+ invert_mask_checkbox = adp.invert_mask_checkbox
280
+ inpaint_erode_or_dilate = adp.inpaint_erode_or_dilate
281
+ black_out_nsfw = adp.black_out_nsfw
282
+ vae_name = adp.vae_name
283
+ clip_skip = adp.clip_skip
284
+
285
+ cn_tasks = {x: [] for x in flags.ip_list}
286
+ for img_prompt in params.image_prompts:
287
+ cn_img, cn_stop, cn_weight, cn_type = img_prompt
288
+ cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
289
+
290
+ if inpaint_input_image is not None and inpaint_input_image['image'] is not None:
291
+ inpaint_image_size = inpaint_input_image['image'].shape[:2]
292
+ if inpaint_input_image['mask'] is None:
293
+ inpaint_input_image['mask'] = np.zeros(inpaint_image_size, dtype=np.uint8)
294
+ else:
295
+ inpaint_mask_upload_checkbox = True
296
+
297
+ inpaint_input_image['mask'] = HWC3(inpaint_input_image['mask'])
298
+ inpaint_mask_image_upload = inpaint_input_image['mask']
299
+
300
+ # Fooocus async_worker.py code start
301
+
302
+ outpaint_selections = [o.lower() for o in outpaint_selections]
303
+ base_model_additional_loras = []
304
+ raw_style_selections = copy.deepcopy(style_selections)
305
+ uov_method = uov_method.lower()
306
+
307
+ if fooocus_expansion in style_selections:
308
+ use_expansion = True
309
+ style_selections.remove(fooocus_expansion)
310
+ else:
311
+ use_expansion = False
312
+
313
+ use_style = len(style_selections) > 0
314
+
315
+ if base_model_name == refiner_model_name:
316
+ logger.std_warn('[Fooocus] Refiner disabled because base model and refiner are same.')
317
+ refiner_model_name = 'None'
318
+
319
+ steps = performance_selection.steps()
320
+
321
+ performance_loras = []
322
+
323
+ if performance_selection == Performance.EXTREME_SPEED:
324
+ logger.std_warn('[Fooocus] Enter LCM mode.')
325
+ progressbar(async_task, 1, 'Downloading LCM components ...')
326
+ performance_loras += [(config.downloading_sdxl_lcm_lora(), 1.0)]
327
+
328
+ if refiner_model_name != 'None':
329
+ logger.std_info('[Fooocus] Refiner disabled in LCM mode.')
330
+
331
+ refiner_model_name = 'None'
332
+ sampler_name = 'lcm'
333
+ scheduler_name = 'lcm'
334
+ sharpness = 0.0
335
+ guidance_scale = 1.0
336
+ adaptive_cfg = 1.0
337
+ refiner_switch = 1.0
338
+ adm_scaler_positive = 1.0
339
+ adm_scaler_negative = 1.0
340
+ adm_scaler_end = 0.0
341
+
342
+ elif performance_selection == Performance.LIGHTNING:
343
+ logger.std_info('[Fooocus] Enter Lightning mode.')
344
+ progressbar(async_task, 1, 'Downloading Lightning components ...')
345
+ performance_loras += [(config.downloading_sdxl_lightning_lora(), 1.0)]
346
+
347
+ if refiner_model_name != 'None':
348
+ logger.std_info('[Fooocus] Refiner disabled in Lightning mode.')
349
+
350
+ refiner_model_name = 'None'
351
+ sampler_name = 'euler'
352
+ scheduler_name = 'sgm_uniform'
353
+ sharpness = 0.0
354
+ guidance_scale = 1.0
355
+ adaptive_cfg = 1.0
356
+ refiner_switch = 1.0
357
+ adm_scaler_positive = 1.0
358
+ adm_scaler_negative = 1.0
359
+ adm_scaler_end = 0.0
360
+
361
+ elif performance_selection == Performance.HYPER_SD:
362
+ print('Enter Hyper-SD mode.')
363
+ progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
364
+ performance_loras += [(config.downloading_sdxl_hyper_sd_lora(), 0.8)]
365
+
366
+ if refiner_model_name != 'None':
367
+ logger.std_info('[Fooocus] Refiner disabled in Hyper-SD mode.')
368
+
369
+ refiner_model_name = 'None'
370
+ sampler_name = 'dpmpp_sde_gpu'
371
+ scheduler_name = 'karras'
372
+ sharpness = 0.0
373
+ guidance_scale = 1.0
374
+ adaptive_cfg = 1.0
375
+ refiner_switch = 1.0
376
+ adm_scaler_positive = 1.0
377
+ adm_scaler_negative = 1.0
378
+ adm_scaler_end = 0.0
379
+
380
+ logger.std_info(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
381
+ logger.std_info(f'[Parameters] CLIP Skip = {clip_skip}')
382
+ logger.std_info(f'[Parameters] Sharpness = {sharpness}')
383
+ logger.std_info(f'[Parameters] ControlNet Softness = {controlnet_softness}')
384
+ logger.std_info(f'[Parameters] ADM Scale = '
385
+ f'{adm_scaler_positive} : '
386
+ f'{adm_scaler_negative} : '
387
+ f'{adm_scaler_end}')
388
+
389
+ patch_settings[pid] = PatchSettings(
390
+ sharpness,
391
+ adm_scaler_end,
392
+ adm_scaler_positive,
393
+ adm_scaler_negative,
394
+ controlnet_softness,
395
+ adaptive_cfg
396
+ )
397
+
398
+ cfg_scale = float(guidance_scale)
399
+ logger.std_info(f'[Parameters] CFG = {cfg_scale}')
400
+
401
+ initial_latent = None
402
+ denoising_strength = 1.0
403
+ tiled = False
404
+
405
+ width, height = aspect_ratios_selection.replace('×', ' ').replace('*', ' ').split(' ')[:2]
406
+ width, height = int(width), int(height)
407
+
408
+ skip_prompt_processing = False
409
+
410
+ inpaint_worker.current_task = None
411
+ inpaint_parameterized = inpaint_engine != 'None'
412
+ inpaint_image = None
413
+ inpaint_mask = None
414
+ inpaint_head_model_path = None
415
+
416
+ use_synthetic_refiner = False
417
+
418
+ controlnet_canny_path = None
419
+ controlnet_cpds_path = None
420
+ clip_vision_path, ip_negative_path, ip_adapter_path, ip_adapter_face_path = None, None, None, None
421
+
422
+ seed = int(image_seed)
423
+ logger.std_info(f'[Parameters] Seed = {seed}')
424
+
425
+ goals = []
426
+ tasks = []
427
+
428
+ if input_image_checkbox:
429
+ if (current_tab == 'uov' or (
430
+ current_tab == 'ip' and mixing_image_prompt_and_vary_upscale)) \
431
+ and uov_method != flags.disabled and uov_input_image is not None:
432
+ uov_input_image = HWC3(uov_input_image)
433
+ if 'vary' in uov_method:
434
+ goals.append('vary')
435
+ elif 'upscale' in uov_method:
436
+ goals.append('upscale')
437
+ if 'fast' in uov_method:
438
+ skip_prompt_processing = True
439
+ else:
440
+ steps = performance_selection.steps_uov()
441
+
442
+ progressbar(async_task, 1, 'Downloading upscale models ...')
443
+ config.downloading_upscale_model()
444
+ if (current_tab == 'inpaint' or (
445
+ current_tab == 'ip' and mixing_image_prompt_and_inpaint)) \
446
+ and isinstance(inpaint_input_image, dict):
447
+ inpaint_image = inpaint_input_image['image']
448
+ inpaint_mask = inpaint_input_image['mask'][:, :, 0]
449
+
450
+ if inpaint_mask_upload_checkbox:
451
+ if isinstance(inpaint_mask_image_upload, np.ndarray):
452
+ if inpaint_mask_image_upload.ndim == 3:
453
+ H, W, C = inpaint_image.shape
454
+ inpaint_mask_image_upload = resample_image(inpaint_mask_image_upload, width=W, height=H)
455
+ inpaint_mask_image_upload = np.mean(inpaint_mask_image_upload, axis=2)
456
+ inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255
457
+ inpaint_mask = np.maximum(np.zeros(shape=(H, W), dtype=np.uint8), inpaint_mask_image_upload)
458
+
459
+ if int(inpaint_erode_or_dilate) != 0:
460
+ inpaint_mask = erode_or_dilate(inpaint_mask, inpaint_erode_or_dilate)
461
+
462
+ if invert_mask_checkbox:
463
+ inpaint_mask = 255 - inpaint_mask
464
+
465
+ inpaint_image = HWC3(inpaint_image)
466
+ if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
467
+ and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
468
+ progressbar(async_task, 1, 'Downloading upscale models ...')
469
+ config.downloading_upscale_model()
470
+ if inpaint_parameterized:
471
+ progressbar(async_task, 1, 'Downloading inpainter ...')
472
+ inpaint_head_model_path, inpaint_patch_model_path = config.downloading_inpaint_models(
473
+ inpaint_engine)
474
+ base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
475
+ logger.std_info(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
476
+ if refiner_model_name == 'None':
477
+ use_synthetic_refiner = True
478
+ refiner_switch = 0.8
479
+ else:
480
+ inpaint_head_model_path, inpaint_patch_model_path = None, None
481
+ logger.std_info('[Inpaint] Parameterized inpaint is disabled.')
482
+ if inpaint_additional_prompt != '':
483
+ if prompt == '':
484
+ prompt = inpaint_additional_prompt
485
+ else:
486
+ prompt = inpaint_additional_prompt + '\n' + prompt
487
+ goals.append('inpaint')
488
+ if current_tab == 'ip' or \
489
+ mixing_image_prompt_and_vary_upscale or \
490
+ mixing_image_prompt_and_inpaint:
491
+ goals.append('cn')
492
+ progressbar(async_task, 1, 'Downloading control models ...')
493
+ if len(cn_tasks[flags.cn_canny]) > 0:
494
+ controlnet_canny_path = config.downloading_controlnet_canny()
495
+ if len(cn_tasks[flags.cn_cpds]) > 0:
496
+ controlnet_cpds_path = config.downloading_controlnet_cpds()
497
+ if len(cn_tasks[flags.cn_ip]) > 0:
498
+ clip_vision_path, ip_negative_path, ip_adapter_path = config.downloading_ip_adapters('ip')
499
+ if len(cn_tasks[flags.cn_ip_face]) > 0:
500
+ clip_vision_path, ip_negative_path, ip_adapter_face_path = config.downloading_ip_adapters(
501
+ 'face')
502
+ progressbar(async_task, 1, 'Loading control models ...')
503
+
504
+ # Load or unload CNs
505
+ pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path])
506
+ ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path)
507
+ ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_face_path)
508
+
509
+ if overwrite_step > 0:
510
+ steps = overwrite_step
511
+
512
+ switch = int(round(steps * refiner_switch))
513
+
514
+ if overwrite_switch > 0:
515
+ switch = overwrite_switch
516
+
517
+ if overwrite_width > 0:
518
+ width = overwrite_width
519
+
520
+ if overwrite_height > 0:
521
+ height = overwrite_height
522
+
523
+ logger.std_info(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}')
524
+ logger.std_info(f'[Parameters] Steps = {steps} - {switch}')
525
+
526
+ progressbar(async_task, 1, 'Initializing ...')
527
+
528
+ if not skip_prompt_processing:
529
+
530
+ prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
531
+ negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='')
532
+
533
+ prompt = prompts[0]
534
+ negative_prompt = negative_prompts[0]
535
+
536
+ if prompt == '':
537
+ # disable expansion when empty since it is not meaningful and influences image prompt
538
+ use_expansion = False
539
+
540
+ extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
541
+ extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
542
+
543
+ progressbar(async_task, 3, 'Loading models ...')
544
+ lora_filenames = remove_performance_lora(config.lora_filenames, performance_selection)
545
+ loras, prompt = parse_lora_references_from_prompt(prompt, loras, config.default_max_lora_number, lora_filenames=lora_filenames)
546
+ loras += performance_loras
547
+
548
+ pipeline.refresh_everything(
549
+ refiner_model_name=refiner_model_name,
550
+ base_model_name=base_model_name,
551
+ loras=loras,
552
+ base_model_additional_loras=base_model_additional_loras,
553
+ use_synthetic_refiner=use_synthetic_refiner)
554
+
555
+ pipeline.set_clip_skip(clip_skip)
556
+
557
+ progressbar(async_task, 3, 'Processing prompts ...')
558
+ tasks = []
559
+
560
+ for i in range(image_number):
561
+ if disable_seed_increment:
562
+ task_seed = seed % (constants.MAX_SEED + 1)
563
+ else:
564
+ task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
565
+
566
+ task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
567
+ task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
568
+ task_prompt = apply_arrays(task_prompt, i)
569
+ task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
570
+ task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
571
+ extra_positive_prompts]
572
+ task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
573
+ extra_negative_prompts]
574
+
575
+ positive_basic_workloads = []
576
+ negative_basic_workloads = []
577
+
578
+ task_styles = style_selections.copy()
579
+ if use_style:
580
+ for index, style in enumerate(task_styles):
581
+ if style == random_style_name:
582
+ style = get_random_style(task_rng)
583
+ task_styles[index] = style
584
+ p, n = apply_style(style, positive=task_prompt)
585
+ positive_basic_workloads = positive_basic_workloads + p
586
+ negative_basic_workloads = negative_basic_workloads + n
587
+ else:
588
+ positive_basic_workloads.append(task_prompt)
589
+
590
+ negative_basic_workloads.append(task_negative_prompt) # Always use independent workload for negative.
591
+
592
+ positive_basic_workloads = positive_basic_workloads + task_extra_positive_prompts
593
+ negative_basic_workloads = negative_basic_workloads + task_extra_negative_prompts
594
+
595
+ positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=task_prompt)
596
+ negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=task_negative_prompt)
597
+
598
+ tasks.append(dict(
599
+ task_seed=task_seed,
600
+ task_prompt=task_prompt,
601
+ task_negative_prompt=task_negative_prompt,
602
+ positive=positive_basic_workloads,
603
+ negative=negative_basic_workloads,
604
+ expansion='',
605
+ c=None,
606
+ uc=None,
607
+ positive_top_k=len(positive_basic_workloads),
608
+ negative_top_k=len(negative_basic_workloads),
609
+ log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
610
+ log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
611
+ styles=task_styles
612
+ ))
613
+
614
+ if use_expansion:
615
+ for i, t in enumerate(tasks):
616
+ progressbar(async_task, 4, f'Preparing Fooocus text #{i + 1} ...')
617
+ expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
618
+ logger.std_info(f'[Prompt Expansion] {expansion}')
619
+ t['expansion'] = expansion
620
+ t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
621
+
622
+ for i, t in enumerate(tasks):
623
+ progressbar(async_task, 5, f'Encoding positive #{i + 1} ...')
624
+ t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
625
+
626
+ for i, t in enumerate(tasks):
627
+ if abs(float(cfg_scale) - 1.0) < 1e-4:
628
+ t['uc'] = pipeline.clone_cond(t['c'])
629
+ else:
630
+ progressbar(async_task, 6, f'Encoding negative #{i + 1} ...')
631
+ t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
632
+
633
+ if len(goals) > 0:
634
+ progressbar(async_task, 7, 'Image processing ...')
635
+
636
+ if 'vary' in goals:
637
+ if 'subtle' in uov_method:
638
+ denoising_strength = 0.5
639
+ if 'strong' in uov_method:
640
+ denoising_strength = 0.85
641
+ if overwrite_vary_strength > 0:
642
+ denoising_strength = overwrite_vary_strength
643
+
644
+ shape_ceil = get_image_shape_ceil(uov_input_image)
645
+ if shape_ceil < 1024:
646
+ logger.std_warn('[Vary] Image is resized because it is too small.')
647
+ shape_ceil = 1024
648
+ elif shape_ceil > 2048:
649
+ logger.std_warn('[Vary] Image is resized because it is too big.')
650
+ shape_ceil = 2048
651
+
652
+ uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
653
+
654
+ initial_pixels = core.numpy_to_pytorch(uov_input_image)
655
+ progressbar(async_task, 8, 'VAE encoding ...')
656
+
657
+ candidate_vae, _ = pipeline.get_candidate_vae(
658
+ steps=steps,
659
+ switch=switch,
660
+ denoise=denoising_strength,
661
+ refiner_swap_method=refiner_swap_method
662
+ )
663
+
664
+ initial_latent = core.encode_vae(vae=candidate_vae, pixels=initial_pixels)
665
+ B, C, H, W = initial_latent['samples'].shape
666
+ width = W * 8
667
+ height = H * 8
668
+ logger.std_info(f'[Vary] Final resolution is {str((height, width))}.')
669
+
670
+ if 'upscale' in goals:
671
+ H, W, C = uov_input_image.shape
672
+ progressbar(async_task, 9, f'Upscaling image from {str((H, W))} ...')
673
+ uov_input_image = perform_upscale(uov_input_image)
674
+ logger.std_info('[Upscale] Image upscale.')
675
+
676
+ if upscale_value is not None and upscale_value > 1.0:
677
+ f = upscale_value
678
+ else:
679
+ if '1.5x' in uov_method:
680
+ f = 1.5
681
+ elif '2x' in uov_method:
682
+ f = 2.0
683
+ else:
684
+ f = 1.0
685
+
686
+ shape_ceil = get_shape_ceil(H * f, W * f)
687
+
688
+ if shape_ceil < 1024:
689
+ logger.std_info('[Upscale] Image is resized because it is too small.')
690
+ uov_input_image = set_image_shape_ceil(uov_input_image, 1024)
691
+ shape_ceil = 1024
692
+ else:
693
+ uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f)
694
+
695
+ image_is_super_large = shape_ceil > 2800
696
+
697
+ if 'fast' in uov_method:
698
+ direct_return = True
699
+ elif image_is_super_large:
700
+ logger.std_info('[Upscale] Image is too large. Directly returned the SR image. '
701
+ 'Usually directly return SR image at 4K resolution '
702
+ 'yields better results than SDXL diffusion.')
703
+ direct_return = True
704
+ else:
705
+ direct_return = False
706
+
707
+ if direct_return:
708
+ # d = [('Upscale (Fast)', '2x')]
709
+ # log(uov_input_image, d, output_format=save_extension)
710
+ if config.default_black_out_nsfw or black_out_nsfw:
711
+ uov_input_image = default_censor(uov_input_image)
712
+ yield_result(async_task, uov_input_image, tasks, save_extension, False, False)
713
+ return
714
+
715
+ tiled = True
716
+ denoising_strength = 0.382
717
+
718
+ if overwrite_upscale_strength > 0:
719
+ denoising_strength = overwrite_upscale_strength
720
+
721
+ initial_pixels = core.numpy_to_pytorch(uov_input_image)
722
+ progressbar(async_task, 10, 'VAE encoding ...')
723
+
724
+ candidate_vae, _ = pipeline.get_candidate_vae(
725
+ steps=steps,
726
+ switch=switch,
727
+ denoise=denoising_strength,
728
+ refiner_swap_method=refiner_swap_method
729
+ )
730
+
731
+ initial_latent = core.encode_vae(
732
+ vae=candidate_vae,
733
+ pixels=initial_pixels, tiled=True)
734
+ B, C, H, W = initial_latent['samples'].shape
735
+ width = W * 8
736
+ height = H * 8
737
+ logger.std_info(f'[Upscale] Final resolution is {str((height, width))}.')
738
+
739
+ if 'inpaint' in goals:
740
+ if len(outpaint_selections) > 0:
741
+ H, W, C = inpaint_image.shape
742
+ if 'top' in outpaint_selections:
743
+ distance_top = int(H * 0.3)
744
+ if outpaint_distance_top > 0:
745
+ distance_top = outpaint_distance_top
746
+
747
+ inpaint_image = np.pad(inpaint_image, [[distance_top, 0], [0, 0], [0, 0]], mode='edge')
748
+ inpaint_mask = np.pad(inpaint_mask, [[distance_top, 0], [0, 0]], mode='constant',
749
+ constant_values=255)
750
+
751
+ if 'bottom' in outpaint_selections:
752
+ distance_bottom = int(H * 0.3)
753
+ if outpaint_distance_bottom > 0:
754
+ distance_bottom = outpaint_distance_bottom
755
+
756
+ inpaint_image = np.pad(inpaint_image, [[0, distance_bottom], [0, 0], [0, 0]], mode='edge')
757
+ inpaint_mask = np.pad(inpaint_mask, [[0, distance_bottom], [0, 0]], mode='constant',
758
+ constant_values=255)
759
+
760
+ H, W, C = inpaint_image.shape
761
+ if 'left' in outpaint_selections:
762
+ distance_left = int(W * 0.3)
763
+ if outpaint_distance_left > 0:
764
+ distance_left = outpaint_distance_left
765
+
766
+ inpaint_image = np.pad(inpaint_image, [[0, 0], [distance_left, 0], [0, 0]], mode='edge')
767
+ inpaint_mask = np.pad(inpaint_mask, [[0, 0], [distance_left, 0]], mode='constant',
768
+ constant_values=255)
769
+
770
+ if 'right' in outpaint_selections:
771
+ distance_right = int(W * 0.3)
772
+ if outpaint_distance_right > 0:
773
+ distance_right = outpaint_distance_right
774
+
775
+ inpaint_image = np.pad(inpaint_image, [[0, 0], [0, distance_right], [0, 0]], mode='edge')
776
+ inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, distance_right]], mode='constant',
777
+ constant_values=255)
778
+
779
+ inpaint_image = np.ascontiguousarray(inpaint_image.copy())
780
+ inpaint_mask = np.ascontiguousarray(inpaint_mask.copy())
781
+ inpaint_strength = 1.0
782
+ inpaint_respective_field = 1.0
783
+
784
+ denoising_strength = inpaint_strength
785
+
786
+ inpaint_worker.current_task = inpaint_worker.InpaintWorker(
787
+ image=inpaint_image,
788
+ mask=inpaint_mask,
789
+ use_fill=denoising_strength > 0.99,
790
+ k=inpaint_respective_field
791
+ )
792
+
793
+ if debugging_inpaint_preprocessor:
794
+ yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), tasks,
795
+ black_out_nsfw)
796
+ return
797
+
798
+ progressbar(async_task, 11, 'VAE Inpaint encoding ...')
799
+
800
+ inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
801
+ inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
802
+ inpaint_pixel_mask = core.numpy_to_pytorch(inpaint_worker.current_task.interested_mask)
803
+
804
+ candidate_vae, candidate_vae_swap = pipeline.get_candidate_vae(
805
+ steps=steps,
806
+ switch=switch,
807
+ denoise=denoising_strength,
808
+ refiner_swap_method=refiner_swap_method
809
+ )
810
+
811
+ latent_inpaint, latent_mask = core.encode_vae_inpaint(
812
+ mask=inpaint_pixel_mask,
813
+ vae=candidate_vae,
814
+ pixels=inpaint_pixel_image)
815
+
816
+ latent_swap = None
817
+ if candidate_vae_swap is not None:
818
+ progressbar(async_task, 12, 'VAE SD15 encoding ...')
819
+ latent_swap = core.encode_vae(
820
+ vae=candidate_vae_swap,
821
+ pixels=inpaint_pixel_fill)['samples']
822
+
823
+ progressbar(async_task, 13, 'VAE encoding ...')
824
+ latent_fill = core.encode_vae(
825
+ vae=candidate_vae,
826
+ pixels=inpaint_pixel_fill)['samples']
827
+
828
+ inpaint_worker.current_task.load_latent(
829
+ latent_fill=latent_fill, latent_mask=latent_mask, latent_swap=latent_swap)
830
+
831
+ if inpaint_parameterized:
832
+ pipeline.final_unet = inpaint_worker.current_task.patch(
833
+ inpaint_head_model_path=inpaint_head_model_path,
834
+ inpaint_latent=latent_inpaint,
835
+ inpaint_latent_mask=latent_mask,
836
+ model=pipeline.final_unet
837
+ )
838
+
839
+ if not inpaint_disable_initial_latent:
840
+ initial_latent = {'samples': latent_fill}
841
+
842
+ B, C, H, W = latent_fill.shape
843
+ height, width = H * 8, W * 8
844
+ final_height, final_width = inpaint_worker.current_task.image.shape[:2]
845
+ logger.std_info(f'[Inpaint] Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.')
846
+
847
+ if 'cn' in goals:
848
+ for task in cn_tasks[flags.cn_canny]:
849
+ cn_img, cn_stop, cn_weight = task
850
+ cn_img = resize_image(HWC3(cn_img), width=width, height=height)
851
+
852
+ if not skipping_cn_preprocessor:
853
+ cn_img = preprocessors.canny_pyramid(cn_img, canny_low_threshold, canny_high_threshold)
854
+
855
+ cn_img = HWC3(cn_img)
856
+ task[0] = core.numpy_to_pytorch(cn_img)
857
+ if debugging_cn_preprocessor:
858
+ yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
859
+ return
860
+ for task in cn_tasks[flags.cn_cpds]:
861
+ cn_img, cn_stop, cn_weight = task
862
+ cn_img = resize_image(HWC3(cn_img), width=width, height=height)
863
+
864
+ if not skipping_cn_preprocessor:
865
+ cn_img = preprocessors.cpds(cn_img)
866
+
867
+ cn_img = HWC3(cn_img)
868
+ task[0] = core.numpy_to_pytorch(cn_img)
869
+ if debugging_cn_preprocessor:
870
+ yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
871
+ return
872
+ for task in cn_tasks[flags.cn_ip]:
873
+ cn_img, cn_stop, cn_weight = task
874
+ cn_img = HWC3(cn_img)
875
+
876
+ # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
877
+ cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
878
+
879
+ task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
880
+ if debugging_cn_preprocessor:
881
+ yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
882
+ return
883
+ for task in cn_tasks[flags.cn_ip_face]:
884
+ cn_img, cn_stop, cn_weight = task
885
+ cn_img = HWC3(cn_img)
886
+
887
+ if not skipping_cn_preprocessor:
888
+ cn_img = face_crop.crop_image(cn_img)
889
+
890
+ # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
891
+ cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
892
+
893
+ task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
894
+ if debugging_cn_preprocessor:
895
+ yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
896
+ return
897
+
898
+ all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
899
+
900
+ if len(all_ip_tasks) > 0:
901
+ pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks)
902
+
903
+ if freeu_enabled:
904
+ logger.std_info('[Fooocus] FreeU is enabled!')
905
+ pipeline.final_unet = core.apply_freeu(
906
+ pipeline.final_unet,
907
+ freeu_b1,
908
+ freeu_b2,
909
+ freeu_s1,
910
+ freeu_s2
911
+ )
912
+
913
+ all_steps = steps * image_number
914
+
915
+ logger.std_info(f'[Parameters] Denoising Strength = {denoising_strength}')
916
+
917
+ if isinstance(initial_latent, dict) and 'samples' in initial_latent:
918
+ log_shape = initial_latent['samples'].shape
919
+ else:
920
+ log_shape = f'Image Space {(height, width)}'
921
+
922
+ logger.std_info(f'[Parameters] Initial Latent shape: {log_shape}')
923
+
924
+ preparation_time = time.perf_counter() - execution_start_time
925
+ logger.std_info(f'[Fooocus] Preparation time: {preparation_time:.2f} seconds')
926
+
927
+ final_sampler_name = sampler_name
928
+ final_scheduler_name = scheduler_name
929
+
930
+ if scheduler_name in ['lcm', 'tcd']:
931
+ final_scheduler_name = 'sgm_uniform'
932
+
933
+ def patch_discrete(unet):
934
+ return core.opModelSamplingDiscrete.patch(
935
+ pipeline.final_unet,
936
+ sampling=scheduler_name,
937
+ zsnr=False)[0]
938
+
939
+ if pipeline.final_unet is not None:
940
+ pipeline.final_unet = patch_discrete(pipeline.final_unet)
941
+ if pipeline.final_refiner_unet is not None:
942
+ pipeline.final_refiner_unet = patch_discrete(pipeline.final_refiner_unet)
943
+ logger.std_info(f'[Fooocus] Using {scheduler_name} scheduler.')
944
+ elif scheduler_name == 'edm_playground_v2.5':
945
+ final_scheduler_name = 'karras'
946
+
947
+ def patch_edm(unet):
948
+ return core.opModelSamplingContinuousEDM.patch(
949
+ unet,
950
+ sampling=scheduler_name,
951
+ sigma_max=120.0,
952
+ sigma_min=0.002)[0]
953
+
954
+ if pipeline.final_unet is not None:
955
+ pipeline.final_unet = patch_edm(pipeline.final_unet)
956
+ if pipeline.final_refiner_unet is not None:
957
+ pipeline.final_refiner_unet = patch_edm(pipeline.final_refiner_unet)
958
+
959
+ logger.std_info(f'[Fooocus] Using {scheduler_name} scheduler.')
960
+
961
+ outputs.append(['preview', (13, 'Moving model to GPU ...', None)])
962
+
963
+ def callback(step, x0, x, total_steps, y):
964
+ """callback, used for progress and preview"""
965
+ done_steps = current_task_id * steps + step
966
+ outputs.append(['preview', (
967
+ int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
968
+ f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling',
969
+ y)])
970
+
971
+ for current_task_id, task in enumerate(tasks):
972
+ execution_start_time = time.perf_counter()
973
+
974
+ try:
975
+ positive_cond, negative_cond = task['c'], task['uc']
976
+
977
+ if 'cn' in goals:
978
+ for cn_flag, cn_path in [
979
+ (flags.cn_canny, controlnet_canny_path),
980
+ (flags.cn_cpds, controlnet_cpds_path)
981
+ ]:
982
+ for cn_img, cn_stop, cn_weight in cn_tasks[cn_flag]:
983
+ positive_cond, negative_cond = core.apply_controlnet(
984
+ positive_cond, negative_cond,
985
+ pipeline.loaded_ControlNets[cn_path], cn_img, cn_weight, 0, cn_stop)
986
+
987
+ imgs = pipeline.process_diffusion(
988
+ positive_cond=positive_cond,
989
+ negative_cond=negative_cond,
990
+ steps=steps,
991
+ switch=switch,
992
+ width=width,
993
+ height=height,
994
+ image_seed=task['task_seed'],
995
+ callback=callback,
996
+ sampler_name=final_sampler_name,
997
+ scheduler_name=final_scheduler_name,
998
+ latent=initial_latent,
999
+ denoise=denoising_strength,
1000
+ tiled=tiled,
1001
+ cfg_scale=cfg_scale,
1002
+ refiner_swap_method=refiner_swap_method,
1003
+ disable_preview=disable_preview
1004
+ )
1005
+
1006
+ del task['c'], task['uc'], positive_cond, negative_cond # Save memory
1007
+
1008
+ if inpaint_worker.current_task is not None:
1009
+ imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
1010
+
1011
+ # Fooocus async_worker.py code end
1012
+
1013
+ results += imgs
1014
+ except model_management.InterruptProcessingException as e:
1015
+ logger.std_warn("[Fooocus] User stopped")
1016
+ results = []
1017
+ results.append(ImageGenerationResult(
1018
+ im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.user_cancel))
1019
+ async_task.set_result(results, True, str(e))
1020
+ break
1021
+ except Exception as e:
1022
+ logger.std_error(f'[Fooocus] Process error: {e}')
1023
+ logging.exception(e)
1024
+ results = []
1025
+ results.append(ImageGenerationResult(
1026
+ im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.error))
1027
+ async_task.set_result(results, True, str(e))
1028
+ break
1029
+
1030
+ execution_time = time.perf_counter() - execution_start_time
1031
+ logger.std_info(f'[Fooocus] Generating and saving time: {execution_time:.2f} seconds')
1032
+
1033
+ if async_task.finish_with_error:
1034
+ worker_queue.finish_task(async_task.job_id)
1035
+ return async_task.task_result
1036
+ yield_result(None, results, tasks, save_extension, black_out_nsfw)
1037
+ return
1038
+ except Exception as e:
1039
+ logger.std_error(f'[Fooocus] Worker error: {e}')
1040
+
1041
+ if not async_task.is_finished:
1042
+ async_task.set_result([], True, str(e))
1043
+ worker_queue.finish_task(async_task.job_id)
1044
+ logger.std_info(f"[Task Queue] Finish task with error, job_id={async_task.job_id}")
main.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """ Entry for Fooocus API.
4
+
5
+ Use for starting Fooocus API.
6
+ python main.py --help for more usage
7
+
8
+ @file: main.py
9
+ @author: Konie
10
+ @update: 2024-03-22
11
+ """
12
+ import argparse
13
+ import os
14
+ import re
15
+ import shutil
16
+ import sys
17
+ from threading import Thread
18
+
19
+ from fooocusapi.utils.logger import logger
20
+ from fooocusapi.utils.tools import run_pip, check_torch_cuda, requirements_check
21
+ from fooocus_api_version import version
22
+
23
+ script_path = os.path.dirname(os.path.realpath(__file__))
24
+ module_path = os.path.join(script_path, "repositories/Fooocus")
25
+
26
+ sys.path.append(script_path)
27
+ sys.path.append(module_path)
28
+
29
+ logger.std_info("[System ARGV] " + str(sys.argv))
30
+
31
+ try:
32
+ index = sys.argv.index('--gpu-device-id')
33
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(sys.argv[index + 1])
34
+ logger.std_info(f"[Fooocus] Set device to: {str(sys.argv[index + 1])}")
35
+ except ValueError:
36
+ pass
37
+
38
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
39
+ os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
40
+
41
+ python = sys.executable
42
+ default_command_live = True
43
+ index_url = os.environ.get("INDEX_URL", "")
44
+ re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
45
+
46
+
47
+ def install_dependents(skip: bool = False):
48
+ """
49
+ Check and install dependencies
50
+ Args:
51
+ skip: skip pip install
52
+ """
53
+ if skip:
54
+ return
55
+
56
+ torch_index_url = os.environ.get(
57
+ "TORCH_INDEX_URL", "https://download.pytorch.org/whl/cu121"
58
+ )
59
+
60
+ # Check if you need pip install
61
+ if not requirements_check():
62
+ run_pip("install -r requirements.txt", "requirements")
63
+
64
+ if not check_torch_cuda():
65
+ run_pip(
66
+ f"install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}",
67
+ desc="torch",
68
+ )
69
+
70
+
71
+ def preload_pipeline():
72
+ """Preload pipeline"""
73
+ logger.std_info("[Fooocus-API] Preloading pipeline ...")
74
+ import modules.default_pipeline as _
75
+
76
+
77
+ def prepare_environments(args) -> bool:
78
+ """
79
+ Prepare environments
80
+ Args:
81
+ args: command line arguments
82
+ """
83
+
84
+ if args.base_url is None or len(args.base_url.strip()) == 0:
85
+ host = args.host
86
+ if host == "0.0.0.0":
87
+ host = "127.0.0.1"
88
+ args.base_url = f"http://{host}:{args.port}"
89
+
90
+ sys.argv = [sys.argv[0]]
91
+
92
+ # Remove and copy preset folder
93
+ origin_preset_folder = os.path.abspath(os.path.join(module_path, "presets"))
94
+ preset_folder = os.path.abspath(os.path.join(script_path, "presets"))
95
+ if os.path.exists(preset_folder):
96
+ shutil.rmtree(preset_folder)
97
+ shutil.copytree(origin_preset_folder, preset_folder)
98
+
99
+ from modules import config
100
+ from fooocusapi.configs import default
101
+ from fooocusapi.utils.model_loader import download_models
102
+
103
+ default.default_inpaint_engine_version = config.default_inpaint_engine_version
104
+ default.default_styles = config.default_styles
105
+ default.default_base_model_name = config.default_base_model_name
106
+ default.default_refiner_model_name = config.default_refiner_model_name
107
+ default.default_refiner_switch = config.default_refiner_switch
108
+ default.default_loras = config.default_loras
109
+ default.default_cfg_scale = config.default_cfg_scale
110
+ default.default_prompt_negative = config.default_prompt_negative
111
+ default.default_aspect_ratio = default.get_aspect_ratio_value(
112
+ config.default_aspect_ratio
113
+ )
114
+ default.available_aspect_ratios = [
115
+ default.get_aspect_ratio_value(a) for a in config.available_aspect_ratios
116
+ ]
117
+
118
+ download_models()
119
+
120
+ # Init task queue
121
+ from fooocusapi import worker
122
+ from fooocusapi.task_queue import TaskQueue
123
+
124
+ worker.worker_queue = TaskQueue(
125
+ queue_size=args.queue_size,
126
+ history_size=args.queue_history,
127
+ webhook_url=args.webhook_url,
128
+ persistent=args.persistent,
129
+ )
130
+
131
+ logger.std_info(f"[Fooocus-API] Task queue size: {args.queue_size}")
132
+ logger.std_info(f"[Fooocus-API] Queue history size: {args.queue_history}")
133
+ logger.std_info(f"[Fooocus-API] Webhook url: {args.webhook_url}")
134
+
135
+ return True
136
+
137
+
138
+ def pre_setup():
139
+ """
140
+ Pre setup, for replicate
141
+ """
142
+
143
+ class Args(object):
144
+ """
145
+ Arguments object
146
+ """
147
+ host = "0.0.0.0"
148
+ port = 8000
149
+ base_url = None
150
+ sync_repo = "skip"
151
+ disable_image_log = True
152
+ skip_pip = True
153
+ preload_pipeline = True
154
+ queue_size = 100
155
+ queue_history = 0
156
+ preset = "default"
157
+ webhook_url = None
158
+ persistent = False
159
+ always_gpu = False
160
+ all_in_fp16 = False
161
+ gpu_device_id = None
162
+ apikey = None
163
+
164
+ print("[Pre Setup] Prepare environments")
165
+
166
+ arguments = Args()
167
+ sys.argv = [sys.argv[0]]
168
+ sys.argv.append("--disable-image-log")
169
+
170
+ install_dependents(arguments.skip_pip)
171
+
172
+ prepare_environments(arguments)
173
+
174
+ # Start task schedule thread
175
+ from fooocusapi.worker import task_schedule_loop
176
+
177
+ task_thread = Thread(target=task_schedule_loop, daemon=True)
178
+ task_thread.start()
179
+
180
+ print("[Pre Setup] Finished")
181
+
182
+
183
+ if __name__ == "__main__":
184
+ logger.std_info(f"[Fooocus API] Python {sys.version}")
185
+ logger.std_info(f"[Fooocus API] Fooocus API version: {version}")
186
+
187
+ from fooocusapi.base_args import add_base_args
188
+
189
+ parser = argparse.ArgumentParser()
190
+ add_base_args(parser, True)
191
+
192
+ args, _ = parser.parse_known_args()
193
+ install_dependents(skip=args.skip_pip)
194
+
195
+ from fooocusapi.args import args
196
+
197
+ if prepare_environments(args):
198
+ sys.argv = [sys.argv[0]]
199
+
200
+ # Load pipeline in new thread
201
+ preload_pipeline_thread = Thread(target=preload_pipeline, daemon=True)
202
+ preload_pipeline_thread.start()
203
+
204
+ # Start task schedule thread
205
+ from fooocusapi.worker import task_schedule_loop
206
+
207
+ task_schedule_thread = Thread(target=task_schedule_loop, daemon=True)
208
+ task_schedule_thread.start()
209
+
210
+ # Start api server
211
+ from fooocusapi.api import start_app
212
+
213
+ start_app(args)
mannequin_to_model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Created By: ishwor subedi
3
+ Date: 2024-07-03
4
+ """
5
+ import os
6
+ import cv2
7
+ import base64
8
+ import requests
9
+ import numpy as np
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ from fastapi import File, UploadFile, Form, Depends
13
+ from fastapi import APIRouter
14
+ from fastapi.responses import JSONResponse
15
+
16
+ from fooocusapi.utils.api_utils import api_key_auth
17
+ from src.pipeline.main_pipeline import MainPipeline
18
+ from supabase import create_client
19
+
20
+ secure_router = APIRouter(dependencies=[Depends(api_key_auth)])
21
+
22
+ pipeline = MainPipeline()
23
+
24
+ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
25
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
26
+
27
+ supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
28
+
29
+
30
+ @secure_router.post("/mannequin_to_model")
31
+ async def mannequinToModel(
32
+ store_name: str = Form(...),
33
+ clothing_category: str = Form(...),
34
+ product_id: str = Form(...),
35
+ body_structure: str = Form(...),
36
+ skin_complexion: str = Form(...),
37
+ facial_structure: str = Form(...),
38
+ person_img: UploadFile = File(...)
39
+ ):
40
+ if body_structure == "medium":
41
+ body_structure = "fat"
42
+
43
+ person_image = await person_img.read()
44
+
45
+ mannequin_image_url = f"{SUPABASE_URL}/storage/v1/object/public/ClothingTryOn/{store_name}/{clothing_category}/{product_id}/{product_id}_{skin_complexion}_{facial_structure}_{body_structure}.webp"
46
+ mannequin_img = read_return(mannequin_image_url)
47
+
48
+ try:
49
+ person_image, mannequin_image = Image.open(BytesIO(person_image)), Image.open(BytesIO(mannequin_img))
50
+ except:
51
+
52
+ error_message = {
53
+ "code": 404,
54
+ "error": "The requested resource is not available. Please verify the availability and try again."
55
+
56
+ }
57
+
58
+ return JSONResponse(content=error_message, status_code=404)
59
+
60
+ mannequin_image = cv2.cvtColor(np.array(mannequin_image), cv2.COLOR_RGB2BGR)
61
+ person_image = cv2.cvtColor(np.array(person_image), cv2.COLOR_RGB2BGR)
62
+ result = pipeline.face_swap(mannequin_image, person_image, enhance=False)
63
+ result = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
64
+ inMemFile = BytesIO()
65
+ result.save(inMemFile, format="WEBP", quality=85)
66
+ outputBytes = inMemFile.getvalue()
67
+ response = {
68
+ "code": 200,
69
+ "output": f"data:image/WEBP;base64,{base64.b64encode(outputBytes).decode('utf-8')}",
70
+ }
71
+
72
+ return response
73
+
74
+
75
+ @secure_router.get("/mannequin_catalogue")
76
+ async def returnJsonData(gender: str):
77
+ folderImageURL = supabase.storage.get_bucket("JSON").create_signed_url(
78
+ path=os.path.join("MannequinInfo.json"), expires_in=3600)["signedURL"]
79
+ r = requests.get(folderImageURL).content.decode()
80
+
81
+ mannequin_data = json.loads(r)
82
+
83
+ if gender.lower() == "female":
84
+ res = [item for item in mannequin_data if item["gender"] == "female"]
85
+ elif gender.lower() == "male":
86
+ res = [item for item in mannequin_data if item["gender"] == "male"]
87
+ else:
88
+ res = []
89
+
90
+ return res
predict.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prediction interface for Cog ⚙️
3
+ https://github.com/replicate/cog/blob/main/docs/python.md
4
+ """
5
+
6
+ import copy
7
+ import os
8
+ from typing import List
9
+ import numpy as np
10
+
11
+ from PIL import Image
12
+ from cog import BasePredictor, BaseModel, Input, Path
13
+ from fooocusapi.utils.lora_manager import LoraManager
14
+ from fooocusapi.utils.file_utils import output_dir
15
+ from fooocusapi.models.common.task import GenerationFinishReason
16
+ from fooocusapi.configs.default import (
17
+ available_aspect_ratios,
18
+ uov_methods,
19
+ outpaint_expansions,
20
+ default_styles,
21
+ default_base_model_name,
22
+ default_refiner_model_name,
23
+ default_loras,
24
+ default_refiner_switch,
25
+ default_cfg_scale,
26
+ default_prompt_negative
27
+ )
28
+
29
+ from fooocusapi.parameters import ImageGenerationParams
30
+ from fooocusapi.task_queue import TaskType
31
+
32
+
33
+ class Output(BaseModel):
34
+ """
35
+ Output model
36
+ """
37
+ seeds: List[str]
38
+ paths: List[Path]
39
+
40
+
41
+ class Predictor(BasePredictor):
42
+ """Predictor"""
43
+ def setup(self) -> None:
44
+ """
45
+ Load the model into memory to make running multiple predictions efficient
46
+ """
47
+ from main import pre_setup
48
+ pre_setup()
49
+
50
+ def predict(
51
+ self,
52
+ prompt: str = Input(
53
+ default='',
54
+ description="Prompt for image generation"),
55
+ negative_prompt: str = Input(
56
+ default=default_prompt_negative,
57
+ description="Negative prompt for image generation"),
58
+ style_selections: str = Input(
59
+ default=','.join(default_styles),
60
+ description="Fooocus styles applied for image generation, separated by comma"),
61
+ performance_selection: str = Input(
62
+ default='Speed',
63
+ choices=['Speed', 'Quality', 'Extreme Speed', 'Lightning'],
64
+ description="Performance selection"),
65
+ aspect_ratios_selection: str = Input(
66
+ default='1152*896',
67
+ choices=available_aspect_ratios,
68
+ description="The generated image's size"),
69
+ image_number: int = Input(
70
+ default=1,
71
+ ge=1, le=8,
72
+ description="How many image to generate"),
73
+ image_seed: int = Input(
74
+ default=-1,
75
+ description="Seed to generate image, -1 for random"),
76
+ use_default_loras: bool = Input(
77
+ default=True,
78
+ description="Use default LoRAs"),
79
+ loras_custom_urls: str = Input(
80
+ default="",
81
+ description="Custom LoRAs URLs in the format 'url,weight' provide multiple separated by ; (example 'url1,0.3;url2,0.1')"),
82
+ sharpness: float = Input(
83
+ default=2.0,
84
+ ge=0.0, le=30.0),
85
+ guidance_scale: float = Input(
86
+ default=default_cfg_scale,
87
+ ge=1.0, le=30.0),
88
+ refiner_switch: float = Input(
89
+ default=default_refiner_switch,
90
+ ge=0.1, le=1.0),
91
+ uov_input_image: Path = Input(
92
+ default=None,
93
+ description="Input image for upscale or variation, keep None for not upscale or variation"),
94
+ uov_method: str = Input(
95
+ default='Disabled',
96
+ choices=uov_methods),
97
+ uov_upscale_value: float = Input(
98
+ default=0,
99
+ description="Only when Upscale (Custom)"),
100
+ inpaint_additional_prompt: str = Input(
101
+ default='',
102
+ description="Prompt for image generation"),
103
+ inpaint_input_image: Path = Input(
104
+ default=None,
105
+ description="Input image for inpaint or outpaint, keep None for not inpaint or outpaint. Please noticed, `uov_input_image` has bigger priority is not None."),
106
+ inpaint_input_mask: Path = Input(
107
+ default=None,
108
+ description="Input mask for inpaint"),
109
+ outpaint_selections: str = Input(
110
+ default='',
111
+ description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
112
+ outpaint_distance_left: int = Input(
113
+ default=0,
114
+ description="Outpaint expansion distance from Left of the image"),
115
+ outpaint_distance_top: int = Input(
116
+ default=0,
117
+ description="Outpaint expansion distance from Top of the image"),
118
+ outpaint_distance_right: int = Input(
119
+ default=0,
120
+ description="Outpaint expansion distance from Right of the image"),
121
+ outpaint_distance_bottom: int = Input(
122
+ default=0,
123
+ description="Outpaint expansion distance from Bottom of the image"),
124
+ cn_img1: Path = Input(
125
+ default=None,
126
+ description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
127
+ cn_stop1: float = Input(
128
+ default=None,
129
+ ge=0, le=1,
130
+ description="Stop at for image prompt, None for default value"),
131
+ cn_weight1: float = Input(
132
+ default=None,
133
+ ge=0, le=2,
134
+ description="Weight for image prompt, None for default value"),
135
+ cn_type1: str = Input(
136
+ default='ImagePrompt',
137
+ choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
138
+ description="ControlNet type for image prompt"),
139
+ cn_img2: Path = Input(
140
+ default=None,
141
+ description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
142
+ cn_stop2: float = Input(
143
+ default=None,
144
+ ge=0, le=1,
145
+ description="Stop at for image prompt, None for default value"),
146
+ cn_weight2: float = Input(
147
+ default=None,
148
+ ge=0, le=2,
149
+ description="Weight for image prompt, None for default value"),
150
+ cn_type2: str = Input(
151
+ default='ImagePrompt',
152
+ choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
153
+ description="ControlNet type for image prompt"),
154
+ cn_img3: Path = Input(
155
+ default=None,
156
+ description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
157
+ cn_stop3: float = Input(
158
+ default=None,
159
+ ge=0, le=1,
160
+ description="Stop at for image prompt, None for default value"),
161
+ cn_weight3: float = Input(
162
+ default=None,
163
+ ge=0, le=2,
164
+ description="Weight for image prompt, None for default value"),
165
+ cn_type3: str = Input(
166
+ default='ImagePrompt',
167
+ choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
168
+ description="ControlNet type for image prompt"),
169
+ cn_img4: Path = Input(
170
+ default=None,
171
+ description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
172
+ cn_stop4: float = Input(
173
+ default=None,
174
+ ge=0, le=1,
175
+ description="Stop at for image prompt, None for default value"),
176
+ cn_weight4: float = Input(
177
+ default=None,
178
+ ge=0, le=2,
179
+ description="Weight for image prompt, None for default value"),
180
+ cn_type4: str = Input(
181
+ default='ImagePrompt',
182
+ choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
183
+ description="ControlNet type for image prompt")
184
+ ) -> Output:
185
+ """Run a single prediction on the model"""
186
+ from modules import flags
187
+ from modules.sdxl_styles import legal_style_names
188
+ from fooocusapi.worker import blocking_get_task_result, worker_queue
189
+
190
+ base_model_name = default_base_model_name
191
+ refiner_model_name = default_refiner_model_name
192
+
193
+ lora_manager = LoraManager()
194
+
195
+ # Use default loras if selected
196
+ loras = copy.copy(default_loras) if use_default_loras else []
197
+
198
+ # add custom user loras if provided
199
+ if loras_custom_urls:
200
+ urls = [url.strip() for url in loras_custom_urls.split(';')]
201
+
202
+ loras_with_weights = [url.split(',') for url in urls]
203
+
204
+ custom_lora_paths = lora_manager.check([lw[0] for lw in loras_with_weights])
205
+ custom_loras = [[path, float(lw[1]) if len(lw) > 1 else 1.0] for path, lw in
206
+ zip(custom_lora_paths, loras_with_weights)]
207
+
208
+ loras.extend(custom_loras)
209
+
210
+ style_selections_arr = []
211
+ for s in style_selections.strip().split(','):
212
+ style = s.strip()
213
+ if style in legal_style_names:
214
+ style_selections_arr.append(style)
215
+
216
+ if uov_input_image is not None:
217
+ im = Image.open(str(uov_input_image))
218
+ uov_input_image = np.array(im)
219
+
220
+ inpaint_input_image_dict = None
221
+ if inpaint_input_image is not None:
222
+ im = Image.open(str(inpaint_input_image))
223
+ inpaint_input_image = np.array(im)
224
+
225
+ if inpaint_input_mask is not None:
226
+ im = Image.open(str(inpaint_input_mask))
227
+ inpaint_input_mask = np.array(im)
228
+
229
+ inpaint_input_image_dict = {
230
+ 'image': inpaint_input_image,
231
+ 'mask': inpaint_input_mask
232
+ }
233
+
234
+ outpaint_selections_arr = []
235
+ for e in outpaint_selections.strip().split(','):
236
+ expansion = e.strip()
237
+ if expansion in outpaint_expansions:
238
+ outpaint_selections_arr.append(expansion)
239
+
240
+ image_prompts = []
241
+ image_prompt_config = [
242
+ (cn_img1, cn_stop1, cn_weight1, cn_type1),
243
+ (cn_img2, cn_stop2, cn_weight2, cn_type2),
244
+ (cn_img3, cn_stop3, cn_weight3, cn_type3),
245
+ (cn_img4, cn_stop4, cn_weight4, cn_type4)]
246
+ for config in image_prompt_config:
247
+ cn_img, cn_stop, cn_weight, cn_type = config
248
+ if cn_img is not None:
249
+ im = Image.open(str(cn_img))
250
+ cn_img = np.array(im)
251
+ if cn_stop is None:
252
+ cn_stop = flags.default_parameters[cn_type][0]
253
+ if cn_weight is None:
254
+ cn_weight = flags.default_parameters[cn_type][1]
255
+ image_prompts.append((cn_img, cn_stop, cn_weight, cn_type))
256
+
257
+ advanced_params = None
258
+
259
+ params = ImageGenerationParams(
260
+ prompt=prompt,
261
+ negative_prompt=negative_prompt,
262
+ style_selections=style_selections_arr,
263
+ performance_selection=performance_selection,
264
+ aspect_ratios_selection=aspect_ratios_selection,
265
+ image_number=image_number,
266
+ image_seed=image_seed,
267
+ sharpness=sharpness,
268
+ guidance_scale=guidance_scale,
269
+ base_model_name=base_model_name,
270
+ refiner_model_name=refiner_model_name,
271
+ refiner_switch=refiner_switch,
272
+ loras=loras,
273
+ uov_input_image=uov_input_image,
274
+ uov_method=uov_method,
275
+ upscale_value=uov_upscale_value,
276
+ outpaint_selections=outpaint_selections_arr,
277
+ inpaint_input_image=inpaint_input_image_dict,
278
+ image_prompts=image_prompts,
279
+ advanced_params=advanced_params,
280
+ inpaint_additional_prompt=inpaint_additional_prompt,
281
+ outpaint_distance_left=outpaint_distance_left,
282
+ outpaint_distance_top=outpaint_distance_top,
283
+ outpaint_distance_right=outpaint_distance_right,
284
+ outpaint_distance_bottom=outpaint_distance_bottom,
285
+ save_meta=True,
286
+ meta_scheme='fooocus',
287
+ save_extension='png',
288
+ save_name='',
289
+ require_base64=False,
290
+ )
291
+
292
+ print(f"[Predictor Predict] Params: {params.__dict__}")
293
+
294
+ async_task = worker_queue.add_task(
295
+ TaskType.text_2_img,
296
+ params)
297
+
298
+ if async_task is None:
299
+ print("[Task Queue] The task queue has reached limit")
300
+ raise Exception("The task queue has reached limit.")
301
+
302
+ results = blocking_get_task_result(async_task.job_id)
303
+
304
+ output_paths: List[Path] = []
305
+ output_seeds: List[str] = []
306
+ for r in results:
307
+ if r.finish_reason == GenerationFinishReason.success and r.im is not None:
308
+ output_seeds.append(r.seed)
309
+ output_paths.append(Path(os.path.join(output_dir, r.im)))
310
+
311
+ print(f"[Predictor Predict] Finished with {len(output_paths)} images")
312
+
313
+ if len(output_paths) == 0:
314
+ raise Exception("Process failed.")
315
+
316
+ return Output(seeds=output_seeds, paths=output_paths)
repositories/Fooocus/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Created By: ishwor subedi
3
+ Date: 2024-07-19
4
+ """
repositories/Fooocus/args_manager.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ldm_patched.modules.args_parser as args_parser
2
+
3
+ args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.")
4
+
5
+ args_parser.parser.add_argument("--preset", type=str, default=None, help="Apply specified UI preset.")
6
+ args_parser.parser.add_argument("--disable-preset-selection", action='store_true',
7
+ help="Disables preset selection in Gradio.")
8
+
9
+ args_parser.parser.add_argument("--language", type=str, default='default',
10
+ help="Translate UI using json files in [language] folder. "
11
+ "For example, [--language example] will use [language/example.json] for translation.")
12
+
13
+ # For example, https://github.com/lllyasviel/Fooocus/issues/849
14
+ args_parser.parser.add_argument("--disable-offload-from-vram", action="store_true",
15
+ help="Force loading models to vram when the unload can be avoided. "
16
+ "Some Mac users may need this.")
17
+
18
+ args_parser.parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
19
+ args_parser.parser.add_argument("--disable-image-log", action='store_true',
20
+ help="Prevent writing images and logs to hard drive.")
21
+
22
+ args_parser.parser.add_argument("--disable-analytics", action='store_true',
23
+ help="Disables analytics for Gradio.")
24
+
25
+ args_parser.parser.add_argument("--disable-metadata", action='store_true',
26
+ help="Disables saving metadata to images.")
27
+
28
+ args_parser.parser.add_argument("--disable-preset-download", action='store_true',
29
+ help="Disables downloading models for presets", default=False)
30
+
31
+ args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
32
+ help="Disables automatic description of uov images when prompt is empty", default=False)
33
+
34
+ args_parser.parser.add_argument("--always-download-new-model", action='store_true',
35
+ help="Always download newer models ", default=False)
36
+
37
+ args_parser.parser.set_defaults(
38
+ disable_cuda_malloc=True,
39
+ in_browser=True,
40
+ port=None
41
+ )
42
+
43
+ args_parser.args = args_parser.parser.parse_args()
44
+
45
+ # (Disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724)
46
+ args_parser.args.always_offload_from_vram = not args_parser.args.disable_offload_from_vram
47
+
48
+ if args_parser.args.disable_analytics:
49
+ import os
50
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
51
+
52
+ if args_parser.args.disable_in_browser:
53
+ args_parser.args.in_browser = False
54
+
55
+ args = args_parser.args
repositories/Fooocus/extras/BLIP/configs/bert_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
repositories/Fooocus/extras/BLIP/configs/caption_coco.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ coco_gt_root: 'annotation/coco_gt'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
7
+
8
+ # size of vit model; base or large
9
+ vit: 'base'
10
+ vit_grad_ckpt: False
11
+ vit_ckpt_layer: 0
12
+ batch_size: 32
13
+ init_lr: 1e-5
14
+
15
+ # vit: 'large'
16
+ # vit_grad_ckpt: True
17
+ # vit_ckpt_layer: 5
18
+ # batch_size: 16
19
+ # init_lr: 2e-6
20
+
21
+ image_size: 384
22
+
23
+ # generation configs
24
+ max_length: 20
25
+ min_length: 5
26
+ num_beams: 3
27
+ prompt: 'a picture of '
28
+
29
+ # optimizer
30
+ weight_decay: 0.05
31
+ min_lr: 0
32
+ max_epoch: 5
33
+
repositories/Fooocus/extras/BLIP/configs/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
repositories/Fooocus/extras/BLIP/configs/nlvr.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/NLVR2/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
6
+
7
+ #size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size_train: 16
10
+ batch_size_test: 64
11
+ vit_grad_ckpt: False
12
+ vit_ckpt_layer: 0
13
+ max_epoch: 15
14
+
15
+ image_size: 384
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-5
20
+ min_lr: 0
21
+
repositories/Fooocus/extras/BLIP/configs/nocaps.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/nocaps/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
6
+
7
+ vit: 'base'
8
+ batch_size: 32
9
+
10
+ image_size: 384
11
+
12
+ max_length: 20
13
+ min_length: 5
14
+ num_beams: 3
15
+ prompt: 'a picture of '
repositories/Fooocus/extras/BLIP/configs/pretrain.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
2
+ '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
3
+ ]
4
+ laion_path: ''
5
+
6
+ # size of vit model; base or large
7
+ vit: 'base'
8
+ vit_grad_ckpt: False
9
+ vit_ckpt_layer: 0
10
+
11
+ image_size: 224
12
+ batch_size: 75
13
+
14
+ queue_size: 57600
15
+ alpha: 0.4
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-4
20
+ min_lr: 1e-6
21
+ warmup_lr: 1e-6
22
+ lr_decay_rate: 0.9
23
+ max_epoch: 20
24
+ warmup_steps: 3000
25
+
26
+
27
+
repositories/Fooocus/extras/BLIP/configs/retrieval_coco.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ dataset: 'coco'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 12
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 256
28
+ negative_all_rank: True
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
repositories/Fooocus/extras/BLIP/configs/retrieval_flickr.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/flickr30k/'
2
+ ann_root: 'annotation'
3
+ dataset: 'flickr'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 10
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 128
28
+ negative_all_rank: False
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
repositories/Fooocus/extras/BLIP/configs/retrieval_msrvtt.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
6
+
7
+ # size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size: 64
10
+ k_test: 128
11
+ image_size: 384
12
+ num_frm_test: 8
repositories/Fooocus/extras/BLIP/configs/vqa.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
2
+ vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
3
+ train_files: ['vqa_train','vqa_val','vg_qa']
4
+ ann_root: 'annotation'
5
+
6
+ # set pretrained as a file path or an url
7
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
8
+
9
+ # size of vit model; base or large
10
+ vit: 'base'
11
+ batch_size_train: 16
12
+ batch_size_test: 32
13
+ vit_grad_ckpt: False
14
+ vit_ckpt_layer: 0
15
+ init_lr: 2e-5
16
+
17
+ image_size: 480
18
+
19
+ k_test: 128
20
+ inference: 'rank'
21
+
22
+ # optimizer
23
+ weight_decay: 0.05
24
+ min_lr: 0
25
+ max_epoch: 10
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "gradient_checkpointing": false,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0.dev0",
20
+ "type_vocab_size": 2,
21
+ "use_cache": true,
22
+ "vocab_size": 30522
23
+ }
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "do_lower_case": true
3
+ }
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff