import functools import pickle import random from typing import List import numpy as np import streamlit as st import torch from huggingface_hub import hf_hub_url, cached_download ICON_CLASS_MAPPING = { "Fire": 8, "Magic": 7, "Nature": 6, "Lightning": 5, "Ice": 4, "Shadow": 3, "Unholy": 2, "Battle": 1, "Holy": 0, } MAX_SEED = 100000000 st.title("RPG Icon Generator") with open( cached_download(hf_hub_url("gylleus/rpg-icongen", "icongen-model.pkl")), "rb" ) as f: G = pickle.load(f)["G_ema"] # torch.nn.Module device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") G = G.to(device) else: G.forward = functools.partial(G.forward, force_fp32=True) random_seed = 0 def randomize_seed() -> int: global random_seed random_seed = random.randint(0, MAX_SEED) randomize_seed() def get_class_id(class_name: str): if class_name in ICON_CLASS_MAPPING: return ICON_CLASS_MAPPING[class_name] return ICON_CLASS_MAPPING["Fire"] def generate(seed: int, class_name: str) -> np.ndarray: label = torch.zeros([1, G.c_dim], device=device) # set chosen class label[:, get_class_id(class_name)] = 1 truncation_psi = 1 noise_mode = "const" z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) return img.cpu().numpy() def generate_images(seed: int, amount: int, class_name: str) -> List[np.ndarray]: return [generate(i, class_name) for i in range(seed, seed + amount)] st.button("Generate", on_click=randomize_seed()) chosen_class = st.selectbox("Choose icon type", tuple(ICON_CLASS_MAPPING.keys())) image_amount = st.slider("Images to generate", 1, 9, 3) columns = st.columns(3) column_index = 0 for img in generate_images(random_seed, image_amount, chosen_class): column = columns[column_index % len(columns)] column.image(img) column_index += 1