|
|
|
import torch |
|
|
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
class UnetSchedulerOneForwardPipeline(DiffusionPipeline): |
|
def __init__(self, unet, scheduler): |
|
super().__init__() |
|
|
|
self.register_modules(unet=unet, scheduler=scheduler) |
|
|
|
def __call__(self): |
|
image = torch.randn( |
|
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), |
|
) |
|
timestep = 1 |
|
|
|
model_output = self.unet(image, timestep).sample |
|
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample |
|
|
|
result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output) |
|
|
|
return result |
|
|