Xueqing Wu
init
e20ef71
import inspect
import json
import os
import random
from typing import Literal, cast
import gradio as gr
import torch
from PIL import Image
from gradio.data_classes import InterfaceTypes
from gradio.flagging import CSVLogger
from torchvision import transforms
from transformers import AutoTokenizer, LlamaForCausalLM
from trace_exec import run_program_with_trace, CompileTimeError
from vision_processes import load_models
print("-" * 10, "Loading models...")
load_models()
with open('joint.prompt') as f:
prompt_template = f.read().strip()
INPUT_TYPE = 'image'
OUTPUT_TYPE = 'str'
SIGNATURE = f'def execute_command({INPUT_TYPE}) -> {OUTPUT_TYPE}:'
def generate(model, input_text):
torch.cuda.empty_cache()
print("-" * 10, "Before loading LLM:")
print(torch.cuda.memory_summary())
dtype = os.environ.get("CODELLAMA_DTYPE")
assert dtype in ['bfloat16', '8bit', '4bit', ]
tokenizer = AutoTokenizer.from_pretrained(model)
model = LlamaForCausalLM.from_pretrained(
model,
device_map="auto",
load_in_8bit=dtype == "8bit",
load_in_4bit=dtype == "4bit",
torch_dtype=torch.bfloat16 if dtype == "bfloat16" else None,
)
print("-" * 10, "LLM loaded:")
print(model)
print(torch.cuda.memory_summary())
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
generated_ids = model.generate(
input_ids.to('cuda'), max_new_tokens=256, stop_strings=["\n\n"], do_sample=False, tokenizer=tokenizer
)
generated_ids = generated_ids[0][input_ids.shape[1]:]
text = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
del model
torch.cuda.empty_cache()
print("-" * 10, "After loading LLM:")
print(torch.cuda.memory_summary())
return text
def to_custom_trace(result, error, traced):
if traced is None:
assert isinstance(error, CompileTimeError)
traced = 'Compile Error'
return "-> {}\n\n--- Trace\n\n{}".format(result, traced)
def answer_from_trace(x):
assert x.startswith("->")
return x[2:].splitlines()[0].strip()
def debug(image, question, code, traced_info):
# critic
prompt = f"# Given an image: {question}\n{code}\n\n{traced_info}\n\n# Program is"
print("--- For debug: critic prompt is ---")
print(prompt)
print("---\n")
critic_out = generate("VDebugger/VDebugger-critic-generalist-7B", prompt)
incorrect = critic_out.strip().startswith('wrong')
critic_out = "# Program is" + critic_out
if not incorrect:
yield code, traced_info, critic_out, "N/A", "N/A", answer_from_trace(traced_info)
return
else:
yield code, traced_info, critic_out, "RUNNING IN PROGRESS...", "", ""
# refiner
critic_code = ('def execute_command' + critic_out.split('def execute_command')[1]).strip()
if '# Program is' in code:
critic_code = critic_code.split("# Program is")[0].strip() # errr, an awkward fix
prompt = f"# Given an image: {question}\n{critic_code}\n\n{traced_info}\n\n# Correction"
print("--- For debug: refiner prompt is ---")
print(prompt)
print("---\n")
refiner_out = generate("VDebugger/VDebugger-refiner-generalist-7B", prompt).strip()
yield code, traced_info, critic_out, refiner_out, "RUNNING IN PROGRESS...", ""
# execute (again)
result, error, traced = run_program_with_trace(refiner_out, image, INPUT_TYPE, OUTPUT_TYPE)
traced_info_2 = to_custom_trace(result, error, traced)
yield code, traced_info, critic_out, refiner_out, traced_info_2, answer_from_trace(traced_info_2)
def predict(image, question):
if image is None:
gr.Warning("Please provide an image", duration=5)
return
image = transforms.Compose([transforms.ToTensor()])(image)
question = question.strip()
if question == "":
gr.Warning("Please provide a question", duration=5)
return
# codellama
prompt = prompt_template.replace("INSERT_QUERY_HERE", f"Given an image: {question}\n{SIGNATURE}")
code = generate("codellama/CodeLlama-7b-Python-hf", prompt)
code = (SIGNATURE + code).strip()
yield code, "RUNNING IN PROGRESS...", "", "", "", ""
# execute
result, error, traced = run_program_with_trace(code, image, INPUT_TYPE, OUTPUT_TYPE)
traced_info = to_custom_trace(result, error, traced)
yield code, traced_info, "RUNNING IN PROGRESS...", "", "", ""
for tup in debug(image, question, code, traced_info):
yield tup
return
def re_debug(image, question, code, traced_info):
if code is None or code == "" or traced_info is None or traced_info == "":
gr.Warning("No prior debugging round", duration=5)
return
yield code, traced_info, "RUNNING IN PROGRESS...", "", "", ""
for tup in debug(image, question, code, traced_info):
yield tup
return
DESCRIPTION = """# VDebugger
| [Paper](https://arxiv.org/abs/2406.13444) | [Project](https://shirley-wu.github.io/vdebugger/) | [Code](https://github.com/shirley-wu/vdebugger/) | [Models and Data](https://huggingface.co/VDebugger) |
**VDebugger** is a novel critic-refiner framework trained to localize and debug *visual programs* by tracking execution step by step. In this demo, we show the visual programs, the outputs from both the critic and the refiner, as well as the final result.
**Warning:** Reduced performance and accuracy may be observed. Due to resource limitation of huggingface spaces, this demo runs Llama inference in 4-bit quantization and uses smaller foundation VLMs. For full capacity, please use the original code."""
class MyInterface(gr.Interface):
def __init__(self):
super(gr.Interface, self).__init__(
title=None,
theme=None,
analytics_enabled=None,
mode="tabbed_interface",
css=None,
js=None,
head=None,
)
self.interface_type = InterfaceTypes.STANDARD
self.description = DESCRIPTION
self.cache_examples = None
self.examples_per_page = 5
self.example_labels = None
self.batch = False
self.live = False
self.api_name = "predict"
self.max_batch_size = 4
self.concurrency_limit = 'default'
self.show_progress = "full"
self.allow_flagging = 'auto'
self.flagging_options = [("Flag", ""), ]
self.flagging_callback = CSVLogger()
self.flagging_dir = 'flagged'
# Load examples
with open('examples/questions.json') as f:
example_questions = json.load(f)
self.examples = []
for question in example_questions:
self.examples.append([
Image.open('examples/{}.jpg'.format(question['imageId'])), question['question'],
])
def load_random_example():
image, question = random.choice(self.examples)
return image, question, "", "", "", "", "", ""
# Render the Gradio UI
with self:
self.render_title_description()
with gr.Row():
image = gr.Image(label="Image", type="pil", width="30%", scale=1)
question = gr.Textbox(label="Question", scale=2)
with gr.Row():
_clear_btn = gr.ClearButton(value="Clear", variant="secondary")
_random_eg_btn = gr.Button("Random Example Input")
_submit_btn = gr.Button("Submit", variant="primary")
if inspect.isgeneratorfunction(predict) or inspect.isasyncgenfunction(predict):
_stop1_btn = gr.Button("Stop", variant="stop", visible=False)
_redebug_btn = gr.Button("Debug for Another Round", variant="primary")
if inspect.isgeneratorfunction(re_debug) or inspect.isasyncgenfunction(re_debug):
_stop2_btn = gr.Button("Stop", variant="stop", visible=False)
with gr.Row():
o1 = gr.Textbox(label="No debugging: program")
o2 = gr.Textbox(label="No debugging: execution")
with gr.Row():
o3 = gr.Textbox(label="VDebugger: critic")
o4 = gr.Textbox(label="VDebugger: refiner")
with gr.Row():
o5 = gr.Textbox(label="VDebugger: execution")
o6 = gr.Textbox(label="VDebugger: final answer")
question.submit(fn=predict, inputs=[image, question], outputs=[o1, o2, o3, o4, o5, o6])
_random_eg_btn.click(fn=load_random_example, outputs=[image, question, o1, o2, o3, o4, o5, o6])
async def cleanup():
return [gr.Button(visible=True), gr.Button(visible=False)]
# Setup redebug event
triggers = [_redebug_btn.click, ]
extra_output = [_redebug_btn, _stop2_btn]
predict_event = gr.on(
triggers,
gr.utils.async_lambda(
lambda: (
gr.Button(visible=False),
gr.Button(visible=True),
)
),
inputs=None,
outputs=[_redebug_btn, _stop2_btn],
queue=False,
show_api=False,
).then(
re_debug,
[image, question, o4, o5],
[o1, o2, o3, o4, o5, o6],
api_name=self.api_name,
scroll_to_output=False,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
concurrency_limit=self.concurrency_limit,
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)
redebug_event = predict_event.then(
cleanup,
inputs=None,
outputs=extra_output, # type: ignore
queue=False,
show_api=False,
)
_stop2_btn.click(
cleanup,
inputs=None,
outputs=[_redebug_btn, _stop2_btn],
cancels=predict_event,
queue=False,
show_api=False,
)
# Setup submit event
triggers = [_submit_btn.click, question.submit, ]
extra_output = [_submit_btn, _stop1_btn]
predict_event = gr.on(
triggers,
gr.utils.async_lambda(
lambda: (
gr.Button(visible=False),
gr.Button(visible=True),
)
),
inputs=None,
outputs=[_submit_btn, _stop1_btn],
queue=False,
show_api=False,
).then(
predict,
[image, question],
[o1, o2, o3, o4, o5, o6],
api_name=self.api_name,
scroll_to_output=False,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
concurrency_limit=self.concurrency_limit,
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)
submit_event = predict_event.then(
cleanup,
inputs=None,
outputs=extra_output, # type: ignore
queue=False,
show_api=False,
)
_stop1_btn.click(
cleanup,
inputs=None,
outputs=[_submit_btn, _stop1_btn],
cancels=predict_event,
queue=False,
show_api=False,
)
# Finally borrow Interface stuff
self.input_components = [image, question]
self.output_components = [o1, o2, o3, o4, o5, o6]
self.fn = predict
self.attach_clear_events(_clear_btn, None)
self.render_examples()
if __name__ == "__main__":
MyInterface().launch(share=os.environ.get("SHARE", '') != "")