# coding=utf-8 # Copyright 2024 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import unittest import pytest from diffusers import __version__ from diffusers.utils import deprecate from diffusers.utils.testing_utils import str_to_bool # Used to test the hub USER = "__DUMMY_TRANSFORMERS_USER__" ENDPOINT_STAGING = "https://hub-ci.huggingface.co" # Not critical, only usable on the sandboxed CI instance. TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" class DeprecateTester(unittest.TestCase): higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:]) lower_version = "0.0.1" def test_deprecate_function_arg(self): kwargs = {"deprecated_arg": 4} with self.assertWarns(FutureWarning) as warning: output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs) assert output == 4 assert ( str(warning.warning) == f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}." " message" ) def test_deprecate_function_arg_tuple(self): kwargs = {"deprecated_arg": 4} with self.assertWarns(FutureWarning) as warning: output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs) assert output == 4 assert ( str(warning.warning) == f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}." " message" ) def test_deprecate_function_args(self): kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8} with self.assertWarns(FutureWarning) as warning: output_1, output_2 = deprecate( ("deprecated_arg_1", self.higher_version, "Hey"), ("deprecated_arg_2", self.higher_version, "Hey"), take_from=kwargs, ) assert output_1 == 4 assert output_2 == 8 assert ( str(warning.warnings[0].message) == "The `deprecated_arg_1` argument is deprecated and will be removed in version" f" {self.higher_version}. Hey" ) assert ( str(warning.warnings[1].message) == "The `deprecated_arg_2` argument is deprecated and will be removed in version" f" {self.higher_version}. Hey" ) def test_deprecate_function_incorrect_arg(self): kwargs = {"deprecated_arg": 4} with self.assertRaises(TypeError) as error: deprecate(("wrong_arg", self.higher_version, "message"), take_from=kwargs) assert "test_deprecate_function_incorrect_arg in" in str(error.exception) assert "line" in str(error.exception) assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception) def test_deprecate_arg_no_kwarg(self): with self.assertWarns(FutureWarning) as warning: deprecate(("deprecated_arg", self.higher_version, "message")) assert ( str(warning.warning) == f"`deprecated_arg` is deprecated and will be removed in version {self.higher_version}. message" ) def test_deprecate_args_no_kwarg(self): with self.assertWarns(FutureWarning) as warning: deprecate( ("deprecated_arg_1", self.higher_version, "Hey"), ("deprecated_arg_2", self.higher_version, "Hey"), ) assert ( str(warning.warnings[0].message) == f"`deprecated_arg_1` is deprecated and will be removed in version {self.higher_version}. Hey" ) assert ( str(warning.warnings[1].message) == f"`deprecated_arg_2` is deprecated and will be removed in version {self.higher_version}. Hey" ) def test_deprecate_class_obj(self): class Args: arg = 5 with self.assertWarns(FutureWarning) as warning: arg = deprecate(("arg", self.higher_version, "message"), take_from=Args()) assert arg == 5 assert ( str(warning.warning) == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" ) def test_deprecate_class_objs(self): class Args: arg = 5 foo = 7 with self.assertWarns(FutureWarning) as warning: arg_1, arg_2 = deprecate( ("arg", self.higher_version, "message"), ("foo", self.higher_version, "message"), ("does not exist", self.higher_version, "message"), take_from=Args(), ) assert arg_1 == 5 assert arg_2 == 7 assert ( str(warning.warning) == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" ) assert ( str(warning.warnings[0].message) == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" ) assert ( str(warning.warnings[1].message) == f"The `foo` attribute is deprecated and will be removed in version {self.higher_version}. message" ) def test_deprecate_incorrect_version(self): kwargs = {"deprecated_arg": 4} with self.assertRaises(ValueError) as error: deprecate(("wrong_arg", self.lower_version, "message"), take_from=kwargs) assert ( str(error.exception) == "The deprecation tuple ('wrong_arg', '0.0.1', 'message') should be removed since diffusers' version" f" {__version__} is >= {self.lower_version}" ) def test_deprecate_incorrect_no_standard_warn(self): with self.assertWarns(FutureWarning) as warning: deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) assert str(warning.warning) == "This message is better!!!" def test_deprecate_stacklevel(self): with self.assertWarns(FutureWarning) as warning: deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) assert str(warning.warning) == "This message is better!!!" assert "diffusers/tests/others/test_utils.py" in warning.filename def parse_flag_from_env(key, default=False): try: value = os.environ[key] except KeyError: # KEY isn't set, default to `default`. _value = default else: # KEY is set, convert it to True or False. try: _value = str_to_bool(value) except ValueError: # More values are supported, but let's keep the message simple. raise ValueError(f"If set, {key} must be yes or no.") return _value _run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False) def is_staging_test(test_case): """ Decorator marking a test as a staging test. Those tests will run using the staging environment of huggingface.co instead of the real model hub. """ if not _run_staging: return unittest.skip("test is staging test")(test_case) else: return pytest.mark.is_staging_test()(test_case)