import pickle as pkl import unittest from dataclasses import dataclass from typing import List, Union import numpy as np import PIL.Image from diffusers.utils.outputs import BaseOutput from diffusers.utils.testing_utils import require_torch @dataclass class CustomOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] class ConfigTester(unittest.TestCase): def test_outputs_single_attribute(self): outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4)) # check every way of getting the attribute assert isinstance(outputs.images, np.ndarray) assert outputs.images.shape == (1, 3, 4, 4) assert isinstance(outputs["images"], np.ndarray) assert outputs["images"].shape == (1, 3, 4, 4) assert isinstance(outputs[0], np.ndarray) assert outputs[0].shape == (1, 3, 4, 4) # test with a non-tensor attribute outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) # check every way of getting the attribute assert isinstance(outputs.images, list) assert isinstance(outputs.images[0], PIL.Image.Image) assert isinstance(outputs["images"], list) assert isinstance(outputs["images"][0], PIL.Image.Image) assert isinstance(outputs[0], list) assert isinstance(outputs[0][0], PIL.Image.Image) def test_outputs_dict_init(self): # test output reinitialization with a `dict` for compatibility with `accelerate` outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)}) # check every way of getting the attribute assert isinstance(outputs.images, np.ndarray) assert outputs.images.shape == (1, 3, 4, 4) assert isinstance(outputs["images"], np.ndarray) assert outputs["images"].shape == (1, 3, 4, 4) assert isinstance(outputs[0], np.ndarray) assert outputs[0].shape == (1, 3, 4, 4) # test with a non-tensor attribute outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]}) # check every way of getting the attribute assert isinstance(outputs.images, list) assert isinstance(outputs.images[0], PIL.Image.Image) assert isinstance(outputs["images"], list) assert isinstance(outputs["images"][0], PIL.Image.Image) assert isinstance(outputs[0], list) assert isinstance(outputs[0][0], PIL.Image.Image) def test_outputs_serialization(self): outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) serialized = pkl.dumps(outputs_orig) outputs_copy = pkl.loads(serialized) # Check original and copy are equal assert dir(outputs_orig) == dir(outputs_copy) assert dict(outputs_orig) == dict(outputs_copy) assert vars(outputs_orig) == vars(outputs_copy) @require_torch def test_torch_pytree(self): # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) # this is important for DistributedDataParallel gradient synchronization with static_graph=True import torch import torch.utils._pytree data = np.random.rand(1, 3, 4, 4) x = CustomOutput(images=data) self.assertFalse(torch.utils._pytree._is_leaf(x)) expected_flat_outs = [data] expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()]) actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) self.assertEqual(expected_flat_outs, actual_flat_outs) self.assertEqual(expected_tree_spec, actual_tree_spec) unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) self.assertEqual(x, unflattened_x)