Spaces:
Runtime error
Runtime error
import argparse | |
import logging | |
import os | |
import sys | |
from dotenv import load_dotenv | |
from knowledgebase import Knowledgebase | |
from utils.constants import ( | |
ENV_FILE, | |
ASSISTANT_TYPE_KEY, | |
AssistantType, | |
OPENAI_API_TOKEN_KEY, | |
OPENAI_KNOWLEDGEBASE_KEY, | |
HUGGINGFACEHUB_API_TOKEN_KEY, | |
HF_KNOWLEDGEBASE_KEY, | |
QUERY_TAG, | |
ANSWER_TAG, | |
SOURCES_TAG, | |
EMBEDDING_TYPE_KEY, | |
APIKeyType, | |
EmbeddingType, | |
) | |
from utils.llm import validate_api_token | |
# load the .env | |
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ENV_FILE)) | |
logger = logging.getLogger(__name__) | |
if __name__ == "__main__": | |
assistant_type = os.getenv(ASSISTANT_TYPE_KEY, AssistantType.HUGGINGFACE.value) | |
embedding_type = os.getenv(EMBEDDING_TYPE_KEY, EmbeddingType.HUGGINGFACE.value) | |
if assistant_type == AssistantType.OPENAI.value: | |
assistant_type = AssistantType.OPENAI | |
assistant_api_key = os.environ.get(OPENAI_API_TOKEN_KEY, None) | |
assistant_api_key_type = APIKeyType.OPENAI | |
knowledgebase_name = os.environ.get(OPENAI_KNOWLEDGEBASE_KEY, None) | |
if embedding_type == EmbeddingType.OPENAI.value: | |
embedding_type = EmbeddingType.OPENAI | |
embedding_api_key = assistant_api_key | |
embedding_api_key_type = APIKeyType.OPENAI | |
else: | |
embedding_type = EmbeddingType.HUGGINGFACE | |
embedding_api_key = os.getenv(HUGGINGFACEHUB_API_TOKEN_KEY, None) | |
embedding_api_key_type = APIKeyType.HUGGINGFACE | |
else: | |
assistant_type = AssistantType.HUGGINGFACE | |
assistant_api_key = os.environ.get(HUGGINGFACEHUB_API_TOKEN_KEY, None) | |
assistant_api_key_type = APIKeyType.HUGGINGFACE | |
knowledgebase_name = os.environ.get(HF_KNOWLEDGEBASE_KEY, None) | |
embedding_type = EmbeddingType.HUGGINGFACE | |
embedding_api_key = assistant_api_key | |
embedding_api_key_type = APIKeyType.HUGGINGFACE | |
logger.info("🗝️ Validating the API tokens...") | |
assistant_valid, assistant_err = validate_api_token( | |
api_key_type=assistant_api_key_type, api_key=assistant_api_key | |
) | |
if not assistant_valid: | |
logger.error(assistant_err) | |
sys.exit(1) | |
embedding_valid, embedding_err = validate_api_token( | |
api_key_type=embedding_api_key_type, api_key=embedding_api_key | |
) | |
if not embedding_valid: | |
logger.error(embedding_err) | |
sys.exit(1) | |
parser = argparse.ArgumentParser(description="LLM Website QA - CLI") | |
parser.add_argument( | |
QUERY_TAG, type=str, help="Question to be asked from the assistant" | |
) | |
args = parser.parse_args() | |
query = args.query | |
knowledgebase = Knowledgebase( | |
assistant_type=assistant_type, | |
embedding_type=embedding_type, | |
assistant_api_key=assistant_api_key, | |
embedding_api_key=embedding_api_key, | |
knowledgebase_name=knowledgebase_name, | |
) | |
result, metadata = knowledgebase.query_knowledgebase(query=query) | |
print(f"\nAnswer: \n{str(result.get(ANSWER_TAG, '').strip())}") | |
print(f"\nSources: \n{str(result.get(SOURCES_TAG, '').strip())}") | |
print(f"\nCost: \n{metadata}") | |