Spaces:
Running
Running
import os | |
import json | |
import requests | |
import logging | |
from models.custom_parsers import CustomStringOutputParser | |
from app_config import ENDPOINT_NAMES | |
from langchain.chains import ConversationChain | |
from langchain_core.callbacks.manager import CallbackManagerForLLMRun | |
from langchain_core.language_models.llms import LLM | |
from langchain.prompts import PromptTemplate | |
from typing import Any, List, Mapping, Optional, Dict | |
class DatabricksCustomBizLLM(LLM): | |
issue:str | |
language:str | |
temperature:float = 0.8 | |
max_tokens: int = 128 | |
db_url:str | |
headers:Mapping[str,str] = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'} | |
def _llm_type(self) -> str: | |
return "custom_databricks_biz" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
data_ = {'inputs': { | |
'prompt': [prompt], | |
'issue': [self.issue], | |
'language': [self.language], | |
'temperature': [self.temperature], | |
'max_tokens': [self.max_tokens], | |
}} | |
data_json = json.dumps(data_, allow_nan=True) | |
response = requests.request(method='POST', headers=self.headers, url=self.db_url, data=data_json) | |
if response.status_code != 200: | |
raise Exception(f'Request failed with status {response.status_code}, {response.text}') | |
return response.json()["predictions"][0]["generated_text"] | |
_DATABRICKS_TEMPLATE_ = """{history} | |
helper: {input} | |
texter:""" | |
def get_databricks_biz_chain(source, issue, language, memory, temperature=0.8): | |
PROMPT = PromptTemplate( | |
input_variables=['history', 'input'], | |
template=_DATABRICKS_TEMPLATE_ | |
) | |
llm = DatabricksCustomBizLLM( | |
issue=issue, | |
language=language, | |
temperature=temperature, | |
max_tokens=256, | |
db_url = os.environ['DATABRICKS_URL'].format(endpoint_name=ENDPOINT_NAMES.get(source, "conversation_simulator")) | |
) | |
llm_chain = ConversationChain( | |
llm=llm, | |
prompt=PROMPT, | |
memory=memory, | |
output_parser=CustomStringOutputParser() | |
) | |
logging.debug(f"loaded Databricks Biz model") | |
return llm_chain, "helper:" |