File size: 2,407 Bytes
cb0d40a
 
 
 
 
 
 
 
 
 
c046d7f
cb0d40a
c046d7f
 
 
 
cb0d40a
c046d7f
 
 
 
cb0d40a
c046d7f
cb0d40a
 
 
 
 
 
c046d7f
cb0d40a
 
c046d7f
 
 
 
 
 
 
cb0d40a
 
 
 
 
 
 
c046d7f
cb0d40a
c046d7f
 
 
cb0d40a
 
c046d7f
 
cb0d40a
 
c046d7f
 
 
cb0d40a
 
 
c046d7f
cb0d40a
c046d7f
 
 
cb0d40a
c046d7f
cb0d40a
 
c046d7f
 
 
 
 
 
 
cb0d40a
 
 
c046d7f
cb0d40a
 
c046d7f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from pathlib import Path
from types import SimpleNamespace
import torch
import torch.nn.functional as F
import wandb
import streamlit as st

from utilities import ContextUnet, setup_ddpm
from constants import WANDB_API_KEY


def load_model():
    """
    This function loads the model from the model registry.
    """

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # login to wandb
    # wandb.login(key=WANDB_API_KEY)

    api = wandb.Api(api_key=WANDB_API_KEY)
    artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:v1", type="model")
    model_path = Path(artifact.download())

    # recover model info from the registry
    producer_run = artifact.logged_by()

    # load the weights dictionary
    model_weights = torch.load(model_path / "context_model.pth", map_location="cpu")

    # create the model
    model = ContextUnet(
        in_channels=3,
        n_feat=producer_run.config["n_feat"],
        n_cfeat=producer_run.config["n_cfeat"],
        height=producer_run.config["height"],
    )

    # load the weights into the model
    model.load_state_dict(model_weights)

    # set the model to eval mode
    model.eval()
    return model.to(DEVICE)


def show_image(img):
    """
    This function shows the image in the streamlit app.
    """
    img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
    st.image(img, clamp=True)
    return img


def generate_sprites(feature_vector):
    """
    This function generates sprites from a given feature vector.
    """
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    config = SimpleNamespace(
        # hyperparameters
        num_samples=30,
        # ddpm sampler hyperparameters
        timesteps=500,
        beta1=1e-4,
        beta2=0.02,
        # network hyperparameters
        height=16,
    )
    nn_model = load_model()

    _, sample_ddpm_context = setup_ddpm(
        config.beta1, config.beta2, config.timesteps, DEVICE
    )

    noises = torch.randn(config.num_samples, 3, config.height, config.height).to(DEVICE)

    feature_vector = torch.tensor([feature_vector]).to(DEVICE).float()
    ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector)

    # upscale the 16*16 images to 256*256
    ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear")
    # show the images
    img = show_image(ddpm_samples[0])
    return img