File size: 2,923 Bytes
3b3a783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import time

import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils
from diffusers import AutoencoderKL

from tld.denoiser import Denoiser
from tld.diffusion import DiffusionGenerator, DiffusionTransformer, LTDConfig
from PIL.Image import Image

to_pil = transforms.ToPILImage()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def test_outputs(num_imgs=4):
    model = Denoiser(
        image_size=32, noise_embed_dims=128, patch_size=2, embed_dim=768, dropout=0.1, n_layers=12
    )
    x = torch.rand(num_imgs, 4, 32, 32)
    noise_level = torch.rand(num_imgs, 1)
    label = torch.rand(num_imgs, 768)

    print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

    with torch.no_grad():
        start_time = time.time()
        output = model(x, noise_level, label)
        end_time = time.time()

    execution_time = end_time - start_time
    print(f"Model execution took {execution_time:.4f} seconds.")

    assert output.shape == torch.Size([num_imgs, 4, 32, 32])
    print("Basic tests passed.")

    # model = Denoiser(image_size=16, noise_embed_dims=128, patch_size=2, embed_dim=256, dropout=0.1, n_layers=6)
    # x = torch.rand(8, 4, 32, 32)
    # noise_level = torch.rand(8, 1)
    # label = torch.rand(8, 768)

    # with torch.no_grad():
    #     output = model(x, noise_level, label)

    # assert output.shape == torch.Size([8, 4, 32, 32])
    # print("Uspscale tests passed.")


def test_diffusion_generator():
    model_dtype = torch.float32  ##float 16 will not work on cpu
    num_imgs = 1
    nrow = int(np.sqrt(num_imgs))

    denoiser = Denoiser(
        image_size=32, noise_embed_dims=128, patch_size=2, embed_dim=256, dropout=0.1, n_layers=3
    )
    print(f"Model has {sum(p.numel() for p in denoiser.parameters())} parameters")

    denoiser.to(model_dtype)

    vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=model_dtype).to(device)

    labels = torch.rand(num_imgs, 768)

    diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype)

    out, _ = diffuser.generate(
        labels=labels,
        num_imgs=num_imgs,
        class_guidance=3,
        seed=1,
        n_iter=5,
        exponent=1,
        scale_factor=8,
        sharp_f=0,
        bright_f=0,
    )

    out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
    out.save("test.png")
    print("Images generated at test.png")


def test_full_generation_pipeline():
    ltdconfig = LTDConfig()
    diffusion_transformer = DiffusionTransformer(ltdconfig)

    out = diffusion_transformer.generate_image_from_text(prompt="a cute cat")
    print(out)
    assert type(out) == Image


# TODO: should add tests for train loop and data processing