Spaces:
Sleeping
Sleeping
from dotenv import load_dotenv | |
from langchain_core.messages import ( | |
BaseMessage, | |
HumanMessage, | |
ToolMessage, | |
) | |
import base64 | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langgraph.graph import END, StateGraph, START | |
from typing import Annotated, List | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain_core.tools import tool | |
from langchain_experimental.utilities import PythonREPL | |
import operator | |
from typing import Annotated, Sequence, TypedDict | |
from langchain_groq import ChatGroq | |
import functools | |
from langchain_core.messages import AIMessage | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langgraph.prebuilt import ToolNode | |
from typing import Literal | |
import gradio as gr | |
import io | |
import PIL | |
load_dotenv() | |
llm_coder = ChatGroq(temperature=0, model_name="llama-3.1-8b-instant") | |
llm_image = ChatGoogleGenerativeAI( | |
model="gemini-1.5-flash", | |
temperature=0, | |
max_tokens=None, | |
timeout=None, | |
max_retries=2, | |
) | |
search_tool = DuckDuckGoSearchRun() | |
repl_tool = PythonREPL() | |
def python_repl( | |
code: Annotated[str, "The python code to execute to answer the question."], | |
): | |
"""Use this to execute python code. If you want to see the output of a value, | |
you should print it out with `print(...)`. This is visible to the user.""" | |
try: | |
result = repl_tool.run(code) | |
except BaseException as e: | |
return f"Failed to execute. Error: {repr(e)}" | |
result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}" | |
return ( | |
result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER." | |
) | |
def create_agent(llm, tools, system_message: str): | |
"""Create an agent.""" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"You are a helpful AI assistant, collaborating with other assistants." | |
" Use the provided tools to progress towards answering the question." | |
" If you are unable to fully answer, that's OK, another assistant with different tools " | |
" will help where you left off. Execute what you can to make progress." | |
" If you or any of the other assistants have the final answer or deliverable," | |
" prefix your response with FINAL ANSWER so the team knows to stop." | |
" You have access to the following tools: {tool_names}.\n{system_message}", | |
), | |
MessagesPlaceholder(variable_name="messages"), | |
] | |
) | |
prompt = prompt.partial(system_message=system_message) | |
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) | |
return prompt | llm.bind_tools(tools) | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
sender: str | |
def agent_node(state, agent, name): | |
result = agent.invoke(state) | |
if isinstance(result, ToolMessage): | |
pass | |
else: | |
result = AIMessage(**result.dict(exclude={"type", "name"}), name=name) | |
return { | |
"messages": [result], | |
"sender": name, | |
} | |
problem_agent = create_agent( | |
llm_image, | |
[], | |
system_message="You should understand the problem properly and provide a clear description with the edge cases, don't provide the solution, after completing all tasks." | |
) | |
problem_node = functools.partial(agent_node, agent=problem_agent, name="problem_agent") | |
solution_agent = create_agent( | |
llm_image, | |
[], | |
system_message="after understanding the problem, you should provide a solution to the problem in python that is clear and concise and solves all edge cases, also provide intuition behind the solution." | |
) | |
solution_node = functools.partial(agent_node, agent=solution_agent, name="solution_agent") | |
checker_agent = create_agent( | |
llm_coder, | |
[], | |
system_message="critically analyze the solution provided by the solution agent, check for correctness, efficiency, and edge cases, if the solution is correct, provide a message saying so, if not, provide a message with the error and suggest a fix." | |
) | |
def checker_node(state): | |
text_only_messages = [] | |
for msg in state["messages"]: | |
if isinstance(msg.content, list): | |
text_content = [item["text"] for item in msg.content if item["type"] == "text"] | |
new_msg = msg.copy() | |
new_msg.content = " ".join(text_content) | |
text_only_messages.append(new_msg) | |
else: | |
text_only_messages.append(msg) | |
text_only_state = { | |
"messages": text_only_messages, | |
"sender": state["sender"] | |
} | |
result = checker_agent.invoke(text_only_state) | |
if isinstance(result, ToolMessage): | |
pass | |
else: | |
result = AIMessage(**result.dict(exclude={"type", "name"}), name="checker_agent") | |
return { | |
"messages": [result], | |
"sender": "checker_agent", | |
} | |
tools = [search_tool, python_repl] | |
tool_node = ToolNode(tools) | |
def router(state) -> Literal["call_tool", "__end__", "continue"]: | |
messages = state["messages"] | |
last_message = messages[-1] | |
if last_message.tool_calls: | |
return "call_tool" | |
if "FINAL ANSWER" in last_message.content: | |
return "__end__" | |
return "continue" | |
workflow = StateGraph(AgentState) | |
workflow.add_node("problem_creator", problem_node) | |
workflow.add_node("solution_generator", solution_node) | |
workflow.add_node("checker_agent", checker_node) | |
workflow.add_node("call_tool", tool_node) | |
workflow.add_conditional_edges( | |
"problem_creator", | |
router, | |
{"continue": "solution_generator", "call_tool": "call_tool", "__end__": END}, | |
) | |
workflow.add_conditional_edges( | |
"solution_generator", | |
router, | |
{"continue": "checker_agent", "call_tool": "call_tool", "__end__": END}, | |
) | |
workflow.add_conditional_edges( | |
"checker_agent", | |
router, | |
{"continue": "problem_creator", "call_tool": "call_tool", "__end__": END}, | |
) | |
workflow.add_conditional_edges( | |
"call_tool", | |
lambda x: x["sender"], | |
{ | |
"problem_creator": "problem_creator", | |
"solution_generator": "solution_generator", | |
"checker_agent": "checker_agent", | |
}, | |
) | |
workflow.add_edge(START, "problem_creator") | |
graph = workflow.compile() | |
def process_images(images: List[tuple[PIL.Image.Image, str | None]]): | |
if not images: | |
return "No images uploaded" | |
# Convert all images to base64 | |
image_contents = [] | |
for (image, _) in images: | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
image_contents.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/png;base64,{img_str}"} | |
}) | |
# Create the input for the workflow | |
input_data = {"messages": [HumanMessage( | |
content = [ | |
{"type": "text", "text": "answer the question about the following images"}, | |
*image_contents | |
] | |
)]} | |
# Run the workflow | |
output = [] | |
try: | |
for chunk in graph.stream(input_data, {"recursion_limit": 10}, stream_mode="values"): | |
message = chunk["messages"][-1] | |
output.append(f"{message.name}: {message.content}") | |
except Exception as e: | |
output.append(f"Error: {repr(e)}") | |
return "\n\n".join(output) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=process_images, | |
inputs=[gr.Gallery(label="Upload an image", type="pil")], | |
outputs=[gr.Markdown(label="Output", show_copy_button=True)], | |
title="Image Question Answering", | |
description="Upload an image to get it processed and answered." | |
) | |
# Launch the interface | |
iface.launch() |