convosim-ui / utils /chain_utils.py
ivnban27-ctl's picture
v0.5 (#4)
20b3b4a verified
raw
history blame
1.99 kB
from models.model_seeds import seeds
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
from models.openai.role_models import get_role_chain, get_template_role_models
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
from models.databricks.scenario_sim import get_databricks_chain, get_template_databricks_models
def get_chain(issue, language, source, memory, temperature, texter_name=""):
if source in ("OA_finetuned"):
OA_engine = finetuned_models[f"{issue}-{language}"]
return get_finetuned_chain(OA_engine, memory, temperature)
elif source in ('OA_rolemodel'):
seed = seeds.get(issue, "GCT")['prompt']
template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
return get_role_chain(template, memory, temperature)
elif source in ('CTL_llama2'):
if language == "English":
language = "en"
elif language == "Spanish":
language = "es"
return get_databricks_biz_chain(source, issue, language, memory, temperature)
elif source in ('CTL_mistral'):
if language == "English":
language = "en"
elif language == "Spanish":
language = "es"
seed = seeds.get(issue, "GCT")['prompt']
template, texter_name = get_template_databricks_models(issue, language, texter_name=texter_name, seed=seed)
return get_databricks_chain(source, template, memory, temperature, texter_name)
from typing import cast
def custom_chain_predict(llm_chain, input, stop):
inputs = llm_chain.prep_inputs({"input":input, "stop":stop})
llm_chain._validate_inputs(inputs)
outputs = llm_chain._call(inputs)
llm_chain._validate_outputs(outputs)
llm_chain.memory.chat_memory.add_user_message(inputs['input'])
for out in outputs[llm_chain.output_key]:
llm_chain.memory.chat_memory.add_ai_message(out)
return outputs[llm_chain.output_key]