Kolors1 / app.py
seshubon's picture
Update app.py
568d5f3 verified
raw
history blame contribute delete
No virus
6.3 kB
import spaces
import os
import torch
import random
from PIL import Image
import io
import base64
import cloudinary
import cloudinary.uploader
from cloudinary.utils import cloudinary_url
from huggingface_hub import snapshot_download
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
StableDiffusionXLPipeline,
)
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import UNet2DConditionModel, AutoencoderKL
from diffusers import EulerDiscreteScheduler
import gradio as gr
import requests
# Download the model files
ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
# Load the models
text_encoder = ChatGLMModel.from_pretrained(
os.path.join(ckpt_dir, "text_encoder"), torch_dtype=torch.float16
).half()
tokenizer = ChatGLMTokenizer.from_pretrained(os.path.join(ckpt_dir, "text_encoder"))
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), revision=None).half()
scheduler = EulerDiscreteScheduler.from_pretrained(os.path.join(ckpt_dir, "scheduler"))
unet = UNet2DConditionModel.from_pretrained(
os.path.join(ckpt_dir, "unet"), revision=None
).half()
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False,
)
pipe = pipe.to("cuda")
API_URL = "https://bots.spaceship.im" # Replace with your actual API endpoint
url_params = gr.JSON({}, visible=True, label="URL Params")
prompt = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative Prompt")
gallery = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
height = gr.Slider(512, 2048, 1024, step=64, label="Height")
width = gr.Slider(512, 2048, 1024, step=64, label="Width")
steps = gr.Slider(1, 50, 25, step=1, label="Steps")
number_of_images = gr.Slider(1, 4, 1, step=1, label="Number of images per prompt")
random_seed = gr.Checkbox(label="Use Random Seed", value=True)
seed = gr.Number(label="Seed", value=0, precision=0)
seed_used = gr.Number(label="Seed Used")
def test_func(request: gr.Request):
data = request.query_params
if "uuid" in data:
msg_id = data["uuid"]
response = requests.get(f"{API_URL}/check_data/{msg_id}")
if response.status_code == 200:
api_data = response.json().get("data")
return [value for value in api_data.values()]
return (
prompt,
data,
height,
width,
steps,
number_of_images,
random_seed,
seed,
gallery,
seed_used,
)
@spaces.GPU(duration=200)
def generate_image(
request: gr.Request,
prompt,
negative_prompt,
height,
width,
num_inference_steps,
guidance_scale,
num_images_per_prompt,
use_random_seed,
seed,
progress=gr.Progress(track_tqdm=True),
):
if use_random_seed:
seed = random.randint(0, 2**32 - 1)
else:
seed = int(seed) # Ensure seed is an integer
height = int(height)
width = int(width)
print(f"Debug: Retrieved height = {height}, width = {width}")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=1024,
width=1024,
num_inference_steps=20,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=torch.Generator(pipe.device).manual_seed(seed),
).images
# testing
finaimage = image[0] # Since it's in a list
# Convert the image to an in-memory byte stream
byte_stream = io.BytesIO()
finaimage.save(
byte_stream, format="PNG"
) # You can choose other formats like 'JPEG' if needed
byte_data = byte_stream.getvalue()
# Encode the byte stream to a base64 string
base64_str = "data:image/png;base64," + base64.b64encode(byte_data).decode("utf-8")
# For demonstration, print the base64 string
print(base64_str)
query_data: dict = request.query_params
save_data = {
"prompt": prompt,
"url_params": url_params,
"height": height,
"width": width,
"steps": steps,
"num_images_per_prompt": num_images_per_prompt,
"use_random_seed": use_random_seed,
"seed": seed,
"output": base64_str,
"seed_used": seed,
}
cloudinary.config(
cloud_name = "dqougmpti",
api_key = "967712926887747",
api_secret = "KbMgVBpkTWxU06tX_jSZKilKD0I",
secure=True
)
# Upload an image
upload_result = cloudinary.uploader.upload(base64_str,
public_id="shoes")
print(upload_result["secure_url"])
print(save_data)
print(query_data)
# url = f"https://bots.spaceship.im/save_data/{query_data['uuid']}"
# res = requests.post(url, json=save_data)
return image, seed, upload_result["secure_url"]
description = """
<p align="center">Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis</p>
<p><center>
<a href="https://kolors.kuaishou.com/" target="_blank">[Official Website]</a>
<a href="https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf" target="_blank">[Tech Report]</a>
<a href="https://huggingface.co/Kwai-Kolors/Kolors" target="_blank">[Model Page]</a>
<a href="https://github.com/Kwai-Kolors/Kolors" target="_blank">[Github]</a>
</center></p>
"""
# Gradio interface
with gr.Blocks() as demo:
iface = gr.Interface(
fn=generate_image,
inputs=[prompt, negative_prompt, url_params],
additional_inputs=[height, width, steps, number_of_images, random_seed, seed],
additional_inputs_accordion=gr.Accordion(label="Advanced settings", open=False),
outputs=[gallery, seed_used, gr.Textbox(label="Base64 Image String")],
title="Kolors",
description=description,
theme="bethecloud/storj_theme",
)
demo.load(
fn=test_func,
outputs=[
prompt,
url_params,
height,
width,
steps,
number_of_images,
random_seed,
seed,
gallery,
seed_used,
],
)
demo.launch(debug=True)