Spaces:
Runtime error
Runtime error
File size: 3,397 Bytes
258d8c9 c74095e 975dc6e c74095e 258d8c9 e8b12f5 880828c c74095e c4964ee 2e09612 c4964ee 3eed896 48c7266 258d8c9 c74095e 258d8c9 090c9fa c74095e 090c9fa c74095e febb26d 090c9fa 8118b09 c74095e febb26d 090c9fa 8118b09 c74095e 090c9fa c74095e 258d8c9 c131c56 880828c 806402e c33c1c5 880828c 466cd5e b6e7466 bf379ef 49e8691 466cd5e 493299e 13c70ff b547a14 bd9a3cd b547a14 466cd5e a600f9f 466cd5e 6303c40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
with open("test.html") as f:
lines = f.readlines()
def create_key(seed=0):
return jax.random.PRNGKey(seed)
#def addp5sketch(url):
# iframe = f'<iframe src ={url} style="border:none;height:525px;width:100%"/frame>'
# return gr.HTML(iframe)
def wandb_report(url):
iframe = f'<iframe src ={url} style="border:none;height:1024px;width:100%"/frame>'
return gr.HTML(iframe)
report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5'
control_img = 'myimage.jpg'
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
image = Image.fromarray(image)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
with gr.Blocks(theme='kfahn/AnimalPose') as demo:
gr.Markdown(
"""
# Animal Pose Control Net
## This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
[Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
[Diffusers model](https://huggingface.co/JFoz/dog-pose)
[Github](https://github.com/fi4cr/animalpose)
[Training Report](https://wandb.ai/john-fozard/AP10K-pose/runs/wn89ezaw)
""")
with gr.Row():
with gr.Column():
prompts = gr.Textbox(label="Prompt")
negative_prompts = gr.Textbox(label="Negative Prompt")
conditioning_image = gr.Image(label="Conditioning Image")
with gr.Column():
#keypoint_tool = addp5sketch(sketch_url)
keypoint_tool = gr.HTML(lines)
submit_btn = gr.Button("Submit")
submit_btn.click(fn=infer, inputs = ["text", "text", "image"], outputs = "gallery", examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])
#gr.Interface(fn=infer, inputs = ["text", "text", "image"], outputs = "gallery",
#examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])
#with gr.Row():
# report = wandb_report(report_url)
demo.launch() |