Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import ctypes | |
import pathlib | |
from typing import Optional, List | |
import enum | |
from pathlib import Path | |
class DataType(enum.IntEnum): | |
def __str__(self): | |
return str(self.name) | |
F16 = 0 | |
F32 = 1 | |
I32 = 2 | |
L64 = 3 | |
Q4_0 = 4 | |
Q4_1 = 5 | |
Q5_0 = 6 | |
Q5_1 = 7 | |
Q8_0 = 8 | |
Q8_1 = 9 | |
Q2_K = 10 | |
Q3_K = 11 | |
Q4_K = 12 | |
Q5_K = 13 | |
Q6_K = 14 | |
Q8_K = 15 | |
class Verbosity(enum.IntEnum): | |
SILENT = 0 | |
ERR = 1 | |
INFO = 2 | |
DEBUG = 3 | |
class ImageFormat(enum.IntEnum): | |
UNKNOWN = 0 | |
F32 = 1 | |
U8 = 2 | |
I32 = ctypes.c_int32 | |
U32 = ctypes.c_uint32 | |
F32 = ctypes.c_float | |
SIZE_T = ctypes.c_size_t | |
VOID_PTR = ctypes.c_void_p | |
CHAR_PTR = ctypes.POINTER(ctypes.c_char) | |
FLOAT_PTR = ctypes.POINTER(ctypes.c_float) | |
INT_PTR = ctypes.POINTER(ctypes.c_int32) | |
CHAR_PTR_PTR = ctypes.POINTER(ctypes.POINTER(ctypes.c_char)) | |
MiniGPT4ContextP = VOID_PTR | |
class MiniGPT4Context: | |
def __init__(self, ptr: ctypes.pointer): | |
self.ptr = ptr | |
class MiniGPT4Image(ctypes.Structure): | |
_fields_ = [ | |
('data', VOID_PTR), | |
('width', I32), | |
('height', I32), | |
('channels', I32), | |
('format', I32) | |
] | |
class MiniGPT4Embedding(ctypes.Structure): | |
_fields_ = [ | |
('data', FLOAT_PTR), | |
('n_embeddings', SIZE_T), | |
] | |
MiniGPT4ImageP = ctypes.POINTER(MiniGPT4Image) | |
MiniGPT4EmbeddingP = ctypes.POINTER(MiniGPT4Embedding) | |
class MiniGPT4SharedLibrary: | |
""" | |
Python wrapper around minigpt4.cpp shared library. | |
""" | |
def __init__(self, shared_library_path: str): | |
""" | |
Loads the shared library from specified file. | |
In case of any error, this method will throw an exception. | |
Parameters | |
---------- | |
shared_library_path : str | |
Path to minigpt4.cpp shared library. On Windows, it would look like 'minigpt4.dll'. On UNIX, 'minigpt4.so'. | |
""" | |
self.library = ctypes.cdll.LoadLibrary(shared_library_path) | |
self.library.minigpt4_model_load.argtypes = [ | |
CHAR_PTR, # const char *path | |
CHAR_PTR, # const char *llm_model | |
I32, # int verbosity | |
I32, # int seed | |
I32, # int n_ctx | |
I32, # int n_batch | |
I32, # int numa | |
] | |
self.library.minigpt4_model_load.restype = MiniGPT4ContextP | |
self.library.minigpt4_image_load_from_file.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
CHAR_PTR, # const char *path | |
MiniGPT4ImageP, # struct MiniGPT4Image *image | |
I32, # int flags | |
] | |
self.library.minigpt4_image_load_from_file.restype = I32 | |
self.library.minigpt4_encode_image.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
MiniGPT4ImageP, # const struct MiniGPT4Image *image | |
MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding | |
I32, # size_t n_threads | |
] | |
self.library.minigpt4_encode_image.restype = I32 | |
self.library.minigpt4_begin_chat_image.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding | |
CHAR_PTR, # const char *s | |
I32, # size_t n_threads | |
] | |
self.library.minigpt4_begin_chat_image.restype = I32 | |
self.library.minigpt4_end_chat_image.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
CHAR_PTR_PTR, # const char **token | |
I32, # size_t n_threads | |
F32, # float temp | |
I32, # int32_t top_k | |
F32, # float top_p | |
F32, # float tfs_z | |
F32, # float typical_p | |
I32, # int32_t repeat_last_n | |
F32, # float repeat_penalty | |
F32, # float alpha_presence | |
F32, # float alpha_frequency | |
I32, # int mirostat | |
F32, # float mirostat_tau | |
F32, # float mirostat_eta | |
I32, # int penalize_nl | |
] | |
self.library.minigpt4_end_chat_image.restype = I32 | |
self.library.minigpt4_system_prompt.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
I32, # size_t n_threads | |
] | |
self.library.minigpt4_system_prompt.restype = I32 | |
self.library.minigpt4_begin_chat.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
CHAR_PTR, # const char *s | |
I32, # size_t n_threads | |
] | |
self.library.minigpt4_begin_chat.restype = I32 | |
self.library.minigpt4_end_chat.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
CHAR_PTR_PTR, # const char **token | |
I32, # size_t n_threads | |
F32, # float temp | |
I32, # int32_t top_k | |
F32, # float top_p | |
F32, # float tfs_z | |
F32, # float typical_p | |
I32, # int32_t repeat_last_n | |
F32, # float repeat_penalty | |
F32, # float alpha_presence | |
F32, # float alpha_frequency | |
I32, # int mirostat | |
F32, # float mirostat_tau | |
F32, # float mirostat_eta | |
I32, # int penalize_nl | |
] | |
self.library.minigpt4_end_chat.restype = I32 | |
self.library.minigpt4_reset_chat.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
] | |
self.library.minigpt4_reset_chat.restype = I32 | |
self.library.minigpt4_contains_eos_token.argtypes = [ | |
CHAR_PTR, # const char *s | |
] | |
self.library.minigpt4_contains_eos_token.restype = I32 | |
self.library.minigpt4_is_eos.argtypes = [ | |
CHAR_PTR, # const char *s | |
] | |
self.library.minigpt4_is_eos.restype = I32 | |
self.library.minigpt4_free.argtypes = [ | |
MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
] | |
self.library.minigpt4_free.restype = I32 | |
self.library.minigpt4_free_image.argtypes = [ | |
MiniGPT4ImageP, # struct MiniGPT4Image *image | |
] | |
self.library.minigpt4_free_image.restype = I32 | |
self.library.minigpt4_free_embedding.argtypes = [ | |
MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding | |
] | |
self.library.minigpt4_free_embedding.restype = I32 | |
self.library.minigpt4_error_code_to_string.argtypes = [ | |
I32, # int error_code | |
] | |
self.library.minigpt4_error_code_to_string.restype = CHAR_PTR | |
self.library.minigpt4_quantize_model.argtypes = [ | |
CHAR_PTR, # const char *in_path | |
CHAR_PTR, # const char *out_path | |
I32, # int data_type | |
] | |
self.library.minigpt4_quantize_model.restype = I32 | |
self.library.minigpt4_set_verbosity.argtypes = [ | |
I32, # int verbosity | |
] | |
self.library.minigpt4_set_verbosity.restype = None | |
def panic_if_error(self, error_code: int) -> None: | |
""" | |
Raises an exception if the error code is not 0. | |
Parameters | |
---------- | |
error_code : int | |
Error code to check. | |
""" | |
if error_code != 0: | |
raise RuntimeError(self.library.minigpt4_error_code_to_string(I32(error_code))) | |
def minigpt4_model_load(self, model_path: str, llm_model_path: str, verbosity: int = 1, seed: int = 1337, n_ctx: int = 2048, n_batch: int = 512, numa: int = 0) -> MiniGPT4Context: | |
""" | |
Loads a model from a file. | |
Args: | |
model_path (str): Path to model file. | |
llm_model_path (str): Path to LLM model file. | |
verbosity (int): Verbosity level: 0 = silent, 1 = error, 2 = info, 3 = debug. Defaults to 0. | |
n_ctx (int): Size of context for llm model. Defaults to 2048. | |
seed (int): Seed for llm model. Defaults to 1337. | |
numa (int): NUMA node to use (0 = NUMA disabled, 1 = NUMA enabled). Defaults to 0. | |
Returns: | |
MiniGPT4Context: Context. | |
""" | |
ptr = self.library.minigpt4_model_load( | |
model_path.encode('utf-8'), | |
llm_model_path.encode('utf-8'), | |
I32(verbosity), | |
I32(seed), | |
I32(n_ctx), | |
I32(n_batch), | |
I32(numa), | |
) | |
assert ptr is not None, 'minigpt4_model_load failed' | |
return MiniGPT4Context(ptr) | |
def minigpt4_image_load_from_file(self, ctx: MiniGPT4Context, path: str, flags: int) -> MiniGPT4Image: | |
""" | |
Loads an image from a file | |
Args: | |
ctx (MiniGPT4Context): context | |
path (str): path | |
flags (int): flags | |
Returns: | |
MiniGPT4Image: image | |
""" | |
image = MiniGPT4Image() | |
self.panic_if_error(self.library.minigpt4_image_load_from_file(ctx.ptr, path.encode('utf-8'), ctypes.pointer(image), I32(flags))) | |
return image | |
def minigpt4_preprocess_image(self, ctx: MiniGPT4Context, image: MiniGPT4Image, flags: int = 0) -> MiniGPT4Image: | |
""" | |
Preprocesses an image | |
Args: | |
ctx (MiniGPT4Context): Context | |
image (MiniGPT4Image): Image | |
flags (int): Flags. Defaults to 0. | |
Returns: | |
MiniGPT4Image: Preprocessed image | |
""" | |
preprocessed_image = MiniGPT4Image() | |
self.panic_if_error(self.library.minigpt4_preprocess_image(ctx.ptr, ctypes.pointer(image), ctypes.pointer(preprocessed_image), I32(flags))) | |
return preprocessed_image | |
def minigpt4_encode_image(self, ctx: MiniGPT4Context, image: MiniGPT4Image, n_threads: int = 0) -> MiniGPT4Embedding: | |
""" | |
Encodes an image into embedding | |
Args: | |
ctx (MiniGPT4Context): Context. | |
image (MiniGPT4Image): Image. | |
n_threads (int): Number of threads to use, if 0, uses all available. Defaults to 0. | |
Returns: | |
embedding (MiniGPT4Embedding): Output embedding. | |
""" | |
embedding = MiniGPT4Embedding() | |
self.panic_if_error(self.library.minigpt4_encode_image(ctx.ptr, ctypes.pointer(image), ctypes.pointer(embedding), n_threads)) | |
return embedding | |
def minigpt4_begin_chat_image(self, ctx: MiniGPT4Context, image_embedding: MiniGPT4Embedding, s: str, n_threads: int = 0): | |
""" | |
Begins a chat with an image. | |
Args: | |
ctx (MiniGPT4Context): Context. | |
image_embedding (MiniGPT4Embedding): Image embedding. | |
s (str): Question to ask about the image. | |
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
Returns: | |
None | |
""" | |
self.panic_if_error(self.library.minigpt4_begin_chat_image(ctx.ptr, ctypes.pointer(image_embedding), s.encode('utf-8'), n_threads)) | |
def minigpt4_end_chat_image(self, ctx: MiniGPT4Context, n_threads: int = 0, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1) -> str: | |
""" | |
Ends a chat with an image. | |
Args: | |
ctx (MiniGPT4Context): Context. | |
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
temp (float, optional): Temperature. Defaults to 0.8. | |
top_k (int, optional): Top K. Defaults to 40. | |
top_p (float, optional): Top P. Defaults to 0.9. | |
tfs_z (float, optional): Tfs Z. Defaults to 1.0. | |
typical_p (float, optional): Typical P. Defaults to 1.0. | |
repeat_last_n (int, optional): Repeat last N. Defaults to 64. | |
repeat_penalty (float, optional): Repeat penality. Defaults to 1.1. | |
alpha_presence (float, optional): Alpha presence. Defaults to 1.0. | |
alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0. | |
mirostat (int, optional): Mirostat. Defaults to 0. | |
mirostat_tau (float, optional): Mirostat Tau. Defaults to 5.0. | |
mirostat_eta (float, optional): Mirostat Eta. Defaults to 1.0. | |
penalize_nl (int, optional): Penalize NL. Defaults to 1. | |
Returns: | |
str: Token generated. | |
""" | |
token = CHAR_PTR() | |
self.panic_if_error(self.library.minigpt4_end_chat_image(ctx.ptr, ctypes.pointer(token), n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl)) | |
return ctypes.cast(token, ctypes.c_char_p).value.decode('utf-8') | |
def minigpt4_system_prompt(self, ctx: MiniGPT4Context, n_threads: int = 0): | |
""" | |
Generates a system prompt. | |
Args: | |
ctx (MiniGPT4Context): Context. | |
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
""" | |
self.panic_if_error(self.library.minigpt4_system_prompt(ctx.ptr, n_threads)) | |
def minigpt4_begin_chat(self, ctx: MiniGPT4Context, s: str, n_threads: int = 0): | |
""" | |
Begins a chat continuing after minigpt4_begin_chat_image | |
Args: | |
ctx (MiniGPT4Context): Context. | |
s (str): Question to ask about the image. | |
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
Returns: | |
None | |
""" | |
self.panic_if_error(self.library.minigpt4_begin_chat(ctx.ptr, s.encode('utf-8'), n_threads)) | |
def minigpt4_end_chat(self, ctx: MiniGPT4Context, n_threads: int = 0, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1) -> str: | |
""" | |
Ends a chat. | |
Args: | |
ctx (MiniGPT4Context): Context. | |
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
temp (float, optional): Temperature. Defaults to 0.8. | |
top_k (int, optional): Top K. Defaults to 40. | |
top_p (float, optional): Top P. Defaults to 0.9. | |
tfs_z (float, optional): Tfs Z. Defaults to 1.0. | |
typical_p (float, optional): Typical P. Defaults to 1.0. | |
repeat_last_n (int, optional): Repeat last N. Defaults to 64. | |
repeat_penalty (float, optional): Repeat penality. Defaults to 1.1. | |
alpha_presence (float, optional): Alpha presence. Defaults to 1.0. | |
alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0. | |
mirostat (int, optional): Mirostat. Defaults to 0. | |
mirostat_tau (float, optional): Mirostat Tau. Defaults to 5.0. | |
mirostat_eta (float, optional): Mirostat Eta. Defaults to 1.0. | |
penalize_nl (int, optional): Penalize NL. Defaults to 1. | |
Returns: | |
str: Token generated. | |
""" | |
token = CHAR_PTR() | |
self.panic_if_error(self.library.minigpt4_end_chat(ctx.ptr, ctypes.pointer(token), n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl)) | |
return ctypes.cast(token, ctypes.c_char_p).value.decode('utf-8') | |
def minigpt4_reset_chat(self, ctx: MiniGPT4Context): | |
""" | |
Resets the chat. | |
Args: | |
ctx (MiniGPT4Context): Context. | |
""" | |
self.panic_if_error(self.library.minigpt4_reset_chat(ctx.ptr)) | |
def minigpt4_contains_eos_token(self, s: str) -> bool: | |
""" | |
Checks if a string contains an EOS token. | |
Args: | |
s (str): String to check. | |
Returns: | |
bool: True if the string contains an EOS token, False otherwise. | |
""" | |
return self.library.minigpt4_contains_eos_token(s.encode('utf-8')) | |
def minigpt4_is_eos(self, s: str) -> bool: | |
""" | |
Checks if a string is EOS. | |
Args: | |
s (str): String to check. | |
Returns: | |
bool: True if the string contains an EOS, False otherwise. | |
""" | |
return self.library.minigpt4_is_eos(s.encode('utf-8')) | |
def minigpt4_free(self, ctx: MiniGPT4Context) -> None: | |
""" | |
Frees a context. | |
Args: | |
ctx (MiniGPT4Context): Context. | |
""" | |
self.panic_if_error(self.library.minigpt4_free(ctx.ptr)) | |
def minigpt4_free_image(self, image: MiniGPT4Image) -> None: | |
""" | |
Frees an image. | |
Args: | |
image (MiniGPT4Image): Image. | |
""" | |
self.panic_if_error(self.library.minigpt4_free_image(ctypes.pointer(image))) | |
def minigpt4_free_embedding(self, embedding: MiniGPT4Embedding) -> None: | |
""" | |
Frees an embedding. | |
Args: | |
embedding (MiniGPT4Embedding): Embedding. | |
""" | |
self.panic_if_error(self.library.minigpt4_free_embedding(ctypes.pointer(embedding))) | |
def minigpt4_error_code_to_string(self, error_code: int) -> str: | |
""" | |
Converts an error code to a string. | |
Args: | |
error_code (int): Error code. | |
Returns: | |
str: Error string. | |
""" | |
return self.library.minigpt4_error_code_to_string(error_code).decode('utf-8') | |
def minigpt4_quantize_model(self, in_path: str, out_path: str, data_type: DataType): | |
""" | |
Quantizes a model file. | |
Args: | |
in_path (str): Path to input model file. | |
out_path (str): Path to write output model file. | |
data_type (DataType): Must be one DataType enum values. | |
""" | |
self.panic_if_error(self.library.minigpt4_quantize_model(in_path.encode('utf-8'), out_path.encode('utf-8'), data_type)) | |
def minigpt4_set_verbosity(self, verbosity: Verbosity): | |
""" | |
Sets verbosity. | |
Args: | |
verbosity (int): Verbosity. | |
""" | |
self.library.minigpt4_set_verbosity(I32(verbosity)) | |
def load_library() -> MiniGPT4SharedLibrary: | |
""" | |
Attempts to find minigpt4.cpp shared library and load it. | |
""" | |
file_name: str | |
if 'win32' in sys.platform or 'cygwin' in sys.platform: | |
file_name = 'minigpt4.dll' | |
elif 'darwin' in sys.platform: | |
file_name = 'libminigpt4.dylib' | |
else: | |
file_name = 'libminigpt4.so' | |
cwd = pathlib.Path(os.getcwd()) | |
repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent | |
paths = [ | |
# If we are in "minigpt4" directory | |
f'../bin/Release/{file_name}', | |
# If we are in repo root directory | |
f'bin/Release/{file_name}', | |
# If we compiled in build directory | |
f'build/bin/Release/{file_name}', | |
# If we compiled in build directory | |
f'build/{file_name}', | |
f'../build/{file_name}', | |
# Search relative to this file | |
str(repo_root_dir / 'bin' / 'Release' / file_name), | |
# Fallback | |
str(repo_root_dir / file_name), | |
str(cwd / file_name) | |
] | |
for path in paths: | |
if os.path.isfile(path): | |
return MiniGPT4SharedLibrary(path) | |
return MiniGPT4SharedLibrary(paths[-1]) | |
class MiniGPT4ChatBot: | |
def __init__(self, model_path: str, llm_model_path: str, verbosity: Verbosity = Verbosity.SILENT, n_threads: int = 0): | |
""" | |
Creates a new MiniGPT4ChatBot instance. | |
Args: | |
model_path (str): Path to model file. | |
llm_model_path (str): Path to language model model file. | |
verbosity (Verbosity, optional): Verbosity. Defaults to Verbosity.SILENT. | |
n_threads (int, optional): Number of threads to use. Defaults to 0. | |
""" | |
self.library = load_library() | |
self.ctx = self.library.minigpt4_model_load(model_path, llm_model_path, verbosity) | |
self.n_threads = n_threads | |
from PIL import Image | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
self.image_size = 224 | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
std = (0.26862954, 0.26130258, 0.27577711) | |
self.transform = transforms.Compose( | |
[ | |
transforms.RandomResizedCrop( | |
self.image_size, | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
transforms.ToTensor(), | |
transforms.Normalize(mean, std) | |
] | |
) | |
self.embedding: Optional[MiniGPT4Embedding] = None | |
self.is_image_chat = False | |
self.chat_history = [] | |
def free(self): | |
if self.ctx: | |
self.library.minigpt4_free(self.ctx) | |
def generate(self, message: str, limit: int = 1024, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1): | |
""" | |
Generates a chat response. | |
Args: | |
message (str): Message. | |
limit (int, optional): Limit. Defaults to 1024. | |
temp (float, optional): Temperature. Defaults to 0.8. | |
top_k (int, optional): Top K. Defaults to 40. | |
top_p (float, optional): Top P. Defaults to 0.9. | |
tfs_z (float, optional): TFS Z. Defaults to 1.0. | |
typical_p (float, optional): Typical P. Defaults to 1.0. | |
repeat_last_n (int, optional): Repeat last N. Defaults to 64. | |
repeat_penalty (float, optional): Repeat penalty. Defaults to 1.1. | |
alpha_presence (float, optional): Alpha presence. Defaults to 1.0. | |
alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0. | |
mirostat (int, optional): Mirostat. Defaults to 0. | |
mirostat_tau (float, optional): Mirostat tau. Defaults to 5.0. | |
mirostat_eta (float, optional): Mirostat eta. Defaults to 1.0. | |
penalize_nl (int, optional): Penalize NL. Defaults to 1. | |
""" | |
if self.is_image_chat: | |
self.is_image_chat = False | |
self.library.minigpt4_begin_chat_image(self.ctx, self.embedding, message, self.n_threads) | |
chat = '' | |
for _ in range(limit): | |
token = self.library.minigpt4_end_chat_image(self.ctx, self.n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl) | |
chat += token | |
if self.library.minigpt4_contains_eos_token(token): | |
continue | |
if self.library.minigpt4_is_eos(chat): | |
break | |
yield token | |
else: | |
self.library.minigpt4_begin_chat(self.ctx, message, self.n_threads) | |
chat = '' | |
for _ in range(limit): | |
token = self.library.minigpt4_end_chat(self.ctx, self.n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl) | |
chat += token | |
if self.library.minigpt4_contains_eos_token(token): | |
continue | |
if self.library.minigpt4_is_eos(chat): | |
break | |
yield token | |
def reset_chat(self): | |
""" | |
Resets the chat. | |
""" | |
self.is_image_chat = False | |
if self.embedding: | |
self.library.minigpt4_free_embedding(self.embedding) | |
self.embedding = None | |
self.library.minigpt4_reset_chat(self.ctx) | |
self.library.minigpt4_system_prompt(self.ctx, self.n_threads) | |
def upload_image(self, image): | |
""" | |
Uploads an image. | |
Args: | |
image (Image): Image. | |
""" | |
self.reset_chat() | |
image = self.transform(image) | |
image = image.unsqueeze(0) | |
image = image.numpy() | |
image = image.ctypes.data_as(ctypes.c_void_p) | |
minigpt4_image = MiniGPT4Image(image, self.image_size, self.image_size, 3, ImageFormat.F32) | |
self.embedding = self.library.minigpt4_encode_image(self.ctx, minigpt4_image, self.n_threads) | |
self.is_image_chat = True | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description='Test loading minigpt4') | |
parser.add_argument('model_path', help='Path to model file') | |
parser.add_argument('llm_model_path', help='Path to llm model file') | |
parser.add_argument('-i', '--image_path', help='Image to test', default='images/llama.png') | |
parser.add_argument('-p', '--prompts', help='Text to test', default='what is the text in the picture?,what is the color of it?') | |
args = parser.parse_args() | |
model_path = args.model_path | |
llm_model_path = args.llm_model_path | |
image_path = args.image_path | |
prompts = args.prompts | |
if not Path(model_path).exists(): | |
print(f'Model does not exist: {model_path}') | |
exit(1) | |
if not Path(llm_model_path).exists(): | |
print(f'LLM Model does not exist: {llm_model_path}') | |
exit(1) | |
prompts = prompts.split(',') | |
print('Loading minigpt4 shared library...') | |
library = load_library() | |
print(f'Loaded library {library}') | |
ctx = library.minigpt4_model_load(model_path, llm_model_path, Verbosity.DEBUG) | |
image = library.minigpt4_image_load_from_file(ctx, image_path, 0) | |
preprocessed_image = library.minigpt4_preprocess_image(ctx, image, 0) | |
question = prompts[0] | |
n_threads = 0 | |
embedding = library.minigpt4_encode_image(ctx, preprocessed_image, n_threads) | |
library.minigpt4_system_prompt(ctx, n_threads) | |
library.minigpt4_begin_chat_image(ctx, embedding, question, n_threads) | |
chat = '' | |
while True: | |
token = library.minigpt4_end_chat_image(ctx, n_threads) | |
chat += token | |
if library.minigpt4_contains_eos_token(token): | |
continue | |
if library.minigpt4_is_eos(chat): | |
break | |
print(token, end='') | |
for i in range(1, len(prompts)): | |
prompt = prompts[i] | |
library.minigpt4_begin_chat(ctx, prompt, n_threads) | |
chat = '' | |
while True: | |
token = library.minigpt4_end_chat(ctx, n_threads) | |
chat += token | |
if library.minigpt4_contains_eos_token(token): | |
continue | |
if library.minigpt4_is_eos(chat): | |
break | |
print(token, end='') | |
library.minigpt4_free_image(image) | |
library.minigpt4_free_image(preprocessed_image) | |
library.minigpt4_free(ctx) | |