vision / app.py
akhaliq's picture
akhaliq HF staff
Update app.py
ecf825a verified
raw
history blame
2.15 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import spaces
# Load the processor and model
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
@spaces.GPU(duration=120)
def process_image_and_text(image, text):
# Process the image and text
inputs = processor.process(
images=[Image.fromarray(image)],
text=text
)
# Move inputs to the correct device and make a batch of size 1
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
# Generate output
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer
)
# Only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
def chatbot(image, text, history):
if image is None:
return "Please upload an image first.", history
response = process_image_and_text(image, text)
history.append((text, response))
return response, history
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
with gr.Row():
image_input = gr.Image(type="numpy")
chatbot_output = gr.Chatbot()
text_input = gr.Textbox(placeholder="Ask a question about the image...")
submit_button = gr.Button("Submit")
state = gr.State([])
submit_button.click(
chatbot,
inputs=[image_input, text_input, state],
outputs=[chatbot_output, state]
)
text_input.submit(
chatbot,
inputs=[image_input, text_input, state],
outputs=[chatbot_output, state]
)
demo.launch()