import gradio as gr import open_clip import torch from PIL import Image from open_clip import tokenizer from rudalle import get_vae from einops import rearrange from modules import DenoiseUNet model_id = "./model_600000.pt" device = "cuda" if torch.cuda.is_available() else "cpu" batch_size = 4 steps = 11 scale = 5 def to_pil(images): images = images.permute(0, 2, 3, 1).cpu().numpy() images = (images * 255).round().astype("uint8") images = [Image.fromarray(image) for image in images] return images def log(t, eps=1e-20): return torch.log(t + eps) def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) def gumbel_sample(t, temperature=1., dim=-1): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'): with torch.inference_mode(): r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device) temperatures = torch.linspace(temp_range[0], temp_range[1], T) preds = [] if x is None: x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) elif mask is not None: noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) x = noise * mask + (1-mask) * x init_x = x.clone() for i in range(starting_t, T): if renoise_mode == 'prev': prev_x = x.clone() r, temp = r_range[i], temperatures[i] logits = model(x, c, r) if classifier_free_scale >= 0: logits_uncond = model(x, torch.zeros_like(c), r) logits = torch.lerp(logits_uncond, logits, classifier_free_scale) x = logits x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1)) if typical_filtering: x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1) x_flat_norm_p = torch.exp(x_flat_norm) entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True) c_flat_shifted = torch.abs((-x_flat_norm) - entropy) c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False) x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1) last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1) sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1)) if typical_min_tokens > 1: sorted_indices_to_remove[..., :typical_min_tokens] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove) x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf")) # x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0] x_flat = gumbel_sample(x_flat, temperature=temp) x = x_flat.view(x.size(0), *x.shape[2:]) if mask is not None: x = x * mask + (1-mask) * init_x if i < renoise_steps: if renoise_mode == 'start': x, _ = model.add_noise(x, r_range[i+1], random_x=init_x) elif renoise_mode == 'prev': x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x) else: # 'rand' x, _ = model.add_noise(x, r_range[i+1]) preds.append(x.detach()) return preds # Model loading vqmodel = get_vae().to(device) vqmodel.eval().requires_grad_(False) clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k') clip_model = clip_model.to(device).eval().requires_grad_(False) def encode(x): return vqmodel.model.encode((2 * x - 1))[-1][-1] def decode(img_seq, shape=(32,32)): img_seq = img_seq.view(img_seq.shape[0], -1) b, n = img_seq.shape one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=vqmodel.num_tokens).float() z = (one_hot_indices @ vqmodel.model.quantize.embed.weight) z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1]) img = vqmodel.model.decode(z) img = (img.clamp(-1., 1.) + 1) * 0.5 return img state_dict = torch.load(model_id, map_location=device) model = DenoiseUNet(num_labels=8192).to(device) model.load_state_dict(state_dict) model.eval().requires_grad_() # ----- def infer(prompt): latent_shape = (32, 32) tokenized_text = tokenizer.tokenize([prompt] * batch_size).to(device) with torch.inference_mode(): with torch.autocast(device_type="cuda"): clip_embeddings = clip_model.encode_text(tokenized_text) images = sample( model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=scale, renoise_steps=steps, renoise_mode="start" ) images = decode(images[-1], latent_shape) return to_pil(images) css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .gr-button { color: white; border-color: black; background: black; } input[type='range'] { accent-color: black; } .dark input[type='range'] { accent-color: #dfdfdf; } .container { max-width: 730px; margin: auto; padding-top: 1.5rem; } #gallery { min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important; } #gallery>div>.h-full { min-height: 20rem; } .details:hover { text-decoration: underline; } .gr-button { white-space: nowrap; } .gr-button:focus { border-color: rgb(147 197 253 / var(--tw-border-opacity)); outline: none; box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); --tw-border-opacity: 1; --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); --tw-ring-opacity: .5; } .footer { margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5; } .footer>p { font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; } .dark .footer { border-color: #303030; } .dark .footer>p { background: #0b0f19; } .acknowledgments h4{ margin: 1.25em 0 .25em 0; font-weight: bold; font-size: 115%; } .animate-spin { animation: spin 1s linear infinite; } @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; } #share-btn { all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important; } #share-btn * { all: unset; } .gr-form{ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; } #prompt-container{ gap: 0; } """ block = gr.Blocks(css=css) with block: gr.HTML( """
Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.