fix: tokenizer chat template sys msg
#1
by
legraphista
- opened
Fixes tokenizer chat template crash when the system message is not given
AutoTokenizer.from_pretrained('OpenLLM-Ro/RoGemma-7b-Instruct', trust_remote_code=True)
chat = [
{'role': 'user', 'content': '...'},
{'role': 'assistant', 'content': '...'},
{'role': 'user', 'content': '...'}
]
tokenizer.apply_chat_template(chat)
File /shared/jupyter/.venv/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1851, in PreTrainedTokenizerBase.apply_chat_template(self, conversation, tools, documents, chat_template, add_generation_prompt, tokenize, padding, truncation, max_length, return_tensors, return_dict, tokenizer_kwargs, **kwargs)
1848 if hasattr(chat, "messages"):
1849 # Indicates it's a Conversation object
1850 chat = chat.messages
-> 1851 rendered_chat = compiled_template.render(
1852 messages=chat,
1853 tools=tool_schemas,
1854 documents=documents,
1855 add_generation_prompt=add_generation_prompt,
1856 **template_kwargs,
1857 )
1858 rendered.append(rendered_chat)
1860 if not is_batched:
File /shared/jupyter/.venv/lib/python3.10/site-packages/jinja2/environment.py:1304, in Template.render(self, *args, **kwargs)
1302 return self.environment.concat(self.root_render_func(ctx)) # type: ignore
1303 except Exception:
-> 1304 self.environment.handle_exception()
File /shared/jupyter/.venv/lib/python3.10/site-packages/jinja2/environment.py:939, in Environment.handle_exception(self, source)
934 """Exception handling helper. This is used internally to either raise
935 rewritten exceptions or return a rendered traceback for the template.
936 """
937 from .debug import rewrite_traceback_stack
--> 939 raise rewrite_traceback_stack(source=source)
File <template>:1, in top-level template code()
UndefinedError: 'system_message' is undefined
This stems from the fact that the usage of the system prompt is not in the {% if %}
block and may be unset.
This PR moves <bos>
at the beginning and moves the sys prompt inside its if
block:
{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{{ '<bos>' + system_message }}{% for message in messages %}...
becomes
{{ '<bos>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ system_message }}{% endif %}{% for message in messages %}...
mihaimasala
changed pull request status to
merged