icongen / app.py
Karl Gylleus
init commit
78ae3cd
raw
history blame contribute delete
No virus
2.09 kB
import functools
import pickle
import random
from typing import List
import numpy as np
import streamlit as st
import torch
from huggingface_hub import hf_hub_url, cached_download
ICON_CLASS_MAPPING = {
"Fire": 8,
"Magic": 7,
"Nature": 6,
"Lightning": 5,
"Ice": 4,
"Shadow": 3,
"Unholy": 2,
"Battle": 1,
"Holy": 0,
}
MAX_SEED = 100000000
st.title("RPG Icon Generator")
with open(
cached_download(hf_hub_url("gylleus/rpg-icongen", "icongen-model.pkl")), "rb"
) as f:
G = pickle.load(f)["G_ema"] # torch.nn.Module
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
G = G.to(device)
else:
G.forward = functools.partial(G.forward, force_fp32=True)
random_seed = 0
def randomize_seed() -> int:
global random_seed
random_seed = random.randint(0, MAX_SEED)
randomize_seed()
def get_class_id(class_name: str):
if class_name in ICON_CLASS_MAPPING:
return ICON_CLASS_MAPPING[class_name]
return ICON_CLASS_MAPPING["Fire"]
def generate(seed: int, class_name: str) -> np.ndarray:
label = torch.zeros([1, G.c_dim], device=device)
# set chosen class
label[:, get_class_id(class_name)] = 1
truncation_psi = 1
noise_mode = "const"
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return img.cpu().numpy()
def generate_images(seed: int, amount: int, class_name: str) -> List[np.ndarray]:
return [generate(i, class_name) for i in range(seed, seed + amount)]
st.button("Generate", on_click=randomize_seed())
chosen_class = st.selectbox("Choose icon type", tuple(ICON_CLASS_MAPPING.keys()))
image_amount = st.slider("Images to generate", 1, 9, 3)
columns = st.columns(3)
column_index = 0
for img in generate_images(random_seed, image_amount, chosen_class):
column = columns[column_index % len(columns)]
column.image(img)
column_index += 1