import inspect import json import os import random from typing import Literal, cast import gradio as gr import torch from PIL import Image from gradio.data_classes import InterfaceTypes from gradio.flagging import CSVLogger from torchvision import transforms from transformers import AutoTokenizer, LlamaForCausalLM from trace_exec import run_program_with_trace, CompileTimeError from vision_processes import load_models print("-" * 10, "Loading models...") load_models() with open('joint.prompt') as f: prompt_template = f.read().strip() INPUT_TYPE = 'image' OUTPUT_TYPE = 'str' SIGNATURE = f'def execute_command({INPUT_TYPE}) -> {OUTPUT_TYPE}:' def generate(model, input_text): torch.cuda.empty_cache() print("-" * 10, "Before loading LLM:") print(torch.cuda.memory_summary()) dtype = os.environ.get("CODELLAMA_DTYPE") assert dtype in ['bfloat16', '8bit', '4bit', ] tokenizer = AutoTokenizer.from_pretrained(model) model = LlamaForCausalLM.from_pretrained( model, device_map="auto", load_in_8bit=dtype == "8bit", load_in_4bit=dtype == "4bit", torch_dtype=torch.bfloat16 if dtype == "bfloat16" else None, ) print("-" * 10, "LLM loaded:") print(model) print(torch.cuda.memory_summary()) input_ids = tokenizer(input_text, return_tensors="pt").input_ids generated_ids = model.generate( input_ids.to('cuda'), max_new_tokens=256, stop_strings=["\n\n"], do_sample=False, tokenizer=tokenizer ) generated_ids = generated_ids[0][input_ids.shape[1]:] text = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) del model torch.cuda.empty_cache() print("-" * 10, "After loading LLM:") print(torch.cuda.memory_summary()) return text def to_custom_trace(result, error, traced): if traced is None: assert isinstance(error, CompileTimeError) traced = 'Compile Error' return "-> {}\n\n--- Trace\n\n{}".format(result, traced) def answer_from_trace(x): assert x.startswith("->") return x[2:].splitlines()[0].strip() def debug(image, question, code, traced_info): # critic prompt = f"# Given an image: {question}\n{code}\n\n{traced_info}\n\n# Program is" print("--- For debug: critic prompt is ---") print(prompt) print("---\n") critic_out = generate("VDebugger/VDebugger-critic-generalist-7B", prompt) incorrect = critic_out.strip().startswith('wrong') critic_out = "# Program is" + critic_out if not incorrect: yield code, traced_info, critic_out, "N/A", "N/A", answer_from_trace(traced_info) return else: yield code, traced_info, critic_out, "RUNNING IN PROGRESS...", "", "" # refiner critic_code = ('def execute_command' + critic_out.split('def execute_command')[1]).strip() if '# Program is' in code: critic_code = critic_code.split("# Program is")[0].strip() # errr, an awkward fix prompt = f"# Given an image: {question}\n{critic_code}\n\n{traced_info}\n\n# Correction" print("--- For debug: refiner prompt is ---") print(prompt) print("---\n") refiner_out = generate("VDebugger/VDebugger-refiner-generalist-7B", prompt).strip() yield code, traced_info, critic_out, refiner_out, "RUNNING IN PROGRESS...", "" # execute (again) result, error, traced = run_program_with_trace(refiner_out, image, INPUT_TYPE, OUTPUT_TYPE) traced_info_2 = to_custom_trace(result, error, traced) yield code, traced_info, critic_out, refiner_out, traced_info_2, answer_from_trace(traced_info_2) def predict(image, question): if image is None: gr.Warning("Please provide an image", duration=5) return image = transforms.Compose([transforms.ToTensor()])(image) question = question.strip() if question == "": gr.Warning("Please provide a question", duration=5) return # codellama prompt = prompt_template.replace("INSERT_QUERY_HERE", f"Given an image: {question}\n{SIGNATURE}") code = generate("codellama/CodeLlama-7b-Python-hf", prompt) code = (SIGNATURE + code).strip() yield code, "RUNNING IN PROGRESS...", "", "", "", "" # execute result, error, traced = run_program_with_trace(code, image, INPUT_TYPE, OUTPUT_TYPE) traced_info = to_custom_trace(result, error, traced) yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" for tup in debug(image, question, code, traced_info): yield tup return def re_debug(image, question, code, traced_info): if code is None or code == "" or traced_info is None or traced_info == "": gr.Warning("No prior debugging round", duration=5) return yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" for tup in debug(image, question, code, traced_info): yield tup return DESCRIPTION = """# VDebugger | [Paper](https://arxiv.org/abs/2406.13444) | [Project](https://shirley-wu.github.io/vdebugger/) | [Code](https://github.com/shirley-wu/vdebugger/) | [Models and Data](https://huggingface.co/VDebugger) | **VDebugger** is a novel critic-refiner framework trained to localize and debug *visual programs* by tracking execution step by step. In this demo, we show the visual programs, the outputs from both the critic and the refiner, as well as the final result. **Warning:** Reduced performance and accuracy may be observed. Due to resource limitation of huggingface spaces, this demo runs Llama inference in 4-bit quantization and uses smaller foundation VLMs. For full capacity, please use the original code.""" class MyInterface(gr.Interface): def __init__(self): super(gr.Interface, self).__init__( title=None, theme=None, analytics_enabled=None, mode="tabbed_interface", css=None, js=None, head=None, ) self.interface_type = InterfaceTypes.STANDARD self.description = DESCRIPTION self.cache_examples = None self.examples_per_page = 5 self.example_labels = None self.batch = False self.live = False self.api_name = "predict" self.max_batch_size = 4 self.concurrency_limit = 'default' self.show_progress = "full" self.allow_flagging = 'auto' self.flagging_options = [("Flag", ""), ] self.flagging_callback = CSVLogger() self.flagging_dir = 'flagged' # Load examples with open('examples/questions.json') as f: example_questions = json.load(f) self.examples = [] for question in example_questions: self.examples.append([ Image.open('examples/{}.jpg'.format(question['imageId'])), question['question'], ]) def load_random_example(): image, question = random.choice(self.examples) return image, question, "", "", "", "", "", "" # Render the Gradio UI with self: self.render_title_description() with gr.Row(): image = gr.Image(label="Image", type="pil", width="30%", scale=1) question = gr.Textbox(label="Question", scale=2) with gr.Row(): _clear_btn = gr.ClearButton(value="Clear", variant="secondary") _random_eg_btn = gr.Button("Random Example Input") _submit_btn = gr.Button("Submit", variant="primary") if inspect.isgeneratorfunction(predict) or inspect.isasyncgenfunction(predict): _stop1_btn = gr.Button("Stop", variant="stop", visible=False) _redebug_btn = gr.Button("Debug for Another Round", variant="primary") if inspect.isgeneratorfunction(re_debug) or inspect.isasyncgenfunction(re_debug): _stop2_btn = gr.Button("Stop", variant="stop", visible=False) with gr.Row(): o1 = gr.Textbox(label="No debugging: program") o2 = gr.Textbox(label="No debugging: execution") with gr.Row(): o3 = gr.Textbox(label="VDebugger: critic") o4 = gr.Textbox(label="VDebugger: refiner") with gr.Row(): o5 = gr.Textbox(label="VDebugger: execution") o6 = gr.Textbox(label="VDebugger: final answer") question.submit(fn=predict, inputs=[image, question], outputs=[o1, o2, o3, o4, o5, o6]) _random_eg_btn.click(fn=load_random_example, outputs=[image, question, o1, o2, o3, o4, o5, o6]) async def cleanup(): return [gr.Button(visible=True), gr.Button(visible=False)] # Setup redebug event triggers = [_redebug_btn.click, ] extra_output = [_redebug_btn, _stop2_btn] predict_event = gr.on( triggers, gr.utils.async_lambda( lambda: ( gr.Button(visible=False), gr.Button(visible=True), ) ), inputs=None, outputs=[_redebug_btn, _stop2_btn], queue=False, show_api=False, ).then( re_debug, [image, question, o4, o5], [o1, o2, o3, o4, o5, o6], api_name=self.api_name, scroll_to_output=False, preprocess=not (self.api_mode), postprocess=not (self.api_mode), batch=self.batch, max_batch_size=self.max_batch_size, concurrency_limit=self.concurrency_limit, show_progress=cast( Literal["full", "minimal", "hidden"], self.show_progress ), ) redebug_event = predict_event.then( cleanup, inputs=None, outputs=extra_output, # type: ignore queue=False, show_api=False, ) _stop2_btn.click( cleanup, inputs=None, outputs=[_redebug_btn, _stop2_btn], cancels=predict_event, queue=False, show_api=False, ) # Setup submit event triggers = [_submit_btn.click, question.submit, ] extra_output = [_submit_btn, _stop1_btn] predict_event = gr.on( triggers, gr.utils.async_lambda( lambda: ( gr.Button(visible=False), gr.Button(visible=True), ) ), inputs=None, outputs=[_submit_btn, _stop1_btn], queue=False, show_api=False, ).then( predict, [image, question], [o1, o2, o3, o4, o5, o6], api_name=self.api_name, scroll_to_output=False, preprocess=not (self.api_mode), postprocess=not (self.api_mode), batch=self.batch, max_batch_size=self.max_batch_size, concurrency_limit=self.concurrency_limit, show_progress=cast( Literal["full", "minimal", "hidden"], self.show_progress ), ) submit_event = predict_event.then( cleanup, inputs=None, outputs=extra_output, # type: ignore queue=False, show_api=False, ) _stop1_btn.click( cleanup, inputs=None, outputs=[_submit_btn, _stop1_btn], cancels=predict_event, queue=False, show_api=False, ) # Finally borrow Interface stuff self.input_components = [image, question] self.output_components = [o1, o2, o3, o4, o5, o6] self.fn = predict self.attach_clear_events(_clear_btn, None) self.render_examples() if __name__ == "__main__": MyInterface().launch(share=os.environ.get("SHARE", '') != "")