|
import functools |
|
import pickle |
|
import random |
|
from typing import List |
|
|
|
import numpy as np |
|
import streamlit as st |
|
import torch |
|
|
|
from huggingface_hub import hf_hub_url, cached_download |
|
|
|
ICON_CLASS_MAPPING = { |
|
"Fire": 8, |
|
"Magic": 7, |
|
"Nature": 6, |
|
"Lightning": 5, |
|
"Ice": 4, |
|
"Shadow": 3, |
|
"Unholy": 2, |
|
"Battle": 1, |
|
"Holy": 0, |
|
} |
|
|
|
MAX_SEED = 100000000 |
|
|
|
st.title("RPG Icon Generator") |
|
|
|
with open( |
|
cached_download(hf_hub_url("gylleus/rpg-icongen", "icongen-model.pkl")), "rb" |
|
) as f: |
|
G = pickle.load(f)["G_ema"] |
|
|
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
G = G.to(device) |
|
else: |
|
G.forward = functools.partial(G.forward, force_fp32=True) |
|
|
|
|
|
random_seed = 0 |
|
|
|
|
|
def randomize_seed() -> int: |
|
global random_seed |
|
random_seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
randomize_seed() |
|
|
|
|
|
def get_class_id(class_name: str): |
|
if class_name in ICON_CLASS_MAPPING: |
|
return ICON_CLASS_MAPPING[class_name] |
|
return ICON_CLASS_MAPPING["Fire"] |
|
|
|
|
|
def generate(seed: int, class_name: str) -> np.ndarray: |
|
label = torch.zeros([1, G.c_dim], device=device) |
|
|
|
label[:, get_class_id(class_name)] = 1 |
|
truncation_psi = 1 |
|
noise_mode = "const" |
|
|
|
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) |
|
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
return img.cpu().numpy() |
|
|
|
|
|
def generate_images(seed: int, amount: int, class_name: str) -> List[np.ndarray]: |
|
return [generate(i, class_name) for i in range(seed, seed + amount)] |
|
|
|
|
|
st.button("Generate", on_click=randomize_seed()) |
|
|
|
chosen_class = st.selectbox("Choose icon type", tuple(ICON_CLASS_MAPPING.keys())) |
|
|
|
image_amount = st.slider("Images to generate", 1, 9, 3) |
|
|
|
columns = st.columns(3) |
|
|
|
column_index = 0 |
|
for img in generate_images(random_seed, image_amount, chosen_class): |
|
column = columns[column_index % len(columns)] |
|
column.image(img) |
|
column_index += 1 |
|
|