QuintW's picture
Upload 1350 files
3f9c56c
raw
history blame
No virus
6.12 kB
from typing import Any, Dict, List
import unittest
from PIL import Image
import numpy as np
import importlib
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
utils.setup_test_env()
from scripts import external_code, processor
from scripts.controlnet import prepare_mask, Script, set_numpy_seed
from modules import processing
class TestPrepareMask(unittest.TestCase):
def test_prepare_mask(self):
p = processing.StableDiffusionProcessing()
p.inpainting_mask_invert = True
p.mask_blur = 5
mask = Image.new("RGB", (10, 10), color="white")
processed_mask = prepare_mask(mask, p)
# Check that mask is correctly converted to grayscale
self.assertTrue(processed_mask.mode, "L")
# Check that mask colors are correctly inverted
self.assertEqual(
processed_mask.getpixel((0, 0)), 0
) # inverted white should be black
p.inpainting_mask_invert = False
processed_mask = prepare_mask(mask, p)
# Check that mask colors are not inverted when 'inpainting_mask_invert' is False
self.assertEqual(
processed_mask.getpixel((0, 0)), 255
) # white should remain white
p.mask_blur = 0
mask = Image.new("RGB", (10, 10), color="black")
processed_mask = prepare_mask(mask, p)
# Check that mask is not blurred when 'mask_blur' is 0
self.assertEqual(
processed_mask.getpixel((0, 0)), 0
) # black should remain black
class TestSetNumpySeed(unittest.TestCase):
def test_seed_subseed_minus_one(self):
p = processing.StableDiffusionProcessing()
p.seed = -1
p.subseed = -1
p.all_seeds = [123, 456]
expected_seed = (123 + 123) & 0xFFFFFFFF
self.assertEqual(set_numpy_seed(p), expected_seed)
def test_valid_seed_subseed(self):
p = processing.StableDiffusionProcessing()
p.seed = 50
p.subseed = 100
p.all_seeds = [123, 456]
expected_seed = (50 + 100) & 0xFFFFFFFF
self.assertEqual(set_numpy_seed(p), expected_seed)
def test_invalid_seed_subseed(self):
p = processing.StableDiffusionProcessing()
p.seed = "invalid"
p.subseed = 2.5
p.all_seeds = [123, 456]
self.assertEqual(set_numpy_seed(p), None)
def test_empty_all_seeds(self):
p = processing.StableDiffusionProcessing()
p.seed = -1
p.subseed = 2
p.all_seeds = []
self.assertEqual(set_numpy_seed(p), None)
def test_random_state_change(self):
p = processing.StableDiffusionProcessing()
p.seed = 50
p.subseed = 100
p.all_seeds = [123, 456]
expected_seed = (50 + 100) & 0xFFFFFFFF
np.random.seed(0) # set a known seed
before_random = np.random.randint(0, 1000) # get a random integer
seed = set_numpy_seed(p)
self.assertEqual(seed, expected_seed)
after_random = np.random.randint(0, 1000) # get another random integer
self.assertNotEqual(before_random, after_random)
class MockImg2ImgProcessing(processing.StableDiffusionProcessing):
"""Mock the Img2Img processing as the WebUI version have dependency on
`sd_model`."""
def __init__(self, init_images, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_images = init_images
class TestScript(unittest.TestCase):
sample_base64_image = (
"data:image/png;base64,"
"iVBORw0KGgoAAAANSUhEUgAAARMAAAC3CAIAAAC+MS2jAAAAqUlEQVR4nO3BAQ"
"0AAADCoPdPbQ8HFAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
"AAAAAAAAAAAAAAAAAAAAAAAA/wZOlAAB5tU+nAAAAABJRU5ErkJggg=="
)
sample_np_image = np.array(
[[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8
)
def test_bound_check_params(self):
def param_required(module: str, param: str) -> bool:
configs = processor.preprocessor_sliders_config[module]
config_index = ("processor_res", "threshold_a", "threshold_b").index(param)
return config_index < len(configs) and configs[config_index] is not None
for module in processor.preprocessor_sliders_config.keys():
for param in ("processor_res", "threshold_a", "threshold_b"):
with self.subTest(param=param, module=module):
unit = external_code.ControlNetUnit(
module=module,
**{param: -100},
)
Script.bound_check_params(unit)
if param_required(module, param):
self.assertGreaterEqual(getattr(unit, param), 0)
else:
self.assertEqual(getattr(unit, param), -100)
def test_choose_input_image(self):
with self.subTest(name="no image"):
with self.assertRaises(ValueError):
Script.choose_input_image(
p=processing.StableDiffusionProcessing(),
unit=external_code.ControlNetUnit(),
idx=0,
)
with self.subTest(name="control net input"):
_, from_a1111 = Script.choose_input_image(
p=MockImg2ImgProcessing(init_images=[TestScript.sample_np_image]),
unit=external_code.ControlNetUnit(
image=TestScript.sample_base64_image, module="none"
),
idx=0,
)
self.assertFalse(from_a1111)
with self.subTest(name="A1111 input"):
_, from_a1111 = Script.choose_input_image(
p=MockImg2ImgProcessing(init_images=[TestScript.sample_np_image]),
unit=external_code.ControlNetUnit(module="none"),
idx=0,
)
self.assertTrue(from_a1111)
if __name__ == "__main__":
unittest.main()