VisionFlowModule / VisionAtomicFlow.py
nbaldwin's picture
renamed flows to aiflows
94f0f9e
from typing import Dict, Any
from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow
from aiflows.utils.general_helpers import encode_image,encode_from_buffer
import cv2
class VisionAtomicFlow(ChatAtomicFlow):
""" This class implements the atomic flow for the VisionFlowModule. It is a flow that, given a textual input, and a set of images and/or videos, generates a textual output.
It uses the litellm library as a backend. See https://docs.litellm.ai/docs/providers for supported models and APIs.
*Configuration Parameters*:
- `name` (str): The name of the flow. Default: "VisionAtomicFlow"
- `description` (str): A description of the flow. This description is used to generate the help message of the flow.
Default: "A flow that, given a textual input, and a set of images and/or videos, generates a textual output."
- enable_cache (bool): If True, the flow will use the cache. Default: True
- `n_api_retries` (int): The number of times to retry the API call in case of failure. Default: 6
- `wait_time_between_api_retries` (int): The time to wait between API retries in seconds. Default: 20
- `system_name` (str): The name of the system. Default: "system"
- `user_name` (str): The name of the user. Default: "user"
- `assistant_name` (str): The name of the assistant. Default: "assistant"
- `backend` (Dict[str, Any]): The configuration of the backend which is used to fetch api keys. Default: LiteLLMBackend with the
default parameters of ChatAtomicFlow (see Flow card of ChatAtomicFlowModule). Except for the following parameters
whose default value is overwritten:
- `api_infos` (List[Dict[str, Any]]): The list of api infos. Default: No default value, this parameter is required.
- `model_name` (Union[Dict[str,str],str]): The name of the model to use.
When using multiple API providers, the model_name can be a dictionary of the form
{"provider_name": "model_name"}.
Default: "gpt-4-vision-preview" (the name needs to follow the name of the model in litellm https://docs.litellm.ai/docs/providers).
- `n` (int) : The number of answers to generate. Default: 1
- `max_tokens` (int): The maximum number of tokens to generate. Default: 2000
- `temperature` (float): The temperature to use. Default: 0.3
- `top_p` (float): An alternative to sampling with temperature. It instructs the model to consider the results of
the tokens with top_p probability. Default: 0.2
- `frequency_penalty` (float): The higher this value, the more likely the model will repeat itself. Default: 0.0
- `presence_penalty` (float): The higher this value, the less likely the model will talk about a new topic. Default: 0.0
- `system_message_prompt_template` (Dict[str,Any]): The template of the system message. It is used to generate the system message.
By default its of type aiflows.prompt_template.JinjaPrompt.
None of the parameters of the prompt are defined by default and therefore need to be defined if one wants to use the system prompt.
Default parameters are defined in aiflows.prompt_template.jinja2_prompts.JinjaPrompt.
- `init_human_message_prompt_template` (Dict[str,Any]): The prompt template of the human/user message used to initialize the conversation
(first time in). It is used to generate the human message. It's passed as the user message to the LLM.
By default its of type aiflows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one
wants to use the init_human_message_prompt_template. Default parameters are defined in aiflows.prompt_template.jinja2_prompts.JinjaPrompt.
- `previous_messages` (Dict[str,Any]): Defines which previous messages to include in the input of the LLM. Note that if `first_k`and `last_k` are both none,
all the messages of the flows's history are added to the input of the LLM. Default:
- `first_k` (int): If defined, adds the first_k earliest messages of the flow's chat history to the input of the LLM. Default: None
- `last_k` (int): If defined, adds the last_k latest messages of the flow's chat history to the input of the LLM. Default: None
- Other parameters are inherited from the default configuration of ChatAtomicFlow (see Flow card of ChatAtomicFlowModule).
*Input Interface Initialized (Expected input the first time in flow)*:
- `query` (str): The textual query to run the model on.
- `data` (Dict[str, Any]): The data (images or video) to run the model on. It can contain the following keys:
- `images` (List[Dict[str, Any]]): A list of images to run the model on. Each image is a dictionary that contains the following keys:
- `type` (str): The type of the image. It can be "local_path" or "url".
- `image` (str): The image. If type is "local_path", it is a local path to the image. If type is "url", it is a url to the image.
- `video` (Dict[str, Any]): A video to run the model on. It is a dictionary that contains the following keys:
- `video_path` (str): The path to the video.
- `resize` (int): The resize we want to apply on the frames of the video.
- `frame_step_size` (int): The step size between the frames of the video (to send to the model).
- `start_frame` (int): The start frame of the video (to send to the model).
- `end_frame` (int): The last frame of the video (to send to the model).
*Input Interface (Expected input the after the first time in flow)*:
- `query` (str): The textual query to run the model on.
- `data` (Dict[str, Any]): The data (images or video) to run the model on. It can contain the following keys:
- `images` (List[Dict[str, Any]]): A list of images to run the model on. Each image is a dictionary that contains the following keys:
- `type` (str): The type of the image. It can be "local_path" or "url".
- `image` (str): The image. If type is "local_path", it is a local path to the image. If type is "url", it is a url to the image.
- `video` (Dict[str, Any]): A video to run the model on. It is a dictionary that contains the following keys:
- `video_path` (str): The path to the video.
- `resize` (int): The resize we want to apply on the frames of the video.
- `frame_step_size` (int): The step size between the frames of the video (to send to the model).
- `start_frame` (int): The start frame of the video (to send to the model).
- `end_frame` (int): The last frame of the video (to send to the model).
*Output Interface*:
- `api_output`s (str): The api output of the flow to the query and data
"""
@staticmethod
def get_image(image):
""" This method returns an image in the appropriate format for API.
:param image: The image dictionary.
:type image: Dict[str, Any]
:return: The image url.
:rtype: Dict[str, Any]
"""
extension_dict = {
"jpg": "jpeg",
"jpeg": "jpeg",
"png": "png",
"webp": "webp",
"gif": "gif"
}
supported_image_types = ["local_path","url"]
assert image.get("type",None) in supported_image_types, f"Must define a valid image type for every image \n your type: {image.get('type',None)} \n supported types{supported_image_types} "
processed_image = None
url = None
if image["type"] == "local_path":
processed_image = encode_image(image.get("image"))
image_extension_type = image.get("image").split(".")[-1]
url = f"data:image/{extension_dict[image_extension_type]};base64, {processed_image}"
elif image["type"] == "url":
processed_image = image
url = image.get("image")
return {"type": "image_url", "image_url": {"url": url}}
@staticmethod
def get_video(video):
""" This method returns the video in the appropriate format for API.
:param video: The video dictionary.
:type video: Dict[str, Any]
:return: The video url.
:rtype: Dict[str, Any]
"""
video_path = video["video_path"]
resize = video.get("resize",768)
frame_step_size = video.get("frame_step_size",10)
start_frame = video.get("start_frame",0)
end_frame = video.get("end_frame",None)
base64Frames = []
video = cv2.VideoCapture(video_path)
while video.isOpened():
success,frame = video.read()
if not success:
break
_,buffer = cv2.imencode(".jpg",frame)
base64Frames.append(encode_from_buffer(buffer))
video.release()
return map(lambda x: {"image": x, "resize": resize},base64Frames[start_frame:end_frame:frame_step_size])
@staticmethod
def get_user_message(prompt_template, input_data: Dict[str, Any]):
""" This method constructs the user message to be passed to the API.
:param prompt_template: The prompt template to use.
:type prompt_template: PromptTemplate
:param input_data: The input data.
:type input_data: Dict[str, Any]
:return: The constructed user message (images , videos and text).
:rtype: Dict[str, Any]
"""
content = VisionAtomicFlow._get_message(prompt_template=prompt_template,input_data=input_data)
media_data = input_data["data"]
if "video" in media_data:
content = [ content[0], *VisionAtomicFlow.get_video(media_data["video"])]
if "images" in media_data:
images = [VisionAtomicFlow.get_image(image) for image in media_data["images"]]
content.extend(images)
return content
@staticmethod
def _get_message(prompt_template, input_data: Dict[str, Any]):
""" This method constructs the textual message to be passed to the API.
:param prompt_template: The prompt template to use.
:type prompt_template: PromptTemplate
:param input_data: The input data.
:type input_data: Dict[str, Any]
:return: The constructed textual message.
:rtype: Dict[str, Any]
"""
template_kwargs = {}
for input_variable in prompt_template.input_variables:
template_kwargs[input_variable] = input_data[input_variable]
msg_content = prompt_template.format(**template_kwargs)
return [{"type": "text", "text": msg_content}]
def _process_input(self, input_data: Dict[str, Any]):
""" This method processes the input data (prepares the messages to send to the API).
:param input_data: The input data.
:type input_data: Dict[str, Any]
:return: The processed input data.
:rtype: Dict[str, Any]
"""
if self._is_conversation_initialized():
# Construct the message using the human message prompt template
user_message_content = self.get_user_message(self.human_message_prompt_template, input_data)
else:
# Initialize the conversation (add the system message, and potentially the demonstrations)
self._initialize_conversation(input_data)
if getattr(self, "init_human_message_prompt_template", None) is not None:
# Construct the message using the query message prompt template
user_message_content = self.get_user_message(self.init_human_message_prompt_template, input_data)
else:
user_message_content = self.get_user_message(self.human_message_prompt_template, input_data)
self._state_update_add_chat_message(role=self.flow_config["user_name"],
content=user_message_content)