Spaces:
Running
Running
aftorresc
commited on
Commit
•
7de31d7
1
Parent(s):
e35ec41
Adjustments for llama3
Browse files- app_config.py +6 -3
- utils/chain_utils.py +1 -1
app_config.py
CHANGED
@@ -4,18 +4,21 @@ from models.model_seeds import seeds, seed2str
|
|
4 |
ISSUES = [k for k,_ in seeds.items()]
|
5 |
SOURCES = [
|
6 |
"CTL_llama2",
|
|
|
7 |
# "CTL_mistral",
|
8 |
'OA_rolemodel',
|
9 |
# 'OA_finetuned',
|
10 |
]
|
11 |
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
12 |
"OA_finetuned":'Finetuned OpenAI',
|
13 |
-
"CTL_llama2": "Llama",
|
|
|
14 |
"CTL_mistral": "Mistral",
|
15 |
}
|
16 |
|
17 |
ENDPOINT_NAMES = {
|
18 |
-
"CTL_llama2": "
|
|
|
19 |
# 'CTL_llama2': "llama2_convo_sim",
|
20 |
"CTL_mistral": "convo_sim_mistral"
|
21 |
}
|
@@ -26,7 +29,7 @@ def source2label(source):
|
|
26 |
def issue2label(issue):
|
27 |
return seed2str.get(issue, "GCT")
|
28 |
|
29 |
-
ENVIRON = "
|
30 |
|
31 |
DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
|
32 |
DB_CONVOS = 'conversations'
|
|
|
4 |
ISSUES = [k for k,_ in seeds.items()]
|
5 |
SOURCES = [
|
6 |
"CTL_llama2",
|
7 |
+
# "CTL_llama3",
|
8 |
# "CTL_mistral",
|
9 |
'OA_rolemodel',
|
10 |
# 'OA_finetuned',
|
11 |
]
|
12 |
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
13 |
"OA_finetuned":'Finetuned OpenAI',
|
14 |
+
"CTL_llama2": "Llama 3",
|
15 |
+
#"CTL_llama3": "Llama 3",
|
16 |
"CTL_mistral": "Mistral",
|
17 |
}
|
18 |
|
19 |
ENDPOINT_NAMES = {
|
20 |
+
"CTL_llama2": "texter_simulator",
|
21 |
+
# "CTL_llama3": "texter_simulator",
|
22 |
# 'CTL_llama2': "llama2_convo_sim",
|
23 |
"CTL_mistral": "convo_sim_mistral"
|
24 |
}
|
|
|
29 |
def issue2label(issue):
|
30 |
return seed2str.get(issue, "GCT")
|
31 |
|
32 |
+
ENVIRON = "dev"
|
33 |
|
34 |
DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
|
35 |
DB_CONVOS = 'conversations'
|
utils/chain_utils.py
CHANGED
@@ -12,7 +12,7 @@ def get_chain(issue, language, source, memory, temperature, texter_name=""):
|
|
12 |
seed = seeds.get(issue, "GCT")['prompt']
|
13 |
template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
|
14 |
return get_role_chain(template, memory, temperature)
|
15 |
-
elif source in ('CTL_llama2'):
|
16 |
if language == "English":
|
17 |
language = "en"
|
18 |
elif language == "Spanish":
|
|
|
12 |
seed = seeds.get(issue, "GCT")['prompt']
|
13 |
template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
|
14 |
return get_role_chain(template, memory, temperature)
|
15 |
+
elif source in ('CTL_llama2', 'CTL_llama3'):
|
16 |
if language == "English":
|
17 |
language = "en"
|
18 |
elif language == "Spanish":
|