|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import shutil |
|
import sys |
|
import tempfile |
|
import unittest |
|
|
|
|
|
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
|
sys.path.append(os.path.join(git_repo_path, "utils")) |
|
|
|
import check_copies |
|
|
|
|
|
|
|
|
|
REFERENCE_CODE = """ \""" |
|
Output class for the scheduler's `step` function output. |
|
|
|
Args: |
|
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the |
|
denoising loop. |
|
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. |
|
`pred_original_sample` can be used to preview progress or for guidance. |
|
\""" |
|
|
|
prev_sample: torch.Tensor |
|
pred_original_sample: Optional[torch.Tensor] = None |
|
""" |
|
|
|
|
|
class CopyCheckTester(unittest.TestCase): |
|
def setUp(self): |
|
self.diffusers_dir = tempfile.mkdtemp() |
|
os.makedirs(os.path.join(self.diffusers_dir, "schedulers/")) |
|
check_copies.DIFFUSERS_PATH = self.diffusers_dir |
|
shutil.copy( |
|
os.path.join(git_repo_path, "src/diffusers/schedulers/scheduling_ddpm.py"), |
|
os.path.join(self.diffusers_dir, "schedulers/scheduling_ddpm.py"), |
|
) |
|
|
|
def tearDown(self): |
|
check_copies.DIFFUSERS_PATH = "src/diffusers" |
|
shutil.rmtree(self.diffusers_dir) |
|
|
|
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None): |
|
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code |
|
if overwrite_result is not None: |
|
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result |
|
code = check_copies.run_ruff(code) |
|
fname = os.path.join(self.diffusers_dir, "new_code.py") |
|
with open(fname, "w", newline="\n") as f: |
|
f.write(code) |
|
if overwrite_result is None: |
|
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0) |
|
else: |
|
check_copies.is_copy_consistent(f.name, overwrite=True) |
|
with open(fname, "r") as f: |
|
self.assertTrue(f.read(), expected) |
|
|
|
def test_find_code_in_diffusers(self): |
|
code = check_copies.find_code_in_diffusers("schedulers.scheduling_ddpm.DDPMSchedulerOutput") |
|
self.assertEqual(code, REFERENCE_CODE) |
|
|
|
def test_is_copy_consistent(self): |
|
|
|
self.check_copy_consistency( |
|
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput", |
|
"DDPMSchedulerOutput", |
|
REFERENCE_CODE + "\n", |
|
) |
|
|
|
|
|
self.check_copy_consistency( |
|
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput", |
|
"DDPMSchedulerOutput", |
|
REFERENCE_CODE, |
|
) |
|
|
|
|
|
self.check_copy_consistency( |
|
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test", |
|
"TestSchedulerOutput", |
|
re.sub("DDPM", "Test", REFERENCE_CODE), |
|
) |
|
|
|
|
|
long_class_name = "TestClassWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason" |
|
self.check_copy_consistency( |
|
f"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->{long_class_name}", |
|
f"{long_class_name}SchedulerOutput", |
|
re.sub("Bert", long_class_name, REFERENCE_CODE), |
|
) |
|
|
|
|
|
self.check_copy_consistency( |
|
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test", |
|
"TestSchedulerOutput", |
|
REFERENCE_CODE, |
|
overwrite_result=re.sub("DDPM", "Test", REFERENCE_CODE), |
|
) |
|
|