Spaces:
Paused
Paused
from enum import Enum | |
from typing import List, Any, Optional, Union, Tuple, Dict | |
import numpy as np | |
from modules import scripts, processing, shared | |
from scripts import global_state | |
from scripts.processor import preprocessor_sliders_config, model_free_preprocessors | |
from scripts.logging import logger | |
from modules.api import api | |
def get_api_version() -> int: | |
return 2 | |
class ControlMode(Enum): | |
""" | |
The improved guess mode. | |
""" | |
BALANCED = "Balanced" | |
PROMPT = "My prompt is more important" | |
CONTROL = "ControlNet is more important" | |
class ResizeMode(Enum): | |
""" | |
Resize modes for ControlNet input images. | |
""" | |
RESIZE = "Just Resize" | |
INNER_FIT = "Crop and Resize" | |
OUTER_FIT = "Resize and Fill" | |
def int_value(self): | |
if self == ResizeMode.RESIZE: | |
return 0 | |
elif self == ResizeMode.INNER_FIT: | |
return 1 | |
elif self == ResizeMode.OUTER_FIT: | |
return 2 | |
assert False, "NOTREACHED" | |
resize_mode_aliases = { | |
'Inner Fit (Scale to Fit)': 'Crop and Resize', | |
'Outer Fit (Shrink to Fit)': 'Resize and Fill', | |
'Scale to Fit (Inner Fit)': 'Crop and Resize', | |
'Envelope (Outer Fit)': 'Resize and Fill', | |
} | |
def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: | |
if isinstance(value, str): | |
return ResizeMode(resize_mode_aliases.get(value, value)) | |
elif isinstance(value, int): | |
assert value >= 0 | |
if value == 3: # 'Just Resize (Latent upscale)' | |
return ResizeMode.RESIZE | |
if value >= len(ResizeMode): | |
logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.') | |
return ResizeMode.RESIZE | |
return [e for e in ResizeMode][value] | |
else: | |
return value | |
def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: | |
if isinstance(value, str): | |
return ControlMode(value) | |
elif isinstance(value, int): | |
return [e for e in ControlMode][value] | |
else: | |
return value | |
def visualize_inpaint_mask(img): | |
if img.ndim == 3 and img.shape[2] == 4: | |
result = img.copy() | |
mask = result[:, :, 3] | |
mask = 255 - mask // 2 | |
result[:, :, 3] = mask | |
return np.ascontiguousarray(result.copy()) | |
return img | |
def pixel_perfect_resolution( | |
image: np.ndarray, | |
target_H: int, | |
target_W: int, | |
resize_mode: ResizeMode, | |
) -> int: | |
""" | |
Calculate the estimated resolution for resizing an image while preserving aspect ratio. | |
The function first calculates scaling factors for height and width of the image based on the target | |
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger | |
scaling factor to estimate the new resolution. | |
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image | |
fits within the target dimensions, potentially leaving some empty space. | |
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target | |
dimensions are fully filled, potentially cropping the image. | |
After calculating the estimated resolution, the function prints some debugging information. | |
Args: | |
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels]. | |
target_H (int): The target height for the image. | |
target_W (int): The target width for the image. | |
resize_mode (ResizeMode): The mode for resizing. | |
Returns: | |
int: The estimated resolution after resizing. | |
""" | |
raw_H, raw_W, _ = image.shape | |
k0 = float(target_H) / float(raw_H) | |
k1 = float(target_W) / float(raw_W) | |
if resize_mode == ResizeMode.OUTER_FIT: | |
estimation = min(k0, k1) * float(min(raw_H, raw_W)) | |
else: | |
estimation = max(k0, k1) * float(min(raw_H, raw_W)) | |
logger.debug(f"Pixel Perfect Computation:") | |
logger.debug(f"resize_mode = {resize_mode}") | |
logger.debug(f"raw_H = {raw_H}") | |
logger.debug(f"raw_W = {raw_W}") | |
logger.debug(f"target_H = {target_H}") | |
logger.debug(f"target_W = {target_W}") | |
logger.debug(f"estimation = {estimation}") | |
return int(np.round(estimation)) | |
InputImage = Union[np.ndarray, str] | |
InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] | |
class ControlNetUnit: | |
""" | |
Represents an entire ControlNet processing unit. | |
""" | |
def __init__( | |
self, | |
enabled: bool = True, | |
module: Optional[str] = None, | |
model: Optional[str] = None, | |
weight: float = 1.0, | |
image: Optional[InputImage] = None, | |
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT, | |
low_vram: bool = False, | |
processor_res: int = -1, | |
threshold_a: float = -1, | |
threshold_b: float = -1, | |
guidance_start: float = 0.0, | |
guidance_end: float = 1.0, | |
pixel_perfect: bool = False, | |
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED, | |
**_kwargs, | |
): | |
self.enabled = enabled | |
self.module = module | |
self.model = model | |
self.weight = weight | |
self.image = image | |
self.resize_mode = resize_mode | |
self.low_vram = low_vram | |
self.processor_res = processor_res | |
self.threshold_a = threshold_a | |
self.threshold_b = threshold_b | |
self.guidance_start = guidance_start | |
self.guidance_end = guidance_end | |
self.pixel_perfect = pixel_perfect | |
self.control_mode = control_mode | |
def __eq__(self, other): | |
if not isinstance(other, ControlNetUnit): | |
return False | |
return vars(self) == vars(other) | |
def to_base64_nparray(encoding: str): | |
""" | |
Convert a base64 image into the image type the extension uses | |
""" | |
return np.array(api.decode_base64_to_image(encoding)).astype('uint8') | |
def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]: | |
""" | |
Fetch ControlNet processing units from a StableDiffusionProcessing. | |
""" | |
return get_all_units(p.scripts, p.script_args) | |
def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]: | |
""" | |
Fetch ControlNet processing units from an existing script runner. | |
Use this function to fetch units from the list of all scripts arguments. | |
""" | |
cn_script = find_cn_script(script_runner) | |
if cn_script: | |
return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to]) | |
return [] | |
def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]: | |
""" | |
Fetch ControlNet processing units from ControlNet script arguments. | |
Use `external_code.get_all_units` to fetch units from the list of all scripts arguments. | |
""" | |
def is_stale_unit(script_arg: Any) -> bool: | |
""" Returns whether the script_arg is potentially an stale version of | |
ControlNetUnit created before module reload.""" | |
return ( | |
'ControlNetUnit' in type(script_arg).__name__ and | |
not isinstance(script_arg, ControlNetUnit) | |
) | |
def is_controlnet_unit(script_arg: Any) -> bool: | |
""" Returns whether the script_arg is ControlNetUnit or anything that | |
can be treated like ControlNetUnit. """ | |
return ( | |
isinstance(script_arg, (ControlNetUnit, dict)) or | |
( | |
hasattr(script_arg, '__dict__') and | |
set(vars(ControlNetUnit()).keys()).issubset( | |
set(vars(script_arg).keys())) | |
) | |
) | |
all_units = [ | |
to_processing_unit(script_arg) | |
for script_arg in script_args | |
if is_controlnet_unit(script_arg) | |
] | |
if not all_units: | |
logger.warning( | |
"No ControlNetUnit detected in args. It is very likely that you are having an extension conflict." | |
f"Here are args received by ControlNet: {script_args}.") | |
if any(is_stale_unit(script_arg) for script_arg in script_args): | |
logger.debug( | |
"Stale version of ControlNetUnit detected. The ControlNetUnit received" | |
"by ControlNet is created before the newest load of ControlNet extension." | |
"They will still be used by ControlNet as long as they provide same fields" | |
"defined in the newest version of ControlNetUnit." | |
) | |
return all_units | |
def get_single_unit_from(script_args: List[Any], index: int = 0) -> Optional[ControlNetUnit]: | |
""" | |
Fetch a single ControlNet processing unit from ControlNet script arguments. | |
The list must not contain script positional arguments. It must only contain processing units. | |
""" | |
i = 0 | |
while i < len(script_args) and index >= 0: | |
if index == 0 and script_args[i] is not None: | |
return to_processing_unit(script_args[i]) | |
i += 1 | |
index -= 1 | |
return None | |
def get_max_models_num(): | |
""" | |
Fetch the maximum number of allowed ControlNet models. | |
""" | |
max_models_num = shared.opts.data.get("control_net_unit_count", 3) | |
return max_models_num | |
def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: | |
""" | |
Convert different types to processing unit. | |
If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details. | |
""" | |
ext_compat_keys = { | |
'guessmode': 'guess_mode', | |
'guidance': 'guidance_end', | |
'lowvram': 'low_vram', | |
'input_image': 'image' | |
} | |
if isinstance(unit, dict): | |
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} | |
mask = None | |
if 'mask' in unit: | |
mask = unit['mask'] | |
del unit['mask'] | |
if 'image' in unit and not isinstance(unit['image'], dict): | |
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[ | |
'image'] else None | |
if 'guess_mode' in unit: | |
logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.') | |
unit = ControlNetUnit(**unit) | |
# temporary, check #602 | |
# assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]' | |
return unit | |
def update_cn_script_in_processing( | |
p: processing.StableDiffusionProcessing, | |
cn_units: List[ControlNetUnit], | |
**_kwargs, # for backwards compatibility | |
): | |
""" | |
Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`. | |
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. | |
Does not update `p.script_args` if any of the folling is true: | |
- ControlNet is not present in `p.scripts` | |
- `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet | |
""" | |
cn_units_type = type(cn_units) if type(cn_units) in (list, tuple) else list | |
script_args = list(p.script_args) | |
update_cn_script_in_place(p.scripts, script_args, cn_units) | |
p.script_args = cn_units_type(script_args) | |
def update_cn_script_in_place( | |
script_runner: scripts.ScriptRunner, | |
script_args: List[Any], | |
cn_units: List[ControlNetUnit], | |
**_kwargs, # for backwards compatibility | |
): | |
""" | |
Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`. | |
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. | |
Does not update `script_args` if any of the folling is true: | |
- ControlNet is not present in `script_runner` | |
- `script_args` is not filled with script arguments for scripts that are processed before ControlNet | |
""" | |
cn_script = find_cn_script(script_runner) | |
if cn_script is None or len(script_args) < cn_script.args_from: | |
return | |
# fill in remaining parameters to satisfy max models, just in case script needs it. | |
max_models = shared.opts.data.get("control_net_unit_count", 3) | |
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) | |
cn_script_args_diff = 0 | |
for script in script_runner.alwayson_scripts: | |
if script is cn_script: | |
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) | |
script_args[script.args_from:script.args_to] = cn_units | |
script.args_to = script.args_from + len(cn_units) | |
else: | |
script.args_from += cn_script_args_diff | |
script.args_to += cn_script_args_diff | |
def get_models(update: bool = False) -> List[str]: | |
""" | |
Fetch the list of available models. | |
Each value is a valid candidate of `ControlNetUnit.model`. | |
Keyword arguments: | |
update -- Whether to refresh the list from disk. (default False) | |
""" | |
if update: | |
global_state.update_cn_models() | |
return list(global_state.cn_models_names.values()) | |
def get_modules(alias_names: bool = False) -> List[str]: | |
""" | |
Fetch the list of available preprocessors. | |
Each value is a valid candidate of `ControlNetUnit.module`. | |
Keyword arguments: | |
alias_names -- Whether to get the ui alias names instead of internal keys | |
""" | |
modules = list(global_state.cn_preprocessor_modules.keys()) | |
if alias_names: | |
modules = [global_state.preprocessor_aliases.get(module, module) for module in modules] | |
return modules | |
def get_modules_detail(alias_names: bool = False) -> Dict[str, Any]: | |
""" | |
get the detail of all preprocessors including | |
sliders: the slider config in Auto1111 webUI | |
Keyword arguments: | |
alias_names -- Whether to get the module detail with alias names instead of internal keys | |
""" | |
_module_detail = {} | |
_module_list = get_modules(False) | |
_module_list_alias = get_modules(True) | |
_output_list = _module_list if not alias_names else _module_list_alias | |
for index, module in enumerate(_output_list): | |
if _module_list[index] in preprocessor_sliders_config: | |
_module_detail[module] = { | |
"model_free": module in model_free_preprocessors, | |
"sliders": preprocessor_sliders_config[_module_list[index]] | |
} | |
else: | |
_module_detail[module] = { | |
"model_free": False, | |
"sliders": [] | |
} | |
return _module_detail | |
def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]: | |
""" | |
Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script. | |
""" | |
if script_runner is None: | |
return None | |
for script in script_runner.alwayson_scripts: | |
if is_cn_script(script): | |
return script | |
def is_cn_script(script: scripts.Script) -> bool: | |
""" | |
Determine whether `script` is a ControlNet script. | |
""" | |
return script.title().lower() == 'controlnet' | |