|
import ast |
|
import importlib |
|
import io |
|
import os |
|
import re |
|
import string |
|
import time |
|
from functools import partial |
|
from typing import List |
|
|
|
import pysnooper |
|
|
|
FUNCTION_HEAD = "def execute_command({input_type}) -> {output_type}:" |
|
EXEC_FUNCTION_HEAD = 'def execute_command({input_type}, possible_answers, query, ImagePatch, VideoSegment,' \ |
|
' llm_query, bool_to_yesno, distance, best_image_match):' |
|
|
|
|
|
class CompileTimeError: |
|
pass |
|
|
|
|
|
class ProgramRuntimeError: |
|
pass |
|
|
|
|
|
def process_trace(text, function_head, execution_function_head): |
|
def remove_indent(lines): |
|
n_space = 0 |
|
for i, c in enumerate(lines[0]): |
|
if c == ' ': |
|
n_space += 1 |
|
else: |
|
break |
|
return [line[n_space:] if line[0] == ' ' else line for line in lines] |
|
|
|
def remove_pre_context(lines: List[str]): |
|
for i in range(len(lines) - 1, -1, -1): |
|
line = lines[i] |
|
if execution_function_head in line: |
|
|
|
content = [line.replace(execution_function_head, function_head)] + lines[i + 1:] |
|
if line[0] == ' ': |
|
return remove_indent(content) |
|
else: |
|
return content |
|
return [] |
|
|
|
def remove_post_context(lines): |
|
for i, line in enumerate(lines): |
|
if line.startswith("Source path:") and line.endswith(__file__): |
|
return lines[:i] |
|
elif line.startswith("Elapsed time"): |
|
return lines[:i] |
|
return lines |
|
|
|
def remove_timestamp(lines): |
|
ret = [] |
|
for line in lines: |
|
if len(line) > 0 and line[0] in string.digits: |
|
line = line[16:] |
|
ret.append(line) |
|
return ret |
|
|
|
def remove_tensor(line): |
|
return re.sub(r"tensor\(\[\[\[.*?\]\]\]\)", "tensor([[[...]]])", line) |
|
|
|
lines = text.splitlines() |
|
lines = remove_pre_context(lines) |
|
lines = remove_post_context(lines) |
|
lines = remove_timestamp(lines) |
|
lines = [remove_tensor(line) for line in lines] |
|
|
|
return '\n'.join(lines) |
|
|
|
|
|
cnt = 0 |
|
|
|
|
|
def run_program_with_trace(code, image, input_type_, output_type_): |
|
from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno |
|
|
|
function_head = FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_) |
|
execution_function_head = EXEC_FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_) |
|
|
|
code = str(code) |
|
if code.startswith("\ndef"): |
|
code = code[1:] |
|
|
|
if code.startswith('def'): |
|
if code.startswith(function_head): |
|
code = code.replace(function_head, '') |
|
else: |
|
print("--- Code with invalid format\n") |
|
print(code) |
|
code = execution_function_head + code |
|
try: |
|
code = ast.unparse(ast.parse(code)) |
|
except: |
|
return None, CompileTimeError(), None |
|
|
|
global cnt |
|
cnt += 1 |
|
name = f'x{cnt}' |
|
with open(f'{name}.py', 'w') as f: |
|
f.write(code) |
|
|
|
for _ in range(20): |
|
try: |
|
x = importlib.import_module(name) |
|
except ModuleNotFoundError: |
|
print("Errrr, import error. Wait a bit while.") |
|
time.sleep(60) |
|
except Exception as e: |
|
print("Import has error:", e) |
|
break |
|
else: |
|
break |
|
|
|
queues = [None, None] |
|
|
|
image_patch_partial = partial(ImagePatch, queues=queues) |
|
video_segment_partial = None |
|
llm_query_partial = partial(llm_query, queues=queues) |
|
|
|
|
|
|
|
with io.StringIO() as f: |
|
with pysnooper.snoop(output=f, color=False, depth=2, max_variable_length=1000): |
|
result = None |
|
error = None |
|
try: |
|
result = x.execute_command(image, None, '', image_patch_partial, video_segment_partial, |
|
llm_query_partial, bool_to_yesno, distance, best_image_match) |
|
except: |
|
error = ProgramRuntimeError() |
|
|
|
|
|
os.remove(f'{name}.py') |
|
f.seek(0) |
|
traced = f.read(100000) |
|
traced_processed = process_trace(traced, function_head, execution_function_head) |
|
|
|
return result, error, traced_processed |
|
|