Ryukijano's picture
Update app.py
8d56d1d
import gradio as gr
from transformers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, AutoTokenizer
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
controlnet = FlaxControlNetModel.from_pretrained(model_name)
pipeline = FlaxStableDiffusionControlNetPipeline.from_pretrained(model_name)
return tokenizer, controlnet, pipeline
model_name = "Ryukijano/controlnet-fill-circle"
tokenizer, controlnet, pipeline = load_model(model_name)
def infer_fill_circle(prompt, image):
# Your inference function for fill circle control
inputs = tokenizer(prompt, return_tensors="jax")
# Implement your image preprocessing here
outputs = pipeline.generate(inputs, image)
return outputs
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown("## Stable Diffusion with Fill Circle Control")
gr.Markdown("In this app, you can find the ControlNet with Fill Circle control.")
with gr.Tab("ControlNet Fill Circle"):
prompt_input_fill_circle = gr.Textbox(label="Prompt")
negative_prompt_fill_circle = gr.Textbox(label="Negative Prompt")
fill_circle_input = gr.Image(label="Input Image")
fill_circle_output = gr.Image(label="Output Image")
submit_btn = gr.Button(value="Submit")
fill_circle_inputs = [prompt_input_fill_circle, fill_circle_input]
submit_btn.click(fn=infer_fill_circle, inputs=fill_circle_inputs, outputs=[fill_circle_output])
demo.launch()