File size: 7,026 Bytes
75a53d9 7553f0c 75a53d9 b9d4498 75a53d9 b434799 ab38e0e 5f4a46b 75a53d9 b9d4498 75a53d9 178416a 75a53d9 b9d4498 398c0e8 75a53d9 1089b06 75a53d9 609d6f1 75a53d9 f711846 75a53d9 609d6f1 75a53d9 609d6f1 b9d4498 75a53d9 b9d4498 609d6f1 aefece3 15d3f2d 54d3921 75a53d9 609d6f1 75a53d9 b9d4498 75a53d9 8f97cdd b9d4498 8f97cdd 609d6f1 8f97cdd 609d6f1 8f97cdd 609d6f1 8f97cdd 5f4a46b 609d6f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import io
import torch
import PIL
from PIL import Image
from typing import Optional, Union, List
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import bitsandbytes
import accelerate
from my_model.config import captioning_config as config
from my_model.utilities.gen_utilities import free_gpu_resources
class ImageCaptioningModel:
"""
A class to handle image captioning using InstructBlip model.
Attributes:
model_type (str): Type of the model to use.
processor (InstructBlipProcessor or None): The processor for handling image input.
model (InstructBlipForConditionalGeneration or None): The loaded model.
prompt (str): Prompt for the model.
max_image_size (int): Maximum size for the input image.
min_length (int): Minimum length of the generated caption.
max_new_tokens (int): Maximum number of new tokens to generate.
model_path (str): Path to the pre-trained model.
device_map (str): Device map for model loading.
torch_dtype (torch.dtype): Data type for torch tensors.
load_in_8bit (bool): Whether to load the model in 8-bit precision.
load_in_4bit (bool): Whether to load the model in 4-bit precision.
low_cpu_mem_usage (bool): Whether to optimize for low CPU memory usage.
skip_special_tokens (bool): Whether to skip special tokens in the generated captions.
"""
def __init__(self) -> None:
"""
Initializes the ImageCaptioningModel class with configuration settings.
"""
self.model_type = config.MODEL_TYPE
self.processor = None
self.model = None
self.prompt = config.PROMPT
self.max_image_size = config.MAX_IMAGE_SIZE
self.min_length = config.MIN_LENGTH
self.max_new_tokens = config.MAX_NEW_TOKENS
self.model_path = config.MODEL_PATH
self.device_map = config.DEVICE_MAP
self.torch_dtype = config.TORCH_DTYPE
self.load_in_8bit = config.LOAD_IN_8BIT
self.load_in_4bit = config.LOAD_IN_4BIT
self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE
self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS
def load_model(self) -> None:
"""
Loads the InstructBlip model and processor based on the specified configuration.
"""
if self.load_in_4bit and self.load_in_8bit: # Ensure only one of 4-bit or 8-bit precision is used.
self.load_in_4bit = False
if self.model_type == 'i_blip':
self.processor = InstructBlipProcessor.from_pretrained(self.model_path,
load_in_8bit=self.load_in_8bit,
load_in_4bit=self.load_in_4bit,
torch_dtype=self.torch_dtype,
device_map=self.device_map
)
free_gpu_resources()
self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path,
load_in_8bit=self.load_in_8bit,
load_in_4bit=self.load_in_4bit,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=self.low_cpu_mem_usage,
device_map=self.device_map
)
free_gpu_resources()
def resize_image(self, image: Image.Image, max_image_size: Optional[int] = None) -> Image.Image:
"""
Resizes the image to fit within the specified maximum size while maintaining aspect ratio.
Args:
image (Image.Image): The input image to resize.
max_image_size (Optional[int]): The maximum size for the resized image. Defaults to None.
Returns:
Image.Image: The resized image.
"""
if max_image_size is None:
max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
h, w = image.size
scale = max_image_size / max(h, w)
if scale < 1:
new_w = int(w * scale)
new_h = int(h * scale)
image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS)
return image
def generate_caption(self, image_path: Union[str, io.IOBase, Image.Image]) -> str:
"""
Generates a caption for the given image.
Args:
image_path (Union[str, io.IOBase, Image.Image]): The path to the image, file-like object, or PIL Image.
Returns:
str: The generated caption for the image.
"""
free_gpu_resources()
free_gpu_resources()
if isinstance(image_path, str) or isinstance(image_path, io.IOBase):
# If it's a file path or file-like object, open it as a PIL Image
image = Image.open(image_path)
elif isinstance(image_path, Image.Image):
image = image_path
image = self.resize_image(image)
inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype)
outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens)
caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip()
free_gpu_resources()
free_gpu_resources()
return caption
def generate_captions_for_multiple_images(self, image_paths: List[Union[str, io.IOBase, Image.Image]]) -> List[str]:
"""
Generates captions for multiple images.
Args:
image_paths (List[Union[str, io.IOBase, Image.Image]]): A list of paths to images, file-like objects, or PIL Images.
Returns:
List[str]: A list of captions for the provided images.
"""
return [self.generate_caption(image_path) for image_path in image_paths]
def get_caption(img: Union[str, io.IOBase, Image.Image]) -> str:
"""
Loads the captioning model and generates a caption for a single image.
Args:
img (Union[str, io.IOBase, Image.Image]): The path to the image, file-like object, or PIL Image.
Returns:
str: The generated caption for the image.
"""
captioner = ImageCaptioningModel()
free_gpu_resources()
captioner.load_model()
free_gpu_resources()
caption = captioner.generate_caption(img)
free_gpu_resources()
return caption |