Spaces:
Paused
Paused
import os | |
from copy import copy | |
from enum import Enum | |
from typing import Tuple, List | |
from modules import img2img, processing, shared, script_callbacks | |
from scripts import external_code | |
class BatchHijack: | |
def __init__(self): | |
self.is_batch = False | |
self.batch_index = 0 | |
self.batch_size = 1 | |
self.init_seed = None | |
self.init_subseed = None | |
self.process_batch_callbacks = [self.on_process_batch] | |
self.process_batch_each_callbacks = [] | |
self.postprocess_batch_each_callbacks = [self.on_postprocess_batch_each] | |
self.postprocess_batch_callbacks = [self.on_postprocess_batch] | |
def img2img_process_batch_hijack(self, p, *args, **kwargs): | |
cn_is_batch, batches, output_dir, _ = get_cn_batches(p) | |
if not cn_is_batch: | |
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs) | |
self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir) | |
try: | |
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs) | |
finally: | |
self.dispatch_callbacks(self.postprocess_batch_callbacks, p) | |
def processing_process_images_hijack(self, p, *args, **kwargs): | |
if self.is_batch: | |
# we are in img2img batch tab, do a single batch iteration | |
return self.process_images_cn_batch(p, *args, **kwargs) | |
cn_is_batch, batches, output_dir, input_file_names = get_cn_batches(p) | |
if not cn_is_batch: | |
# we are not in batch mode, fallback to original function | |
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs) | |
output_images = [] | |
try: | |
self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir) | |
for batch_i in range(self.batch_size): | |
processed = self.process_images_cn_batch(p, *args, **kwargs) | |
if shared.opts.data.get('controlnet_show_batch_images_in_ui', False): | |
output_images.extend(processed.images[processed.index_of_first_image:]) | |
if output_dir: | |
self.save_images(output_dir, input_file_names[batch_i], processed.images[processed.index_of_first_image:]) | |
if shared.state.interrupted: | |
break | |
finally: | |
self.dispatch_callbacks(self.postprocess_batch_callbacks, p) | |
if output_images: | |
processed.images = output_images | |
else: | |
processed = processing.Processed(p, [], p.seed) | |
return processed | |
def process_images_cn_batch(self, p, *args, **kwargs): | |
self.dispatch_callbacks(self.process_batch_each_callbacks, p) | |
old_detectmap_output = shared.opts.data.get('control_net_no_detectmap', False) | |
try: | |
shared.opts.data.update({'control_net_no_detectmap': True}) | |
processed = getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs) | |
finally: | |
shared.opts.data.update({'control_net_no_detectmap': old_detectmap_output}) | |
self.dispatch_callbacks(self.postprocess_batch_each_callbacks, p, processed) | |
# do not go past control net batch size | |
if self.batch_index >= self.batch_size: | |
shared.state.interrupted = True | |
return processed | |
def save_images(self, output_dir, init_image_path, output_images): | |
os.makedirs(output_dir, exist_ok=True) | |
for n, processed_image in enumerate(output_images): | |
filename = os.path.basename(init_image_path) | |
if n > 0: | |
left, right = os.path.splitext(filename) | |
filename = f"{left}-{n}{right}" | |
if processed_image.mode == 'RGBA': | |
processed_image = processed_image.convert("RGB") | |
processed_image.save(os.path.join(output_dir, filename)) | |
def do_hijack(self): | |
script_callbacks.on_script_unloaded(self.undo_hijack) | |
hijack_function( | |
module=img2img, | |
name='process_batch', | |
new_name='__controlnet_original_process_batch', | |
new_value=self.img2img_process_batch_hijack, | |
) | |
hijack_function( | |
module=processing, | |
name='process_images_inner', | |
new_name='__controlnet_original_process_images_inner', | |
new_value=self.processing_process_images_hijack | |
) | |
def undo_hijack(self): | |
unhijack_function( | |
module=img2img, | |
name='process_batch', | |
new_name='__controlnet_original_process_batch', | |
) | |
unhijack_function( | |
module=processing, | |
name='process_images_inner', | |
new_name='__controlnet_original_process_images_inner', | |
) | |
def adjust_job_count(self, p): | |
if shared.state.job_count == -1: | |
shared.state.job_count = p.n_iter | |
shared.state.job_count *= self.batch_size | |
def on_process_batch(self, p, batches, output_dir, *args): | |
print('controlnet batch mode') | |
self.is_batch = True | |
self.batch_index = 0 | |
self.batch_size = len(batches) | |
processing.fix_seed(p) | |
if shared.opts.data.get('controlnet_increment_seed_during_batch', False): | |
self.init_seed = p.seed | |
self.init_subseed = p.subseed | |
self.adjust_job_count(p) | |
p.do_not_save_grid = True | |
p.do_not_save_samples = bool(output_dir) | |
def on_postprocess_batch_each(self, p, *args): | |
self.batch_index += 1 | |
if shared.opts.data.get('controlnet_increment_seed_during_batch', False): | |
p.seed = p.seed + len(p.all_prompts) | |
p.subseed = p.subseed + len(p.all_prompts) | |
def on_postprocess_batch(self, p, *args): | |
self.is_batch = False | |
self.batch_index = 0 | |
self.batch_size = 1 | |
if shared.opts.data.get('controlnet_increment_seed_during_batch', False): | |
p.seed = self.init_seed | |
p.all_seeds = [self.init_seed] | |
p.subseed = self.init_subseed | |
p.all_subseeds = [self.init_subseed] | |
def dispatch_callbacks(self, callbacks, *args): | |
for callback in callbacks: | |
callback(*args) | |
def hijack_function(module, name, new_name, new_value): | |
# restore original function in case of reload | |
unhijack_function(module=module, name=name, new_name=new_name) | |
setattr(module, new_name, getattr(module, name)) | |
setattr(module, name, new_value) | |
def unhijack_function(module, name, new_name): | |
if hasattr(module, new_name): | |
setattr(module, name, getattr(module, new_name)) | |
delattr(module, new_name) | |
class InputMode(Enum): | |
SIMPLE = "simple" | |
BATCH = "batch" | |
def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]: | |
units = external_code.get_all_units_in_processing(p) | |
units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)] | |
any_unit_is_batch = False | |
output_dir = '' | |
input_file_names = [] | |
for unit in units: | |
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH: | |
any_unit_is_batch = True | |
output_dir = getattr(unit, 'output_dir', '') | |
if isinstance(unit.batch_images, str): | |
unit.batch_images = shared.listfiles(unit.batch_images) | |
input_file_names = unit.batch_images | |
if any_unit_is_batch: | |
cn_batch_size = min(len(getattr(unit, 'batch_images', [])) | |
for unit in units | |
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH) | |
else: | |
cn_batch_size = 1 | |
batches = [[] for _ in range(cn_batch_size)] | |
for i in range(cn_batch_size): | |
for unit in units: | |
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.SIMPLE: | |
batches[i].append(unit.image) | |
else: | |
batches[i].append(unit.batch_images[i]) | |
return any_unit_is_batch, batches, output_dir, input_file_names | |
instance = BatchHijack() | |