import spaces import os import torch import random from PIL import Image import io import base64 from huggingface_hub import snapshot_download from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline from kolors.models.modeling_chatglm import ChatGLMModel from kolors.models.tokenization_chatglm import ChatGLMTokenizer from diffusers import UNet2DConditionModel, AutoencoderKL from diffusers import EulerDiscreteScheduler import gradio as gr import requests # Download the model files ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors") # Load the models text_encoder = ChatGLMModel.from_pretrained( os.path.join(ckpt_dir, 'text_encoder'), torch_dtype=torch.float16).half() tokenizer = ChatGLMTokenizer.from_pretrained(os.path.join(ckpt_dir, 'text_encoder')) vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), revision=None).half() scheduler = EulerDiscreteScheduler.from_pretrained(os.path.join(ckpt_dir, "scheduler")) unet = UNet2DConditionModel.from_pretrained(os.path.join(ckpt_dir, "unet"), revision=None).half() pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, force_zeros_for_empty_prompt=False) pipe = pipe.to("cuda") API_URL = "https://bots.spaceship.im" # Replace with your actual API endpoint url_params = gr.JSON({}, visible=True, label="URL Params") prompt = gr.Textbox(label="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt") gallery = gr.Gallery(label="Result", elem_id="gallery", show_label=False) height = gr.Slider(512, 2048, 1024, step=64, label="Height") width = gr.Slider(512, 2048, 1024, step=64, label="Width") steps = gr.Slider(1, 50, 25, step=1, label="Steps") number_of_images = gr.Slider( 1, 4, 1, step=1, label="Number of images per prompt") random_seed = gr.Checkbox(label="Use Random Seed", value=True) seed = gr.Number(label="Seed", value=0, precision=0) seed_used = gr.Number(label="Seed Used") def test_func(request: gr.Request): data = request.query_params if "uuid" in data: msg_id = data["uuid"] response = requests.get(f"{API_URL}/check_data/{msg_id}") if response.status_code == 200: api_data = response.json().get("data") return [value for value in api_data.values()] return prompt, data, height, width, steps, number_of_images, random_seed, seed, gallery, seed_used @spaces.GPU(duration=200) def generate_image(request: gr.Request,prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed,progress=gr.Progress(track_tqdm=True)): if use_random_seed: seed = random.randint(0, 2**32 - 1) else: seed = int(seed) # Ensure seed is an integer height = int(height) width = int(width) print(f'Debug: Retrieved height = {height}, width = {width}') image = pipe( prompt=prompt, negative_prompt=negative_prompt, height=1024, width=1024, num_inference_steps=20, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, generator=torch.Generator(pipe.device).manual_seed(seed) ).images # testing image = image[0] # Since it's in a list # Convert the image to an in-memory byte stream byte_stream = io.BytesIO() image.save(byte_stream, format='PNG') # You can choose other formats like 'JPEG' if needed byte_data = byte_stream.getvalue() # Encode the byte stream to a base64 string base64_str = base64.b64encode(byte_data).decode('utf-8') # For demonstration, print the base64 string print(base64_str) query_data: dict = request.query_params save_data = { "prompt": prompt, "url_params": url_params, "height": height, "width": width, "steps": steps, "num_images_per_prompt": num_images_per_prompt, "use_random_seed": use_random_seed, "seed": seed, "output": base64_str, "seed_used" : seed } print(save_data) print(query_data) # url = f"https://bots.spaceship.im/save_data/{query_data['uuid']}" # res = requests.post(url, json=save_data) return image, seed description = """

Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis

[Official Website] [Tech Report] [Model Page] [Github]

""" # Gradio interface with gr.Blocks() as demo: iface = gr.Interface( fn=generate_image, inputs=[ prompt, negative_prompt, url_params ], additional_inputs=[ height, width, steps, number_of_images, random_seed, seed ], additional_inputs_accordion=gr.Accordion( label="Advanced settings", open=False), outputs=[ gallery, seed_used ], title="Kolors", description=description, theme='bethecloud/storj_theme', ) demo.load(fn=test_func, outputs=[ prompt, url_params, height, width, steps, number_of_images, random_seed, seed, gallery, seed_used]) demo.launch(debug=True)