|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import os |
|
import shutil |
|
import sys |
|
import tempfile |
|
|
|
import torch |
|
|
|
from diffusers import VQModel |
|
from diffusers.utils.testing_utils import require_timm |
|
|
|
|
|
sys.path.append("..") |
|
from test_examples_utils import ExamplesTestsAccelerate, run_command |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
logger = logging.getLogger() |
|
stream_handler = logging.StreamHandler(sys.stdout) |
|
logger.addHandler(stream_handler) |
|
|
|
|
|
@require_timm |
|
class TextToImage(ExamplesTestsAccelerate): |
|
@property |
|
def test_vqmodel_config(self): |
|
return { |
|
"_class_name": "VQModel", |
|
"_diffusers_version": "0.17.0.dev0", |
|
"act_fn": "silu", |
|
"block_out_channels": [ |
|
32, |
|
], |
|
"down_block_types": [ |
|
"DownEncoderBlock2D", |
|
], |
|
"in_channels": 3, |
|
"latent_channels": 4, |
|
"layers_per_block": 2, |
|
"norm_num_groups": 32, |
|
"norm_type": "spatial", |
|
"num_vq_embeddings": 32, |
|
"out_channels": 3, |
|
"sample_size": 32, |
|
"scaling_factor": 0.18215, |
|
"up_block_types": [ |
|
"UpDecoderBlock2D", |
|
], |
|
"vq_embed_dim": 4, |
|
} |
|
|
|
@property |
|
def test_discriminator_config(self): |
|
return { |
|
"_class_name": "Discriminator", |
|
"_diffusers_version": "0.27.0.dev0", |
|
"in_channels": 3, |
|
"cond_channels": 0, |
|
"hidden_channels": 8, |
|
"depth": 4, |
|
} |
|
|
|
def get_vq_and_discriminator_configs(self, tmpdir): |
|
vqmodel_config_path = os.path.join(tmpdir, "vqmodel.json") |
|
discriminator_config_path = os.path.join(tmpdir, "discriminator.json") |
|
with open(vqmodel_config_path, "w") as fp: |
|
json.dump(self.test_vqmodel_config, fp) |
|
with open(discriminator_config_path, "w") as fp: |
|
json.dump(self.test_discriminator_config, fp) |
|
return vqmodel_config_path, discriminator_config_path |
|
|
|
def test_vqmodel(self): |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) |
|
test_args = f""" |
|
examples/vqgan/train_vqgan.py |
|
--dataset_name hf-internal-testing/dummy_image_text_data |
|
--resolution 32 |
|
--image_column image |
|
--train_batch_size 1 |
|
--gradient_accumulation_steps 1 |
|
--max_train_steps 2 |
|
--learning_rate 5.0e-04 |
|
--scale_lr |
|
--lr_scheduler constant |
|
--lr_warmup_steps 0 |
|
--model_config_name_or_path {vqmodel_config_path} |
|
--discriminator_config_name_or_path {discriminator_config_path} |
|
--output_dir {tmpdir} |
|
""".split() |
|
|
|
run_command(self._launch_args + test_args) |
|
|
|
self.assertTrue( |
|
os.path.isfile(os.path.join(tmpdir, "discriminator", "diffusion_pytorch_model.safetensors")) |
|
) |
|
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors"))) |
|
|
|
def test_vqmodel_checkpointing(self): |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) |
|
|
|
|
|
|
|
|
|
initial_run_args = f""" |
|
examples/vqgan/train_vqgan.py |
|
--dataset_name hf-internal-testing/dummy_image_text_data |
|
--resolution 32 |
|
--image_column image |
|
--train_batch_size 1 |
|
--gradient_accumulation_steps 1 |
|
--max_train_steps 4 |
|
--learning_rate 5.0e-04 |
|
--scale_lr |
|
--lr_scheduler constant |
|
--lr_warmup_steps 0 |
|
--model_config_name_or_path {vqmodel_config_path} |
|
--discriminator_config_name_or_path {discriminator_config_path} |
|
--checkpointing_steps=2 |
|
--output_dir {tmpdir} |
|
--seed=0 |
|
""".split() |
|
|
|
run_command(self._launch_args + initial_run_args) |
|
|
|
|
|
self.assertEqual( |
|
{x for x in os.listdir(tmpdir) if "checkpoint" in x}, |
|
{"checkpoint-2", "checkpoint-4"}, |
|
) |
|
|
|
|
|
model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel") |
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
_ = model(image) |
|
|
|
|
|
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) |
|
self.assertEqual( |
|
{x for x in os.listdir(tmpdir) if "checkpoint" in x}, |
|
{"checkpoint-4"}, |
|
) |
|
|
|
|
|
|
|
resume_run_args = f""" |
|
examples/vqgan/train_vqgan.py |
|
--dataset_name hf-internal-testing/dummy_image_text_data |
|
--resolution 32 |
|
--image_column image |
|
--train_batch_size 1 |
|
--gradient_accumulation_steps 1 |
|
--max_train_steps 6 |
|
--learning_rate 5.0e-04 |
|
--scale_lr |
|
--lr_scheduler constant |
|
--lr_warmup_steps 0 |
|
--model_config_name_or_path {vqmodel_config_path} |
|
--discriminator_config_name_or_path {discriminator_config_path} |
|
--checkpointing_steps=1 |
|
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} |
|
--output_dir {tmpdir} |
|
--seed=0 |
|
""".split() |
|
|
|
run_command(self._launch_args + resume_run_args) |
|
|
|
|
|
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") |
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
_ = model(image) |
|
|
|
|
|
|
|
|
|
|
|
self.assertEqual( |
|
{x for x in os.listdir(tmpdir) if "checkpoint" in x}, |
|
{"checkpoint-4", "checkpoint-6"}, |
|
) |
|
|
|
def test_vqmodel_checkpointing_use_ema(self): |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) |
|
|
|
|
|
|
|
|
|
initial_run_args = f""" |
|
examples/vqgan/train_vqgan.py |
|
--dataset_name hf-internal-testing/dummy_image_text_data |
|
--resolution 32 |
|
--image_column image |
|
--train_batch_size 1 |
|
--gradient_accumulation_steps 1 |
|
--max_train_steps 4 |
|
--learning_rate 5.0e-04 |
|
--scale_lr |
|
--lr_scheduler constant |
|
--lr_warmup_steps 0 |
|
--model_config_name_or_path {vqmodel_config_path} |
|
--discriminator_config_name_or_path {discriminator_config_path} |
|
--checkpointing_steps=2 |
|
--output_dir {tmpdir} |
|
--use_ema |
|
--seed=0 |
|
""".split() |
|
|
|
run_command(self._launch_args + initial_run_args) |
|
|
|
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") |
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
_ = model(image) |
|
|
|
|
|
self.assertEqual( |
|
{x for x in os.listdir(tmpdir) if "checkpoint" in x}, |
|
{"checkpoint-2", "checkpoint-4"}, |
|
) |
|
|
|
|
|
model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel") |
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
_ = model(image) |
|
|
|
|
|
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) |
|
|
|
|
|
|
|
resume_run_args = f""" |
|
examples/vqgan/train_vqgan.py |
|
--dataset_name hf-internal-testing/dummy_image_text_data |
|
--resolution 32 |
|
--image_column image |
|
--train_batch_size 1 |
|
--gradient_accumulation_steps 1 |
|
--max_train_steps 6 |
|
--learning_rate 5.0e-04 |
|
--scale_lr |
|
--lr_scheduler constant |
|
--lr_warmup_steps 0 |
|
--model_config_name_or_path {vqmodel_config_path} |
|
--discriminator_config_name_or_path {discriminator_config_path} |
|
--checkpointing_steps=1 |
|
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} |
|
--output_dir {tmpdir} |
|
--use_ema |
|
--seed=0 |
|
""".split() |
|
|
|
run_command(self._launch_args + resume_run_args) |
|
|
|
|
|
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") |
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
_ = model(image) |
|
|
|
|
|
|
|
self.assertEqual( |
|
{x for x in os.listdir(tmpdir) if "checkpoint" in x}, |
|
{"checkpoint-4", "checkpoint-6"}, |
|
) |
|
|
|
def test_vqmodel_checkpointing_checkpoints_total_limit(self): |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) |
|
|
|
|
|
|
|
|
|
|
|
initial_run_args = f""" |
|
examples/vqgan/train_vqgan.py |
|
--dataset_name hf-internal-testing/dummy_image_text_data |
|
--resolution 32 |
|
--image_column image |
|
--train_batch_size 1 |
|
--gradient_accumulation_steps 1 |
|
--max_train_steps 6 |
|
--learning_rate 5.0e-04 |
|
--scale_lr |
|
--lr_scheduler constant |
|
--lr_warmup_steps 0 |
|
--model_config_name_or_path {vqmodel_config_path} |
|
--discriminator_config_name_or_path {discriminator_config_path} |
|
--output_dir {tmpdir} |
|
--checkpointing_steps=2 |
|
--checkpoints_total_limit=2 |
|
--seed=0 |
|
""".split() |
|
|
|
run_command(self._launch_args + initial_run_args) |
|
|
|
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") |
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
_ = model(image) |
|
|
|
|
|
|
|
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) |
|
|
|
def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) |
|
|
|
|
|
|
|
|
|
initial_run_args = f""" |
|
examples/vqgan/train_vqgan.py |
|
--dataset_name hf-internal-testing/dummy_image_text_data |
|
--resolution 32 |
|
--image_column image |
|
--train_batch_size 1 |
|
--gradient_accumulation_steps 1 |
|
--max_train_steps 4 |
|
--learning_rate 5.0e-04 |
|
--scale_lr |
|
--lr_scheduler constant |
|
--lr_warmup_steps 0 |
|
--model_config_name_or_path {vqmodel_config_path} |
|
--discriminator_config_name_or_path {discriminator_config_path} |
|
--checkpointing_steps=2 |
|
--output_dir {tmpdir} |
|
--seed=0 |
|
""".split() |
|
|
|
run_command(self._launch_args + initial_run_args) |
|
|
|
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") |
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
_ = model(image) |
|
|
|
|
|
self.assertEqual( |
|
{x for x in os.listdir(tmpdir) if "checkpoint" in x}, |
|
{"checkpoint-2", "checkpoint-4"}, |
|
) |
|
|
|
|
|
|
|
|
|
resume_run_args = f""" |
|
examples/vqgan/train_vqgan.py |
|
--dataset_name hf-internal-testing/dummy_image_text_data |
|
--resolution 32 |
|
--image_column image |
|
--train_batch_size 1 |
|
--gradient_accumulation_steps 1 |
|
--max_train_steps 8 |
|
--learning_rate 5.0e-04 |
|
--scale_lr |
|
--lr_scheduler constant |
|
--lr_warmup_steps 0 |
|
--model_config_name_or_path {vqmodel_config_path} |
|
--discriminator_config_name_or_path {discriminator_config_path} |
|
--output_dir {tmpdir} |
|
--checkpointing_steps=2 |
|
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} |
|
--checkpoints_total_limit=2 |
|
--seed=0 |
|
""".split() |
|
|
|
run_command(self._launch_args + resume_run_args) |
|
|
|
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") |
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
_ = model(image) |
|
|
|
|
|
self.assertEqual( |
|
{x for x in os.listdir(tmpdir) if "checkpoint" in x}, |
|
{"checkpoint-6", "checkpoint-8"}, |
|
) |
|
|