# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import os
import time
import json
import logging
import gc
import torch
from pathlib import Path
from trt_llama_api import TrtLlmAPI
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from collections import defaultdict
from llama_index import ServiceContext
from llama_index.llms.llama_utils import messages_to_prompt, completion_to_prompt
from llama_index import set_global_service_context
from faiss_vector_storage import FaissEmbeddingStorage
from ui.user_interface import MainInterface
app_config_file = 'config\\app_config.json'
model_config_file = 'config\\config.json'
preference_config_file = 'config\\preferences.json'
data_source = 'directory'
def read_config(file_name):
try:
with open(file_name, 'r') as file:
return json.load(file)
except FileNotFoundError:
print(f"The file {file_name} was not found.")
except json.JSONDecodeError:
print(f"There was an error decoding the JSON from the file {file_name}.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
return None
def get_model_config(config, model_name=None):
models = config["models"]["supported"]
selected_model = next((model for model in models if model["name"] == model_name), models[0])
return {
"model_path": os.path.join(os.getcwd(), selected_model["metadata"]["model_path"]),
"engine": selected_model["metadata"]["engine"],
"tokenizer_path": os.path.join(os.getcwd(), selected_model["metadata"]["tokenizer_path"]),
"max_new_tokens": selected_model["metadata"]["max_new_tokens"],
"max_input_token": selected_model["metadata"]["max_input_token"],
"temperature": selected_model["metadata"]["temperature"]
}
def get_data_path(config):
return os.path.join(os.getcwd(), config["dataset"]["path"])
# read the app specific config
app_config = read_config(app_config_file)
streaming = app_config["streaming"]
similarity_top_k = app_config["similarity_top_k"]
is_chat_engine = app_config["is_chat_engine"]
embedded_model_name = app_config["embedded_model"]
embedded_model = os.path.join(os.getcwd(), "model", embedded_model_name)
embedded_dimension = app_config["embedded_dimension"]
# read model specific config
selected_model_name = None
selected_data_directory = None
config = read_config(model_config_file)
if os.path.exists(preference_config_file):
perf_config = read_config(preference_config_file)
selected_model_name = perf_config.get('models', {}).get('selected')
selected_data_directory = perf_config.get('dataset', {}).get('path')
if selected_model_name == None:
selected_model_name = config["models"].get("selected")
model_config = get_model_config(config, selected_model_name)
trt_engine_path = model_config["model_path"]
trt_engine_name = model_config["engine"]
tokenizer_dir_path = model_config["tokenizer_path"]
data_dir = config["dataset"]["path"] if selected_data_directory == None else selected_data_directory
# create trt_llm engine object
llm = TrtLlmAPI(
model_path=model_config["model_path"],
engine_name=model_config["engine"],
tokenizer_dir=model_config["tokenizer_path"],
temperature=model_config["temperature"],
max_new_tokens=model_config["max_new_tokens"],
context_window=model_config["max_input_token"],
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
verbose=False
)
# create embeddings model object
embed_model = HuggingFaceEmbeddings(model_name=embedded_model)
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model,
context_window=model_config["max_input_token"], chunk_size=512,
chunk_overlap=200)
set_global_service_context(service_context)
def generate_inferance_engine(data, force_rewrite=False):
"""
Initialize and return a FAISS-based inference engine.
Args:
data: The directory where the data for the inference engine is located.
force_rewrite (bool): If True, force rewriting the index.
Returns:
The initialized inference engine.
Raises:
RuntimeError: If unable to generate the inference engine.
"""
try:
global engine
faiss_storage = FaissEmbeddingStorage(data_dir=data,
dimension=embedded_dimension)
faiss_storage.initialize_index(force_rewrite=force_rewrite)
engine = faiss_storage.get_engine(is_chat_engine=is_chat_engine, streaming=streaming,
similarity_top_k=similarity_top_k)
except Exception as e:
raise RuntimeError(f"Unable to generate the inference engine: {e}")
# load the vectorstore index
generate_inferance_engine(data_dir)
def call_llm_streamed(query):
partial_response = ""
response = llm.stream_complete(query)
for token in response:
partial_response += token.delta
yield partial_response
def chatbot(query, chat_history, session_id):
if data_source == "nodataset":
yield llm.complete(query).text
return
if is_chat_engine:
response = engine.chat(query)
else:
response = engine.query(query)
# Aggregate scores by file
file_scores = defaultdict(float)
for node in response.source_nodes:
metadata = node.metadata
if 'filename' in metadata:
file_name = metadata['filename']
file_scores[file_name] += node.score
# Find the file with the highest aggregated score
highest_aggregated_score_file = None
if file_scores:
highest_aggregated_score_file = max(file_scores, key=file_scores.get)
file_links = []
seen_files = set() # Set to track unique file names
# Generate links for the file with the highest aggregated score
if highest_aggregated_score_file:
abs_path = Path(os.path.join(os.getcwd(), highest_aggregated_score_file.replace('\\', '/')))
file_name = os.path.basename(abs_path)
file_name_without_ext = abs_path.stem
if file_name not in seen_files: # Ensure the file hasn't already been processed
if data_source == 'directory':
file_link = file_name
else:
exit("Wrong data_source type")
file_links.append(file_link)
seen_files.add(file_name) # Mark file as processed
response_txt = str(response)
if file_links:
response_txt += "
Reference files:
" + "
".join(file_links)
if not highest_aggregated_score_file: # If no file with a high score was found
response_txt = llm.complete(query).text
yield response_txt
def stream_chatbot(query, chat_history, session_id):
if data_source == "nodataset":
for response in call_llm_streamed(query):
yield response
return
if is_chat_engine:
response = engine.stream_chat(query)
else:
response = engine.query(query)
partial_response = ""
if len(response.source_nodes) == 0:
response = llm.stream_complete(query)
for token in response:
partial_response += token.delta
yield partial_response
else:
# Aggregate scores by file
file_scores = defaultdict(float)
for node in response.source_nodes:
if 'filename' in node.metadata:
file_name = node.metadata['filename']
file_scores[file_name] += node.score
# Find the file with the highest aggregated score
highest_score_file = max(file_scores, key=file_scores.get, default=None)
file_links = []
seen_files = set()
for token in response.response_gen:
partial_response += token
yield partial_response
time.sleep(0.05)
time.sleep(0.2)
if highest_score_file:
abs_path = Path(os.path.join(os.getcwd(), highest_score_file.replace('\\', '/')))
file_name = os.path.basename(abs_path)
file_name_without_ext = abs_path.stem
if file_name not in seen_files: # Check if file_name is already seen
if data_source == 'directory':
file_link = file_name
else:
exit("Wrong data_source type")
file_links.append(file_link)
seen_files.add(file_name) # Add file_name to the set
if file_links:
partial_response += "
Reference files:
" + "
".join(file_links)
yield partial_response
# call garbage collector after inference
torch.cuda.empty_cache()
gc.collect()
interface = MainInterface(chatbot=stream_chatbot if streaming else chatbot, streaming=streaming)
def on_shutdown_handler(session_id):
global llm, service_context, embed_model, faiss_storage, engine
import gc
if llm is not None:
llm.unload_model()
del llm
# Force a garbage collection cycle
gc.collect()
interface.on_shutdown(on_shutdown_handler)
def reset_chat_handler(session_id):
global faiss_storage
global engine
print('reset chat called', session_id)
if is_chat_engine == True:
faiss_storage.reset_engine(engine)
interface.on_reset_chat(reset_chat_handler)
def on_dataset_path_updated_handler(source, new_directory, video_count, session_id):
print('data set path updated to ', source, new_directory, video_count, session_id)
global engine
global data_dir
if source == 'directory':
if data_dir != new_directory:
data_dir = new_directory
generate_inferance_engine(data_dir)
interface.on_dataset_path_updated(on_dataset_path_updated_handler)
def on_model_change_handler(model, metadata, session_id):
model_path = os.path.join(os.getcwd(), metadata.get('model_path', None))
engine_name = metadata.get('engine', None)
tokenizer_path = os.path.join(os.getcwd(), metadata.get('tokenizer_path', None))
if not model_path or not engine_name:
print("Model path or engine not provided in metadata")
return
global llm, embedded_model, engine, data_dir, service_context
if llm is not None:
llm.unload_model()
del llm
llm = TrtLlmAPI(
model_path=model_path,
engine_name=engine_name,
tokenizer_dir=tokenizer_path,
temperature=metadata.get('temperature', 0.1),
max_new_tokens=metadata.get('max_new_tokens', 512),
context_window=metadata.get('max_input_token', 512),
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
verbose=False
)
service_context = ServiceContext.from_service_context(service_context=service_context, llm=llm)
set_global_service_context(service_context)
generate_inferance_engine(data_dir)
interface.on_model_change(on_model_change_handler)
def on_dataset_source_change_handler(source, path, session_id):
global data_source, data_dir, engine
data_source = source
if data_source == "nodataset":
print(' No dataset source selected', session_id)
return
print('dataset source updated ', source, path, session_id)
if data_source == "directory":
data_dir = path
else:
print("Wrong data type selected")
generate_inferance_engine(data_dir)
interface.on_dataset_source_updated(on_dataset_source_change_handler)
def handle_regenerate_index(source, path, session_id):
generate_inferance_engine(path, force_rewrite=True)
print("on regenerate index", source, path, session_id)
interface.on_regenerate_index(handle_regenerate_index)
# render the interface
interface.render()