pecore / presets.py
gsarti's picture
Fix typo in code snippet
44f7b3c verified
raw
history blame contribute delete
No virus
10.7 kB
import json
SYSTEM_PROMPT = "You are a helpful assistant that provide concise and accurate answers."
def set_cora_preset():
return (
"gsarti/cora_mgen", # model_name_or_path
"<Q>: {current} <P>: {context}", # input_template
"<Q>: {current}", # input_current_text_template
)
def set_default_preset():
return (
"gpt2", # model_name_or_path
"{current} {context}", # input_template
"{current}", # output_template
"{current}", # contextless_input_template
"{current}", # contextless_output_template
[], # special_tokens_to_keep
"", # decoder_input_output_separator
"{}", # model_kwargs
"{}", # tokenizer_kwargs
"{}", # generation_kwargs
"{}", # attribution_kwargs
)
def set_zephyr_preset():
return (
"stabilityai/stablelm-2-zephyr-1_6b", # model_name_or_path
"<|system|>{system_prompt}<|endoftext|>\n<|user|>\n{context}\n\n{current}<|endoftext|>\n<|assistant|>".replace("{system_prompt}", SYSTEM_PROMPT), # input_template
"\n", # decoder_input_output_separator
"<|system|>{system_prompt}<|endoftext|>\n<|user|>\n{current}<|endoftext|>\n<|assistant|>".replace("{system_prompt}", SYSTEM_PROMPT), # input_current_text_template
["<|im_start|>", "<|im_end|>", "<|endoftext|>"], # special_tokens_to_keep
'{\n\t"max_new_tokens": 50\n}', # generation_kwargs
)
def set_chatml_preset():
return (
"Qwen/Qwen1.5-0.5B-Chat", # model_name_or_path
"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{context}\n\n{current}<|im_end|>\n<|im_start|>assistant\n".replace("{system_prompt}", SYSTEM_PROMPT), # input_template
"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{current}<|im_end|>\n<|im_start|>assistant\n".replace("{system_prompt}", SYSTEM_PROMPT), # input_current_text_template
["<|im_start|>", "<|im_end|>"], # special_tokens_to_keep
'{\n\t"max_new_tokens": 50\n}', # generation_kwargs
)
def set_mbart_mmt_preset():
return (
"facebook/mbart-large-50-one-to-many-mmt", # model_name_or_path
"{context} {current}", # input_template
"{context} {current}", # output_template
'{\n\t"src_lang": "en_XX",\n\t"tgt_lang": "fr_XX"\n}', # tokenizer_kwargs
)
def set_nllb_mmt_preset():
return (
"facebook/nllb-200-distilled-600M", # model_name_or_path
"{context} {current}", # input_template
"{context} {current}", # output_template
'{\n\t"src_lang": "eng_Latn",\n\t"tgt_lang": "fra_Latn"\n}', # tokenizer_kwargs
)
def set_towerinstruct_preset():
return (
"Unbabel/TowerInstruct-7B-v0.1", # model_name_or_path
"<|im_start|>user\nSource: {current}\nContext: {context}\nTranslate the above text into French. Use the context to guide your answer.\nTarget:<|im_end|>\n<|im_start|>assistant\n", # input_template
"<|im_start|>user\nSource: {current}\nTranslate the above text into French.\nTarget:<|im_end|>\n<|im_start|>assistant\n", # input_current_text_template
["<|im_start|>", "<|im_end|>"], # special_tokens_to_keep
'{\n\t"max_new_tokens": 50\n}', # generation_kwargs
)
def set_gemma_preset():
return (
"google/gemma-2b-it", # model_name_or_path
"<start_of_turn>user\n{context}\n{current}<end_of_turn>\n<start_of_turn>model", # input_template
"\n", # decoder_input_output_separator
"<start_of_turn>user\n{current}<end_of_turn>\n<start_of_turn>model", # input_current_text_template
["<start_of_turn>", "<end_of_turn>"], # special_tokens_to_keep
'{\n\t"max_new_tokens": 50\n}', # generation_kwargs
)
def set_mistral_instruct_preset():
return (
"mistralai/Mistral-7B-Instruct-v0.2", # model_name_or_path
"[INST]{context}\n{current}[/INST]", # input_template
"[INST]{current}[/INST]", # input_current_text_template
'{\n\t"max_new_tokens": 50\n}', # generation_kwargs
)
def set_phi3_preset():
return (
"microsoft/Phi-3-mini-4k-instruct", # model_name_or_path
"<|system|>\n{system_prompt}<|end|>\n<|user|>\n{context}\n\n{current}<|end|>\n<|assistant|>".replace("{system_prompt}", SYSTEM_PROMPT), # input_template
"\n", # decoder_input_output_separator
"<|system|>\n{system_prompt}<|end|>\n<|user|>\n{current}<|end|>\n<|assistant|>".replace("{system_prompt}", SYSTEM_PROMPT), # input_current_text_template
["<|system|>", "<|end|>", "<|assistant|>", "<|user|>"], # special_tokens_to_keep
'{\n\t"max_new_tokens": 50\n}', # generation_kwargs
)
def update_code_snippets_fn(
input_current_text: str,
input_context_text: str,
output_current_text: str,
output_context_text: str,
model_name_or_path: str,
attribution_method: str,
attributed_fn: str | None,
context_sensitivity_metric: str,
context_sensitivity_std_threshold: float,
context_sensitivity_topk: int,
attribution_std_threshold: float,
attribution_topk: int,
input_template: str,
output_template: str,
contextless_input_template: str,
contextless_output_template: str,
special_tokens_to_keep: str | list[str] | None,
decoder_input_output_separator: str,
model_kwargs: str,
tokenizer_kwargs: str,
generation_kwargs: str,
attribution_kwargs: str,
) -> tuple[str, str]:
if not input_current_text:
input_current_text = "<MISSING INPUT CURRENT TEXT, REQUIRED>"
nl = "\n"
tq = "\"\"\""
def escape_quotes(s: str) -> str:
return s.replace('"', '\\"')
def py_get_kwargs_str(kwargs: str, name: str, pad: str = " " * 4) -> str:
kwargs_dict = json.loads(kwargs)
return nl + pad + name + '=' + str(kwargs_dict) + ',' if kwargs_dict else ''
def py_get_if_specified(arg: str | int | float | list | None, name: str, pad: str = " " * 4) -> str:
if arg is None or (isinstance(arg, (str, list)) and not arg) or (isinstance(arg, (int, float)) and arg <= 0):
return ""
elif isinstance(arg, str):
return nl + pad + name + "=" + tq + arg + tq + ","
elif isinstance(arg, list):
return nl + pad + name + "=" + str(arg) + ","
else:
return nl + pad + name + "=" + str(arg) + ","
def sh_get_kwargs_str(kwargs: str, name: str, pad: str = " " * 4) -> str:
return nl + pad + f"--{name} " + '"' + escape_quotes("".join(x.strip() for x in str(kwargs).split("\n"))) + '"' + " \\" if json.loads(kwargs) else ''
def sh_get_if_specified(arg: str | int | float | list | None, name: str, pad: str = " " * 4) -> str:
if arg is None or (isinstance(arg, (str, list)) and not arg) or (isinstance(arg, (int, float)) and arg <= 0):
return ""
elif isinstance(arg, str):
return nl + pad + f"--{name} " + '"' + escape_quotes(arg) + '"' + " \\"
elif isinstance(arg, list):
return nl + pad + f"--{name} " + " ".join(str(arg)) + " \\"
else:
return nl + pad + f"--{name} " + str(arg) + " \\"
# Python
python = f"""#!pip install inseq
import inseq
from inseq.commands.attribute_context.attribute_context import attribute_context_with_model, AttributeContextArgs
inseq_model = inseq.load_model(
"{model_name_or_path}",
"{attribution_method}",{py_get_kwargs_str(model_kwargs, "model_kwargs")}{py_get_kwargs_str(tokenizer_kwargs, "tokenizer_kwargs")}
)
pecore_args = AttributeContextArgs(
model_name_or_path="{model_name_or_path}",
attribution_method="{attribution_method}",
attributed_fn="{attributed_fn}",
context_sensitivity_metric="{context_sensitivity_metric}",
context_sensitivity_std_threshold={context_sensitivity_std_threshold},{py_get_if_specified(context_sensitivity_topk, "context_sensitivity_topk")}
attribution_std_threshold={attribution_std_threshold},{py_get_if_specified(attribution_topk, "attribution_topk")}
input_current_text=\"\"\"{input_current_text}\"\"\",{py_get_if_specified(input_context_text, "input_context_text")}
contextless_input_current_text=\"\"\"{contextless_input_template}\"\"\",
input_template=\"\"\"{input_template}\"\"\",{py_get_if_specified(output_current_text, "output_current_text")}{py_get_if_specified(output_context_text, "output_context_text")}
contextless_output_current_text=\"\"\"{contextless_output_template}\"\"\",
output_template="{output_template}",{py_get_if_specified(special_tokens_to_keep, "special_tokens_to_keep")}{py_get_if_specified(decoder_input_output_separator, "decoder_input_output_separator")}
save_path="pecore_output.json",
viz_path="pecore_output.html",{py_get_kwargs_str(model_kwargs, "model_kwargs")}{py_get_kwargs_str(tokenizer_kwargs, "tokenizer_kwargs")}{py_get_kwargs_str(generation_kwargs, "generation_kwargs")}{py_get_kwargs_str(attribution_kwargs, "attribution_kwargs")}
)
out = attribute_context_with_model(pecore_args, inseq_model)"""
# Bash
bash = f"""# pip install inseq
inseq attribute-context \\
--model_name_or_path "{model_name_or_path}" \\
--attribution_method "{attribution_method}" \\
--attributed_fn "{attributed_fn}" \\
--context_sensitivity_metric "{context_sensitivity_metric}" \\
--context_sensitivity_std_threshold {context_sensitivity_std_threshold} \\{sh_get_if_specified(context_sensitivity_topk, "context_sensitivity_topk")}
--attribution_std_threshold {attribution_std_threshold} \\{sh_get_if_specified(attribution_topk, "attribution_topk")}
--input_current_text "{escape_quotes(input_current_text)}" \\{sh_get_if_specified(input_context_text, "input_context_text")}
--contextless_input_current_text "{escape_quotes(contextless_input_template)}" \\
--input_template "{escape_quotes(input_template)}" \\{sh_get_if_specified(output_current_text, "output_current_text")}{sh_get_if_specified(output_context_text, "output_context_text")}
--contextless_output_current_text "{escape_quotes(contextless_output_template)}" \\
--output_template "{escape_quotes(output_template)}" \\{sh_get_if_specified(special_tokens_to_keep, "special_tokens_to_keep")}{sh_get_if_specified(decoder_input_output_separator, "decoder_input_output_separator")}
--save_path "pecore_output.json" \\
--viz_path "pecore_output.html" \\{sh_get_kwargs_str(model_kwargs, "model_kwargs")}{sh_get_kwargs_str(tokenizer_kwargs, "tokenizer_kwargs")}{sh_get_kwargs_str(generation_kwargs, "generation_kwargs")}{sh_get_kwargs_str(attribution_kwargs, "attribution_kwargs")}""".strip("\\")
return python, bash