tsungtao commited on
Commit
86736fe
1 Parent(s): e9584c2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+ from flax.jax_utils import replicate
6
+ from flax.training.common_utils import shard
7
+ from PIL import Image
8
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
+ import cv2
10
+
11
+ def create_key(seed=0):
12
+ return jax.random.PRNGKey(seed)
13
+
14
+ def canny_filter(image):
15
+ gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
16
+ blurred_image = cv2.GaussianBlur(gray_image, (5, 5), 0)
17
+ edges_image = cv2.Canny(blurred_image, 50, 200)
18
+ return edges_image
19
+
20
+ # load control net and stable diffusion v1-5
21
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
22
+ "tsungtao/controlnet-mlsd-202305011046", from_flax=True, dtype=jnp.bfloat16
23
+ )
24
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
25
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
26
+ )
27
+
28
+ def infer(prompts, negative_prompts, image):
29
+ params["controlnet"] = controlnet_params
30
+
31
+ num_samples = 1 #jax.device_count()
32
+ rng = create_key(0)
33
+ rng = jax.random.split(rng, jax.device_count())
34
+ im = canny_filter(image)
35
+ canny_image = Image.fromarray(im)
36
+
37
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
38
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
39
+ processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
40
+
41
+ p_params = replicate(params)
42
+ prompt_ids = shard(prompt_ids)
43
+ negative_prompt_ids = shard(negative_prompt_ids)
44
+ processed_image = shard(processed_image)
45
+
46
+ output = pipe(
47
+ prompt_ids=prompt_ids,
48
+ image=processed_image,
49
+ params=p_params,
50
+ prng_seed=rng,
51
+ num_inference_steps=50,
52
+ neg_prompt_ids=negative_prompt_ids,
53
+ jit=True,
54
+ ).images
55
+
56
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
57
+ return output_images
58
+
59
+ title = "ControlNet MLSD"
60
+ description = "This is a demo on ControlNet MLSD."
61
+ examples = [["living room with TV", "fan", "image_01.jpg"],
62
+ ["a living room with hardwood floors and a flat screen tv", "sea", "image_02.jpg"],
63
+ ["a living room with a fireplace and a view of the ocean", "pendant", "image_03.jpg"]
64
+ ]
65
+
66
+ with gr.Blocks(css=".gradio-container {background: url('file=sky.jpg')}") as demo:
67
+ gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery", title = title, description = description, examples = examples, theme='gradio/soft')
68
+
69
+ gr.Markdown(
70
+ """
71
+ * * *
72
+ * [Dataset](https://huggingface.co/datasets/tsungtao/diffusers-testing)
73
+ * [Diffusers model](https://huggingface.co/runwayml/stable-diffusion-v1-5)
74
+ * [Training Report](https://wandb.ai/tsungtao0311/controlnet-mlsd-202305011046/runs/ezfn6bkz?workspace=user-tsungtao0311)
75
+ """)
76
+
77
+ with gr.Accordion("Open for More!"):
78
+ gr.Markdown("Look at me...")
79
+
80
+ gr.Markdown("* * *")
81
+ gr.Markdown(""" <img src='https://huggingface.co/spaces/tsungtao/tsungtao-controlnet-mlsd-202305011046/blob/main/test.png' /> """)
82
+
83
+
84
+ demo.launch()