File size: 8,002 Bytes
c59fc6b
 
9347b1e
139bf60
c59fc6b
 
2997bb2
d26dd8d
c59fc6b
 
 
 
 
 
 
 
 
7fd408f
c59fc6b
 
 
 
 
 
518eb6e
824d7ec
c59fc6b
139bf60
c59fc6b
 
 
 
08ec8d2
 
 
 
 
c59fc6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567d6d1
c59fc6b
 
 
 
 
 
 
 
 
 
 
6f1c42e
c59fc6b
6f1c42e
c59fc6b
 
 
 
 
ec4889b
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
 
 
 
 
 
e57843e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91f466a
 
c59fc6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c690614
97bc44b
c59fc6b
c690614
c59fc6b
c690614
e57843e
7ea3839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
 
c62c890
c59fc6b
97bc44b
c59fc6b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import streamlit as st
import torch
import copy
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Optional
from my_model.gen_utilities import free_gpu_resources
from my_model.captioner.image_captioning import ImageCaptioningModel
from my_model.object_detection import ObjectDetector


class KBVQA():

    def __init__(self):
        self.kbvqa_model_name = "m7mdal7aj/fine_tunned_llama_2_merged"
        self.quantization='4bit'
        self.bnb_config = self.create_bnb_config()
        self.max_context_window = 4000
        self.add_eos_token = False
        self.trust_remote = False
        self.use_fast = True
        self.kbvqa_tokenizer = None
        self.captioner = None
        self.detector = None
        self.detection_model = None
        self.detection_confidence = None 
        self.kbvqa_model = None
        self.access_token = os.getenv("HUGGINGFACE_TOKEN")
      #  self.kbvqa_model_loaded = self.all_models_loaded() 

 
    def create_bnb_config(self) -> BitsAndBytesConfig:
        """
        Creates a BitsAndBytes configuration based on the quantization setting.
        Returns:
            BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
        """
        if self.quantization == '4bit':
            return BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
        elif self.quantization == '8bit':
            return BitsAndBytesConfig(
                load_in_8bit=True,
                bnb_8bit_use_double_quant=True,
                bnb_8bit_quant_type="nf4",
                bnb_8bit_compute_dtype=torch.bfloat16
            )


    def load_caption_model(self):
        self.captioner = ImageCaptioningModel()
        self.captioner.load_model()

    def get_caption(self, img):

        return self.captioner.generate_caption(img)

    def load_detector(self, model):

        self.detector = ObjectDetector()
        self.detector.load_model(model)

    def detect_objects(self, img):
        image = self.detector.process_image(img)
        detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=self.detection_confidence)
        image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
        return image_with_boxes, detected_objects_string

    def load_fine_tuned_model(self):

        self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name, 
                                                                device_map="auto", 
                                                                low_cpu_mem_usage=True, 
                                                                quantization_config=self.bnb_config,
                                                                token=self.access_token)
        
        self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name, 
                                                             use_fast=self.use_fast, 
                                                             low_cpu_mem_usage=True, 
                                                             trust_remote_code=self.trust_remote, 
                                                             add_eos_token=self.add_eos_token,
                                                             token=self.access_token)


    @property
    def all_models_loaded(self):
        return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None

    def force_reload_model(self):
        free_gpu_resources()
        if self.kbvqa_model is not None:
            del self.kbvqa_model
        if self.captioner is not None:
            del self.captioner
        if self.detector is not None:
            del self.detector

        free_gpu_resources()

        

        
            



    def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):

        if sys_prompt is None:
            sys_prompt = "You are a helpful, respectful and honest assistant for visual question answering. you are provided with a caption of an image and a list of objects detected in the image along with their bounding boxes and level of certainty, you will output an answer to the given questions in no more than one sentence. Use logical reasoning to reach to the answer, but do not output your reasoning process unless asked for it. If provided, you will use the [CAP] and [/CAP] tags to indicate the begining and end of the caption respectively. If provided you will use the [OBJ] and [/OBJ] tags to indicate the begining and end of the list of detected objects in the image along with their bounding boxes respectively.if provided, you will use [QES] and [/QES] tags to indicate the begining and end of the question respectively."
    
        B_SENT = '<s>'
        E_SENT = '</s>'
        B_INST = '[INST]'
        E_INST = '[/INST]'
        B_SYS = '<<SYS>>\n'
        E_SYS = '\n<</SYS>>\n\n'
        B_CAP = '[CAP]'
        E_CAP = '[/CAP]'
        B_QES = '[QES]'
        E_QES = '[/QES]'
        B_OBJ = '[OBJ]'
        E_OBJ = '[/OBJ]'
    
    
        current_query = current_query.strip()
        sys_prompt = sys_prompt.strip()
        
        if history is None:
            if objects is None:
                p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
            else:
              p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
        else:
            p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
        
        
        return p
       

    def generate_answer(self, question, caption, detected_objects_str,):
        
        prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
        num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
        if num_tokens > self.max_context_window:
            st.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
            return

        model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
        input_ids = model_inputs["input_ids"]
        output_ids = self.kbvqa_model.generate(input_ids)
        index = input_ids.shape[1] # needed to avoid printing the input prompt
        history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
        output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)

        return output_text.capitalize()

def prepare_kbvqa_model(only_reload_detection_model=False):
    free_gpu_resources()
    kbvqa = KBVQA()
    kbvqa.detection_model = st.session_state.detection_model
    # Progress bar for model loading
    with st.spinner('Loading model... this should take no more than a few minutes.'):

        if not only_reload_detection_model:
            progress_bar = st.progress(0)
            
            kbvqa.load_detector(kbvqa.detection_model)
            progress_bar.progress(33)
            kbvqa.load_caption_model()
            free_gpu_resources()
            progress_bar.progress(66)
            kbvqa.load_fine_tuned_model()
            free_gpu_resources()
            progress_bar.progress(100)

        else:
            progress_bar = st.progress(0)
            kbvqa.load_detector(kbvqa.detection_model)
            progress_bar.progress(100)
    
    if kbvqa.all_models_loaded:
        st.success('Model loaded successfully and ready for inferecne!')
        kbvqa.kbvqa_model.eval()
        free_gpu_resources()
        return kbvqa