vision / app.py
akhaliq's picture
akhaliq HF staff
Update app.py
86f16e5 verified
raw
history blame
2.1 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
# Load the processor and model
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
def process_image_and_text(image, text):
# Process the image and text
inputs = processor.process(
images=[Image.fromarray(image)],
text=text
)
# Move inputs to the correct device and make a batch of size 1
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
# Generate output
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer
)
# Only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
def chatbot(image, text, history):
if image is None:
return history + [("Please upload an image first.", None)]
response = process_image_and_text(image, text)
history.append((text, response))
return history
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
with gr.Row():
image_input = gr.Image(type="numpy")
chatbot_output = gr.Chatbot()
text_input = gr.Textbox(placeholder="Ask a question about the image...")
submit_button = gr.Button("Submit")
state = gr.State([])
submit_button.click(
chatbot,
inputs=[image_input, text_input, state],
outputs=[chatbot_output]
)
text_input.submit(
chatbot,
inputs=[image_input, text_input, state],
outputs=[chatbot_output]
)
demo.launch()