import streamlit as st import torch import copy import os from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from typing import Tuple, Optional from my_model.utilities.gen_utilities import free_gpu_resources from my_model.captioner.image_captioning import ImageCaptioningModel from my_model.object_detection import ObjectDetector import my_model.config.kbvqa_config as config class KBVQA: """ The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model. It integrates various components such as an image captioning model, object detection model, and a fine-tuned language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions. Attributes: kbvqa_model_name (str): Name of the fine-tuned language model used for KBVQA. quantization (str): The quantization setting for the model (e.g., '4bit', '8bit'). max_context_window (int): The maximum number of tokens allowed in the model's context window. add_eos_token (bool): Flag to indicate whether to add an end-of-sentence token to the tokenizer. trust_remote (bool): Flag to indicate whether to trust remote code when using the tokenizer. use_fast (bool): Flag to indicate whether to use the fast version of the tokenizer. low_cpu_mem_usage (bool): Flag to optimize model loading for low CPU memory usage. kbvqa_tokenizer (Optional[AutoTokenizer]): The tokenizer for the KBVQA model. captioner (Optional[ImageCaptioningModel]): The model used for generating image captions. detector (Optional[ObjectDetector]): The object detection model. detection_model (Optional[str]): The name of the object detection model. detection_confidence (Optional[float]): The confidence threshold for object detection. kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA. bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model. access_token (str): Access token for Hugging Face API. current_prompt_length (int): Prompt length. Methods: create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting. load_caption_model: Loads the image captioning model. get_caption: Generates a caption for a given image. load_detector: Loads the object detection model. detect_objects: Detects objects in a given image. load_fine_tuned_model: Loads the fine-tuned KBVQA model along with its tokenizer. all_models_loaded: Checks if all the required models are loaded. force_reload_model: Forces a reload of all models, freeing up GPU resources. format_prompt: Formats the prompt for the KBVQA model. generate_answer: Generates an answer to a given question using the KBVQA model. """ def __init__(self): # self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2]) self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME self.quantization: str = config.QUANTIZATION self.max_context_window: int = config.MAX_CONTEXT_WINDOW self.add_eos_token: bool = config.ADD_EOS_TOKEN self.trust_remote: bool = config.TRUST_REMOTE self.use_fast: bool = config.USE_FAST self.low_cpu_mem_usage: bool = config.LOW_CPU_MEM_USAGE self.kbvqa_tokenizer: Optional[AutoTokenizer] = None self.captioner: Optional[ImageCaptioningModel] = None self.detector: Optional[ObjectDetector] = None self.detection_model: Optional[str] = None self.detection_confidence: Optional[float] = None self.kbvqa_model: Optional[AutoModelForCausalLM] = None self.bnb_config: BitsAndBytesConfig = self.create_bnb_config() self.access_token: str = config.HUGGINGFACE_TOKEN self.current_prompt_length = None def create_bnb_config(self) -> BitsAndBytesConfig: """ Creates a BitsAndBytes configuration based on the quantization setting. Returns: BitsAndBytesConfig: Configuration for BitsAndBytes optimized model. """ if self.quantization == '4bit': return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) elif self.quantization == '8bit': return BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf4", bnb_8bit_compute_dtype=torch.bfloat16 ) def load_caption_model(self) -> None: """ Loads the image captioning model into the KBVQA instance. """ self.captioner = ImageCaptioningModel() self.captioner.load_model() def get_caption(self, img: Image.Image) -> str: """ Generates a caption for a given image using the image captioning model. Args: img (PIL.Image.Image): The image for which to generate a caption. Returns: str: The generated caption for the image. """ return self.captioner.generate_caption(img) def load_detector(self, model: str) -> None: """ Loads the object detection model. Args: model (str): The name of the object detection model to load. """ self.detector = ObjectDetector() self.detector.load_model(model) def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]: """ Detects objects in a given image using the loaded object detection model. Args: img (PIL.Image.Image): The image in which to detect objects. Returns: tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects. """ st.write('detect func', self.detection_confidence) image = self.detector.process_image(img) detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=self.detection_confidence) image_with_boxes = self.detector.draw_boxes(img, detected_objects_list) return image_with_boxes, detected_objects_string def load_fine_tuned_model(self) -> None: """ Loads the fine-tuned KBVQA model along with its tokenizer. """ self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name, device_map="auto", low_cpu_mem_usage=True, quantization_config=self.bnb_config, token=self.access_token) self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name, use_fast=self.use_fast, low_cpu_mem_usage=True, trust_remote_code=self.trust_remote, add_eos_token=self.add_eos_token, token=self.access_token) @property def all_models_loaded(self): """ Checks if all the required models (KBVQA, captioner, detector) are loaded. Returns: bool: True if all models are loaded, False otherwise. """ return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str: """ Formats the prompt for the KBVQA model based on the provided parameters. Args: current_query (str): The current question to be answered. history (str, optional): The history of previous interactions. sys_prompt (str, optional): The system prompt or instructions for the model. caption (str, optional): The caption of the image. objects (str, optional): The detected objects in the image. Returns: str: The formatted prompt for the KBVQA model. """ B_SENT = '' E_SENT = '' B_INST = '[INST]' E_INST = '[/INST]' B_SYS = '<>\n' E_SYS = '\n<>\n\n' B_CAP = '[CAP]' E_CAP = '[/CAP]' B_QES = '[QES]' E_QES = '[/QES]' B_OBJ = '[OBJ]' E_OBJ = '[/OBJ]' current_query = current_query.strip() if sys_prompt is None: sys_prompt = config.SYSTEM_PROMPT.strip() if history is None: if objects is None: p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}""" else: p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}""" else: p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}""" return p def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str: """ Generates an answer to a given question using the KBVQA model. Args: question (str): The question to be answered. caption (str): The caption of the image related to the question. detected_objects_str (str): The string representation of detected objects in the image. Returns: str: The generated answer to the question. """ prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str) num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt)) self.current_prompt_length = num_tokens if num_tokens > self.max_context_window: st.warning(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector") return model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda') input_ids = model_inputs["input_ids"] output_ids = self.kbvqa_model.generate(input_ids) index = input_ids.shape[1] # needed to avoid printing the input prompt history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False) output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True) return output_text.capitalize() def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA: """ Prepares the KBVQA model for use, including loading necessary sub-models. Args: only_reload_detection_model (bool): If True, only the object detection model is reloaded. Returns: KBVQA: An instance of the KBVQA model ready for inference. """ if force_reload: loading_message = 'Force Reloading model.. this should take no more than a few minutes!' try: del kbvqa except: free_gpu_resources() pass free_gpu_resources() else: loading_message = 'Looading model.. this should take no more than 2 or 3 minutes!' free_gpu_resources() kbvqa = KBVQA() kbvqa.detection_model = st.session_state.detection_model # Progress bar for model loading with st.spinner(loading_message): if not only_reload_detection_model: progress_bar = st.progress(0) kbvqa.load_detector(kbvqa.detection_model) progress_bar.progress(33) kbvqa.load_caption_model() free_gpu_resources() progress_bar.progress(75) st.text('Almost there :)') kbvqa.load_fine_tuned_model() free_gpu_resources() progress_bar.progress(100) else: free_gpu_resources() progress_bar = st.progress(0) kbvqa.load_detector(kbvqa.detection_model) progress_bar.progress(100) if kbvqa.all_models_loaded: st.success('Model loaded successfully and ready for inferecne!') kbvqa.kbvqa_model.eval() free_gpu_resources() return kbvqa