Spaces:
Running
Running
# from streamlit.logger import get_logger | |
from models.custom_parsers import CustomStringOutputParser | |
from langchain.chains import LLMChain | |
from langchain.llms import OpenAI | |
from langchain.prompts import PromptTemplate | |
# logger = get_logger(__name__) | |
# logger.debug("START APP") | |
finetuned_models = { | |
# "olivia_babbage_engine": "babbage:ft-crisis-text-line:exp-olivia-babbage-2023-02-23-19-57-19", | |
"Anxiety-English": "curie:ft-crisis-text-line:exp-olivia-curie-2-2023-02-24-00-25-13", | |
# "olivia_davinci_engine": "davinci:ft-crisis-text-line:exp-olivia-davinci-2023-02-24-00-02-41", | |
# "olivia_augmented_babbage_engine": "babbage:ft-crisis-text-line:exp-olivia-augmented-babbage-2023-02-24-18-35-42", | |
# "Olivia-Augmented": "curie:ft-crisis-text-line:exp-olivia-augmented-curie-2023-02-24-20-13-33", | |
# "olivia_augmented_davinci_engine": "davinci:ft-crisis-text-line:exp-olivia-augmented-davinci-2023-02-24-23-57-08", | |
# "kit_babbage_engine": "babbage:ft-crisis-text-line:exp-kit-babbage-2023-03-06-21-34-10", | |
# "kit_curie_engine": "curie:ft-crisis-text-line:exp-kit-curie-2023-03-06-22-01-29", | |
"Suicide-English": "curie:ft-crisis-text-line:exp-kit-curie-2-2023-03-08-16-26-48", | |
# "kit_davinci_engine": "davinci:ft-crisis-text-line:exp-kit-davinci-2023-03-06-23-09-15", | |
# "olivia_es_davinci_engine": "davinci:ft-crisis-text-line:es-olivia-davinci-2023-04-25-17-07-44", | |
"Anxiety-Spanish": "curie:ft-crisis-text-line:es-olivia-curie-2023-04-27-15-02-42", | |
# "olivia_curie_engine": "curie:ft-crisis-text-line:exp-olivia-curie-2-2023-02-24-00-25-13", | |
# "Oscar-Spanish": "curie:ft-crisis-text-line:es-oscar-curie-2023-05-03-21-55-06", | |
# "oscar_es_davinci_engine": "davinci:ft-crisis-text-line:es-oscar-davinci-2023-05-03-21-39-29", | |
} | |
# def generate_next_response(completion_engine, context, temperature=0.8): | |
# completion = openai.Completion.create( | |
# engine=completion_engine, | |
# prompt=context, | |
# temperature=temperature, | |
# max_tokens=150, | |
# stop="helper:" | |
# ) | |
# completion_text = completion['choices'][0]['text'] | |
# return completion_text | |
# def update_memory_completion(helper_input, memory, OA_engine, temperature=0.8): | |
# memory.chat_memory.add_user_message(helper_input) | |
# context = "## BEGIN ## \n" + memory.load_memory_variables({})['history'] + "\ntexter:" | |
# print(context) | |
# response = generate_next_response(OA_engine, context, temperature).strip().replace("\n","") | |
# response = response.split("texter:")[0] | |
# memory.chat_memory.add_ai_message(response) | |
# return response | |
def get_finetuned_chain(model_name, memory, temperature=0.8): | |
_TEXTER_TEMPLATE_ = """The following is a friendly conversation between a volunter and a person in crisis; | |
Current conversation: | |
{history} | |
helper: {input} | |
texter:""" | |
PROMPT = PromptTemplate( | |
input_variables=['history', 'input'], | |
template=_TEXTER_TEMPLATE_ | |
) | |
llm = OpenAI( | |
temperature=temperature, | |
model=model_name, | |
max_tokens=150, | |
) | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=PROMPT, | |
memory=memory, | |
output_parser = CustomStringOutputParser() | |
) | |
# logger.debug(f"{__name__}: loaded fine tuned model {model_name}") | |
return llm_chain, "helper:" |