Spaces:
Running
Running
ivnban27-ctl
commited on
Commit
•
5e31980
1
Parent(s):
e35ec41
v2_llama3 (#8)
Browse files- Adjustments for llama3 (7de31d7c6b0fd120e14eacf18386410133774a6a)
- Adjustments for llama3 (f52c6a68c26379d31335cae3d8be3aa4fd22618e)
- app_config.py +5 -2
- utils/chain_utils.py +1 -1
app_config.py
CHANGED
@@ -4,18 +4,21 @@ from models.model_seeds import seeds, seed2str
|
|
4 |
ISSUES = [k for k,_ in seeds.items()]
|
5 |
SOURCES = [
|
6 |
"CTL_llama2",
|
|
|
7 |
# "CTL_mistral",
|
8 |
'OA_rolemodel',
|
9 |
# 'OA_finetuned',
|
10 |
]
|
11 |
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
12 |
"OA_finetuned":'Finetuned OpenAI',
|
13 |
-
"CTL_llama2": "Llama",
|
|
|
14 |
"CTL_mistral": "Mistral",
|
15 |
}
|
16 |
|
17 |
ENDPOINT_NAMES = {
|
18 |
-
"CTL_llama2": "
|
|
|
19 |
# 'CTL_llama2': "llama2_convo_sim",
|
20 |
"CTL_mistral": "convo_sim_mistral"
|
21 |
}
|
|
|
4 |
ISSUES = [k for k,_ in seeds.items()]
|
5 |
SOURCES = [
|
6 |
"CTL_llama2",
|
7 |
+
# "CTL_llama3",
|
8 |
# "CTL_mistral",
|
9 |
'OA_rolemodel',
|
10 |
# 'OA_finetuned',
|
11 |
]
|
12 |
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
13 |
"OA_finetuned":'Finetuned OpenAI',
|
14 |
+
"CTL_llama2": "Llama 3",
|
15 |
+
#"CTL_llama3": "Llama 3",
|
16 |
"CTL_mistral": "Mistral",
|
17 |
}
|
18 |
|
19 |
ENDPOINT_NAMES = {
|
20 |
+
"CTL_llama2": "texter_simulator",
|
21 |
+
# "CTL_llama3": "texter_simulator",
|
22 |
# 'CTL_llama2': "llama2_convo_sim",
|
23 |
"CTL_mistral": "convo_sim_mistral"
|
24 |
}
|
utils/chain_utils.py
CHANGED
@@ -12,7 +12,7 @@ def get_chain(issue, language, source, memory, temperature, texter_name=""):
|
|
12 |
seed = seeds.get(issue, "GCT")['prompt']
|
13 |
template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
|
14 |
return get_role_chain(template, memory, temperature)
|
15 |
-
elif source in ('CTL_llama2'):
|
16 |
if language == "English":
|
17 |
language = "en"
|
18 |
elif language == "Spanish":
|
|
|
12 |
seed = seeds.get(issue, "GCT")['prompt']
|
13 |
template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
|
14 |
return get_role_chain(template, memory, temperature)
|
15 |
+
elif source in ('CTL_llama2', 'CTL_llama3'):
|
16 |
if language == "English":
|
17 |
language = "en"
|
18 |
elif language == "Spanish":
|