Spaces:
Sleeping
Sleeping
File size: 2,407 Bytes
cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from pathlib import Path
from types import SimpleNamespace
import torch
import torch.nn.functional as F
import wandb
import streamlit as st
from utilities import ContextUnet, setup_ddpm
from constants import WANDB_API_KEY
def load_model():
"""
This function loads the model from the model registry.
"""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# login to wandb
# wandb.login(key=WANDB_API_KEY)
api = wandb.Api(api_key=WANDB_API_KEY)
artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:v1", type="model")
model_path = Path(artifact.download())
# recover model info from the registry
producer_run = artifact.logged_by()
# load the weights dictionary
model_weights = torch.load(model_path / "context_model.pth", map_location="cpu")
# create the model
model = ContextUnet(
in_channels=3,
n_feat=producer_run.config["n_feat"],
n_cfeat=producer_run.config["n_cfeat"],
height=producer_run.config["height"],
)
# load the weights into the model
model.load_state_dict(model_weights)
# set the model to eval mode
model.eval()
return model.to(DEVICE)
def show_image(img):
"""
This function shows the image in the streamlit app.
"""
img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
st.image(img, clamp=True)
return img
def generate_sprites(feature_vector):
"""
This function generates sprites from a given feature vector.
"""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
config = SimpleNamespace(
# hyperparameters
num_samples=30,
# ddpm sampler hyperparameters
timesteps=500,
beta1=1e-4,
beta2=0.02,
# network hyperparameters
height=16,
)
nn_model = load_model()
_, sample_ddpm_context = setup_ddpm(
config.beta1, config.beta2, config.timesteps, DEVICE
)
noises = torch.randn(config.num_samples, 3, config.height, config.height).to(DEVICE)
feature_vector = torch.tensor([feature_vector]).to(DEVICE).float()
ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector)
# upscale the 16*16 images to 256*256
ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear")
# show the images
img = show_image(ddpm_samples[0])
return img
|