|
import tempfile |
|
import unittest |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from diffusers import DiffusionPipeline |
|
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor |
|
|
|
|
|
class AttnAddedKVProcessorTests(unittest.TestCase): |
|
def get_constructor_arguments(self, only_cross_attention: bool = False): |
|
query_dim = 10 |
|
|
|
if only_cross_attention: |
|
cross_attention_dim = 12 |
|
else: |
|
|
|
cross_attention_dim = query_dim |
|
|
|
return { |
|
"query_dim": query_dim, |
|
"cross_attention_dim": cross_attention_dim, |
|
"heads": 2, |
|
"dim_head": 4, |
|
"added_kv_proj_dim": 6, |
|
"norm_num_groups": 1, |
|
"only_cross_attention": only_cross_attention, |
|
"processor": AttnAddedKVProcessor(), |
|
} |
|
|
|
def get_forward_arguments(self, query_dim, added_kv_proj_dim): |
|
batch_size = 2 |
|
|
|
hidden_states = torch.rand(batch_size, query_dim, 3, 2) |
|
encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim) |
|
attention_mask = None |
|
|
|
return { |
|
"hidden_states": hidden_states, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
"attention_mask": attention_mask, |
|
} |
|
|
|
def test_only_cross_attention(self): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
constructor_args = self.get_constructor_arguments(only_cross_attention=False) |
|
attn = Attention(**constructor_args) |
|
|
|
self.assertTrue(attn.to_k is not None) |
|
self.assertTrue(attn.to_v is not None) |
|
|
|
forward_args = self.get_forward_arguments( |
|
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] |
|
) |
|
|
|
self_and_cross_attn_out = attn(**forward_args) |
|
|
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
constructor_args = self.get_constructor_arguments(only_cross_attention=True) |
|
attn = Attention(**constructor_args) |
|
|
|
self.assertTrue(attn.to_k is None) |
|
self.assertTrue(attn.to_v is None) |
|
|
|
forward_args = self.get_forward_arguments( |
|
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] |
|
) |
|
|
|
only_cross_attn_out = attn(**forward_args) |
|
|
|
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) |
|
|
|
|
|
class DeprecatedAttentionBlockTests(unittest.TestCase): |
|
def test_conversion_when_using_device_map(self): |
|
pipe = DiffusionPipeline.from_pretrained( |
|
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None |
|
) |
|
|
|
pre_conversion = pipe( |
|
"foo", |
|
num_inference_steps=2, |
|
generator=torch.Generator("cpu").manual_seed(0), |
|
output_type="np", |
|
).images |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"hf-internal-testing/tiny-stable-diffusion-torch", device_map="balanced", safety_checker=None |
|
) |
|
|
|
conversion = pipe( |
|
"foo", |
|
num_inference_steps=2, |
|
generator=torch.Generator("cpu").manual_seed(0), |
|
output_type="np", |
|
).images |
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
pipe.save_pretrained(tmpdir) |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="balanced", safety_checker=None) |
|
after_conversion = pipe( |
|
"foo", |
|
num_inference_steps=2, |
|
generator=torch.Generator("cpu").manual_seed(0), |
|
output_type="np", |
|
).images |
|
|
|
self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3)) |
|
self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3)) |
|
|