File size: 6,089 Bytes
3f9c56c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import unittest
import importlib

import numpy as np

utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils')
utils.setup_test_env()

from copy import copy
from scripts import external_code
from scripts import controlnet
from modules import scripts, ui, shared


class TestExternalCodeWorking(unittest.TestCase):
    max_models = 6
    args_offset = 10

    def setUp(self):
        self.scripts = copy(scripts.scripts_txt2img)
        self.scripts.initialize_scripts(False)
        ui.create_ui()
        self.cn_script = controlnet.Script()
        self.cn_script.args_from = self.args_offset
        self.cn_script.args_to = self.args_offset + self.max_models
        self.scripts.alwayson_scripts = [self.cn_script]
        self.script_args = [None] * self.cn_script.args_from

        self.initial_max_models = shared.opts.data.get("control_net_unit_count", 3)
        shared.opts.data.update(control_net_unit_count=self.max_models)

        self.extra_models = 0

    def tearDown(self):
        shared.opts.data.update(control_net_unit_count=self.initial_max_models)

    def get_expected_args_to(self):
        args_len = max(self.max_models, len(self.cn_units))
        return self.args_offset + args_len

    def assert_update_in_place_ok(self):
        external_code.update_cn_script_in_place(self.scripts, self.script_args, self.cn_units)
        self.assertEqual(self.cn_script.args_to, self.get_expected_args_to())

    def test_empty_resizes_min_args(self):
        self.cn_units = []
        self.assert_update_in_place_ok()

    def test_empty_resizes_extra_args(self):
        extra_models = 1
        self.cn_units = [external_code.ControlNetUnit()] * (self.max_models + extra_models)
        self.assert_update_in_place_ok()


class TestControlNetUnitConversion(unittest.TestCase):
    def setUp(self):
        self.dummy_image = 'base64...'
        self.input = {}
        self.expected = external_code.ControlNetUnit()

    def assert_converts_to_expected(self):
        self.assertEqual(vars(external_code.to_processing_unit(self.input)), vars(self.expected))

    def test_empty_dict_works(self):
        self.assert_converts_to_expected()

    def test_image_works(self):
        self.input = {
            'image': self.dummy_image
        }
        self.expected = external_code.ControlNetUnit(image=self.dummy_image)
        self.assert_converts_to_expected()

    def test_image_alias_works(self):
        self.input = {
            'input_image': self.dummy_image
        }
        self.expected = external_code.ControlNetUnit(image=self.dummy_image)
        self.assert_converts_to_expected()

    def test_masked_image_works(self):
        self.input = {
            'image': self.dummy_image,
            'mask': self.dummy_image,
        }
        self.expected = external_code.ControlNetUnit(image={'image': self.dummy_image, 'mask': self.dummy_image})
        self.assert_converts_to_expected()


class TestControlNetUnitImageToDict(unittest.TestCase):
    def setUp(self):
        self.dummy_image = utils.readImage("test/test_files/img2img_basic.png")
        self.input = external_code.ControlNetUnit()
        self.expected_image = external_code.to_base64_nparray(self.dummy_image)
        self.expected_mask = external_code.to_base64_nparray(self.dummy_image)

    def assert_dict_is_valid(self):
        actual_dict = controlnet.image_dict_from_any(self.input.image)
        self.assertEqual(actual_dict['image'].tolist(), self.expected_image.tolist())
        self.assertEqual(actual_dict['mask'].tolist(), self.expected_mask.tolist())

    def test_none(self):
        self.assertEqual(controlnet.image_dict_from_any(self.input.image), None)

    def test_image_without_mask(self):
        self.input.image = self.dummy_image
        self.expected_mask = np.zeros_like(self.expected_image, dtype=np.uint8)
        self.assert_dict_is_valid()

    def test_masked_image_tuple(self):
        self.input.image = (self.dummy_image, self.dummy_image,)
        self.assert_dict_is_valid()

    def test_masked_image_dict(self):
        self.input.image = {'image': self.dummy_image, 'mask': self.dummy_image}
        self.assert_dict_is_valid()


class TestPixelPerfectResolution(unittest.TestCase):
    def test_outer_fit(self):
        image = np.zeros((100, 100, 3))
        target_H, target_W = 50, 100
        resize_mode = external_code.ResizeMode.OUTER_FIT
        result = external_code.pixel_perfect_resolution(image, target_H, target_W, resize_mode)
        expected = 50  # manually computed expected result
        self.assertEqual(result, expected)

    def test_inner_fit(self):
        image = np.zeros((100, 100, 3))
        target_H, target_W = 50, 100
        resize_mode = external_code.ResizeMode.INNER_FIT
        result = external_code.pixel_perfect_resolution(image, target_H, target_W, resize_mode)
        expected = 100  # manually computed expected result
        self.assertEqual(result, expected)


class TestGetAllUnitsFrom(unittest.TestCase):
    def test_none(self):
        self.assertListEqual(external_code.get_all_units_from([None]), [])

    def test_bool(self):
        self.assertListEqual(external_code.get_all_units_from([True]), [])

    def test_inheritance(self):
        class Foo(external_code.ControlNetUnit):
            def __init__(self):
                super().__init__(self)
                self.bar = 'a'
        
        foo = Foo()
        self.assertListEqual(external_code.get_all_units_from([foo]), [foo])

    def test_dict(self):
        units = external_code.get_all_units_from([{}])
        self.assertGreater(len(units), 0)
        self.assertIsInstance(units[0], external_code.ControlNetUnit)

    def test_unitlike(self):
        class Foo(object):
            """ bar """

        foo = Foo()
        for key in vars(external_code.ControlNetUnit()).keys():
            setattr(foo, key, True)
        setattr(foo, 'bar', False)
        self.assertListEqual(external_code.get_all_units_from([foo]), [foo])


if __name__ == '__main__':
    unittest.main()