weizechen's picture
fix openai api base bug
1c62b4b
import logging
import json
import ast
import os
import numpy as np
from aiohttp import ClientSession
from typing import Dict, List, Optional, Union
from tenacity import retry, stop_after_attempt, wait_exponential, RetryCallState
from pydantic import BaseModel, Field
from agentverse.llms.base import LLMResult
from agentverse.logging import logger
from agentverse.message import Message
from . import llm_registry
from .base import BaseChatModel, BaseCompletionModel, BaseModelArgs
from .utils.jsonrepair import JsonRepair
try:
import openai
from openai.error import OpenAIError
except ImportError:
is_openai_available = False
logging.warning("openai package is not installed")
else:
# openai.proxy = os.environ.get("http_proxy")
# if openai.proxy is None:
# openai.proxy = os.environ.get("HTTP_PROXY")
if os.environ.get("OPENAI_API_KEY") != None:
openai.api_key = os.environ.get("OPENAI_API_KEY")
is_openai_available = True
elif os.environ.get("AZURE_OPENAI_API_KEY") != None:
openai.api_type = "azure"
openai.api_key = os.environ.get("AZURE_OPENAI_API_KEY")
openai.api_base = os.environ.get("AZURE_OPENAI_API_BASE")
openai.api_version = "2023-05-15"
is_openai_available = True
else:
logging.warning("OpenAI API key is not set. Please set the environment variable OPENAI_API_KEY")
is_openai_available = False
def log_retry(retry_state: RetryCallState):
exception = retry_state.outcome.exception()
logger.warn(
f"Retrying {retry_state.fn}\nAttempt: {retry_state.attempt_number}\nException: {exception.__class__.__name__} {exception}",
)
class OpenAIChatArgs(BaseModelArgs):
model: str = Field(default="gpt-3.5-turbo")
deployment_id: Optional[str] = Field(default=None)
max_tokens: int = Field(default=2048)
temperature: float = Field(default=1.0)
top_p: int = Field(default=1)
n: int = Field(default=1)
stop: Optional[Union[str, List]] = Field(default=None)
presence_penalty: int = Field(default=0)
frequency_penalty: int = Field(default=0)
# class OpenAICompletionArgs(OpenAIChatArgs):
# model: str = Field(default="text-davinci-003")
# suffix: str = Field(default="")
# best_of: int = Field(default=1)
# @llm_registry.register("text-davinci-003")
# class OpenAICompletion(BaseCompletionModel):
# args: OpenAICompletionArgs = Field(default_factory=OpenAICompletionArgs)
# def __init__(self, max_retry: int = 3, **kwargs):
# args = OpenAICompletionArgs()
# args = args.dict()
# for k, v in args.items():
# args[k] = kwargs.pop(k, v)
# if len(kwargs) > 0:
# logging.warning(f"Unused arguments: {kwargs}")
# super().__init__(args=args, max_retry=max_retry)
# def generate_response(self, prompt: str) -> LLMResult:
# response = openai.Completion.create(prompt=prompt, **self.args.dict())
# return LLMResult(
# content=response["choices"][0]["text"],
# send_tokens=response["usage"]["prompt_tokens"],
# recv_tokens=response["usage"]["completion_tokens"],
# total_tokens=response["usage"]["total_tokens"],
# )
# async def agenerate_response(self, prompt: str) -> LLMResult:
# response = await openai.Completion.acreate(prompt=prompt, **self.args.dict())
# return LLMResult(
# content=response["choices"][0]["text"],
# send_tokens=response["usage"]["prompt_tokens"],
# recv_tokens=response["usage"]["completion_tokens"],
# total_tokens=response["usage"]["total_tokens"],
# )
@llm_registry.register("gpt-35-turbo")
@llm_registry.register("gpt-3.5-turbo")
@llm_registry.register("gpt-4")
class OpenAIChat(BaseChatModel):
args: OpenAIChatArgs = Field(default_factory=OpenAIChatArgs)
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
def __init__(self, max_retry: int = 3, **kwargs):
args = OpenAIChatArgs()
args = args.dict()
for k, v in args.items():
args[k] = kwargs.pop(k, v)
if len(kwargs) > 0:
logging.warning(f"Unused arguments: {kwargs}")
super().__init__(args=args, max_retry=max_retry)
# def _construct_messages(self, history: List[Message]):
# return history + [{"role": "user", "content": query}]
@retry(
stop=stop_after_attempt(20),
wait=wait_exponential(multiplier=1, min=4, max=10),
reraise=True,
before_sleep=log_retry
)
def generate_response(
self,
prepend_prompt: str = "",
history: List[dict] = [],
append_prompt: str = "",
functions: List[dict] = [],
) -> LLMResult:
messages = self.construct_messages(prepend_prompt, history, append_prompt)
logger.log_prompt(messages)
try:
# Execute function call
if functions != []:
response = openai.ChatCompletion.create(
messages=messages,
functions=functions,
**self.args.dict(),
)
if response["choices"][0]["message"].get("function_call") is not None:
self.collect_metrics(response)
return LLMResult(
content=response["choices"][0]["message"].get("content", ""),
function_name=response["choices"][0]["message"]["function_call"]["name"],
function_arguments=ast.literal_eval(
response["choices"][0]["message"]["function_call"]["arguments"]
),
send_tokens=response["usage"]["prompt_tokens"],
recv_tokens=response["usage"]["completion_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
else:
self.collect_metrics(response)
return LLMResult(
content=response["choices"][0]["message"]["content"],
send_tokens=response["usage"]["prompt_tokens"],
recv_tokens=response["usage"]["completion_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
else:
response = openai.ChatCompletion.create(
messages=messages,
**self.args.dict(),
)
self.collect_metrics(response)
return LLMResult(
content=response["choices"][0]["message"]["content"],
send_tokens=response["usage"]["prompt_tokens"],
recv_tokens=response["usage"]["completion_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
except (OpenAIError, KeyboardInterrupt, json.decoder.JSONDecodeError) as error:
raise
@retry(
stop=stop_after_attempt(20),
wait=wait_exponential(multiplier=1, min=4, max=10),
reraise=True,
before_sleep=log_retry,
)
async def agenerate_response(
self,
prepend_prompt: str = "",
history: List[dict] = [],
append_prompt: str = "",
functions: List[dict] = [],
) -> LLMResult:
messages = self.construct_messages(prepend_prompt, history, append_prompt)
logger.log_prompt(messages)
try:
if functions != []:
async with ClientSession(trust_env=True) as session:
openai.aiosession.set(session)
response = await openai.ChatCompletion.acreate(
messages=messages,
functions=functions,
**self.args.dict(),
)
if response["choices"][0]["message"].get("function_call") is not None:
function_name = response["choices"][0]["message"]["function_call"]["name"]
valid_function = False
if function_name.startswith("function."):
function_name = function_name.replace("function.", "")
elif function_name.startswith("functions."):
function_name = function_name.replace("functions.", "")
for function in functions:
if function["name"] == function_name:
valid_function = True
break
if not valid_function:
logger.warn(
f"The returned function name {function_name} is not in the list of valid functions. Retrying..."
)
raise ValueError(
f"The returned function name {function_name} is not in the list of valid functions."
)
try:
arguments = ast.literal_eval(response["choices"][0]["message"]["function_call"]["arguments"])
except:
try:
arguments = ast.literal_eval(
JsonRepair(response["choices"][0]["message"]["function_call"]["arguments"]).repair()
)
except:
logger.warn("The returned argument in function call is not valid json. Retrying...")
raise ValueError("The returned argument in function call is not valid json.")
self.collect_metrics(response)
return LLMResult(
function_name=function_name,
function_arguments=arguments,
send_tokens=response["usage"]["prompt_tokens"],
recv_tokens=response["usage"]["completion_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
else:
self.collect_metrics(response)
return LLMResult(
content=response["choices"][0]["message"]["content"],
send_tokens=response["usage"]["prompt_tokens"],
recv_tokens=response["usage"]["completion_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
else:
async with ClientSession(trust_env=True) as session:
openai.aiosession.set(session)
response = await openai.ChatCompletion.acreate(
messages=messages,
**self.args.dict(),
)
self.collect_metrics(response)
return LLMResult(
content=response["choices"][0]["message"]["content"],
send_tokens=response["usage"]["prompt_tokens"],
recv_tokens=response["usage"]["completion_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
except (OpenAIError, KeyboardInterrupt, json.decoder.JSONDecodeError) as error:
raise
def construct_messages(self, prepend_prompt: str, history: List[dict], append_prompt: str):
messages = []
if prepend_prompt != "":
messages.append({"role": "system", "content": prepend_prompt})
if len(history) > 0:
messages += history
if append_prompt != "":
messages.append({"role": "user", "content": append_prompt})
return messages
def collect_metrics(self, response):
self.total_prompt_tokens += response["usage"]["prompt_tokens"]
self.total_completion_tokens += response["usage"]["completion_tokens"]
def get_spend(self) -> int:
input_cost_map = {
"gpt-3.5-turbo": 0.0015,
"gpt-3.5-turbo-16k": 0.003,
"gpt-3.5-turbo-0613": 0.0015,
"gpt-3.5-turbo-16k-0613": 0.003,
"gpt-4": 0.03,
"gpt-4-0613": 0.03,
"gpt-4-32k": 0.06,
}
output_cost_map = {
"gpt-3.5-turbo": 0.002,
"gpt-3.5-turbo-16k": 0.004,
"gpt-3.5-turbo-0613": 0.002,
"gpt-3.5-turbo-16k-0613": 0.004,
"gpt-4": 0.06,
"gpt-4-0613": 0.06,
"gpt-4-32k": 0.12,
}
model = self.args.model
if model not in input_cost_map or model not in output_cost_map:
raise ValueError(f"Model type {model} not supported")
return (
self.total_prompt_tokens * input_cost_map[model] / 1000.0
+ self.total_completion_tokens * output_cost_map[model] / 1000.0
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
reraise=True,
)
def get_embedding(text: str, attempts=3) -> np.array:
try:
text = text.replace("\n", " ")
if openai.api_type == "azure":
embedding = openai.Embedding.create(input=[text], deployment_id="text-embedding-ada-002")["data"][0][
"embedding"
]
else:
embedding = openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
return tuple(embedding)
except Exception as e:
attempts += 1
logger.error(f"Error {e} when requesting openai models. Retrying")
raise