import spaces import os import torch import random from huggingface_hub import snapshot_download from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline from kolors.models.modeling_chatglm import ChatGLMModel from kolors.models.tokenization_chatglm import ChatGLMTokenizer from diffusers import UNet2DConditionModel, AutoencoderKL from diffusers import EulerDiscreteScheduler import gradio as gr import requests # Download the model files ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors") # Load the models text_encoder = ChatGLMModel.from_pretrained( os.path.join(ckpt_dir, 'text_encoder'), torch_dtype=torch.float16).half() tokenizer = ChatGLMTokenizer.from_pretrained(os.path.join(ckpt_dir, 'text_encoder')) vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), revision=None).half() scheduler = EulerDiscreteScheduler.from_pretrained(os.path.join(ckpt_dir, "scheduler")) unet = UNet2DConditionModel.from_pretrained(os.path.join(ckpt_dir, "unet"), revision=None).half() pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, force_zeros_for_empty_prompt=False) pipe = pipe.to("cuda") API_URL = "https://bots.spaceship.im" # Replace with your actual API endpoint url_params = gr.JSON({}, visible=True, label="URL Params") prompt = gr.Textbox(label="Prompt") gallery = gr.Gallery(label="Result", elem_id="gallery", show_label=False) height = gr.Slider(512, 2048, 1024, step=64, label="Height") width = gr.Slider(512, 2048, 1024, step=64, label="Width") steps = gr.Slider(1, 50, 25, step=1, label="Steps") number_of_images = gr.Slider( 1, 4, 1, step=1, label="Number of images per prompt") random_seed = gr.Checkbox(label="Use Random Seed", value=True) seed = gr.Number(label="Seed", value=0, precision=0) seed_used = gr.Number(label="Seed Used") def test_func(request: gr.Request): data = request.query_params if "uuid" in data: msg_id = data["uuid"] response = requests.get(f"{API_URL}/check_data/{msg_id}") if response.status_code == 200: api_data = response.json().get("data") return [value for value in api_data.values()] return prompt, data, height, width, steps, number_of_images, random_seed, seed, gallery, seed_used @spaces.GPU(duration=200) def generate_image(prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, request: gr.Request,progress=gr.Progress(track_tqdm=True)): if use_random_seed: seed = random.randint(0, 2**32 - 1) else: seed = int(seed) # Ensure seed is an integer image = pipe( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, generator=torch.Generator(pipe.device).manual_seed(seed) ).images query_data: dict = request.query_params save_data = { "prompt": prompt, "url_params": url_params, "height": height, "width": width, "steps": steps, "num_images_per_prompt": num_images_per_prompt, "use_random_seed": use_random_seed, "seed": seed, "output": "https://bots.spaceship.im/static/204d01b0-cfc4-499f-8d55-b0f072d5c285_14.jpg", "seed_used" : seed } url = f"https://bots.spaceship.im/save_data/{query_data['uuid']}" res = requests.post(url, json=save_data) return image, seed description = """

Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis

[Official Website] [Tech Report] [Model Page] [Github]

""" # Gradio interface with gr.Blocks() as demo: iface = gr.Interface( fn=generate_image, inputs=[ prompt, url_params ], additional_inputs=[ height, width, steps, number_of_images, random_seed, seed ], additional_inputs_accordion=gr.Accordion( label="Advanced settings", open=False), outputs=[ gallery, seed_used ], title="Kolors", description=description, theme='bethecloud/storj_theme', ) demo.load(fn=test_func, outputs=[ prompt, url_params, height, width, steps, number_of_images, random_seed, seed, gallery, seed_used]) demo.launch(debug=True)