Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +2 -8
- api.py +33 -0
- llamafactory.egg-info/PKG-INFO +815 -0
- llamafactory.egg-info/SOURCES.txt +123 -0
- llamafactory.egg-info/dependency_links.txt +1 -0
- llamafactory.egg-info/entry_points.txt +3 -0
- llamafactory.egg-info/requires.txt +82 -0
- llamafactory.egg-info/top_level.txt +1 -0
- llamafactory/__init__.py +46 -0
- llamafactory/__pycache__/__init__.cpython-311.pyc +0 -0
- llamafactory/api/__init__.py +0 -0
- llamafactory/api/app.py +134 -0
- llamafactory/api/chat.py +237 -0
- llamafactory/api/common.py +34 -0
- llamafactory/api/protocol.py +153 -0
- llamafactory/chat/__init__.py +19 -0
- llamafactory/chat/__pycache__/__init__.cpython-311.pyc +0 -0
- llamafactory/chat/__pycache__/base_engine.cpython-311.pyc +0 -0
- llamafactory/chat/__pycache__/chat_model.cpython-311.pyc +0 -0
- llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc +0 -0
- llamafactory/chat/__pycache__/vllm_engine.cpython-311.pyc +0 -0
- llamafactory/chat/base_engine.py +102 -0
- llamafactory/chat/chat_model.py +187 -0
- llamafactory/chat/hf_engine.py +343 -0
- llamafactory/chat/vllm_engine.py +230 -0
- llamafactory/cli.py +121 -0
- llamafactory/data/__init__.py +37 -0
- llamafactory/data/__pycache__/__init__.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/aligner.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/collator.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/data_utils.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/formatter.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/loader.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/mm_plugin.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/parser.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/preprocess.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/template.cpython-311.pyc +0 -0
- llamafactory/data/__pycache__/tool_utils.cpython-311.pyc +0 -0
- llamafactory/data/aligner.py +258 -0
- llamafactory/data/collator.py +189 -0
- llamafactory/data/data_utils.py +92 -0
- llamafactory/data/formatter.py +148 -0
- llamafactory/data/loader.py +292 -0
- llamafactory/data/mm_plugin.py +627 -0
- llamafactory/data/parser.py +154 -0
- llamafactory/data/preprocess.py +111 -0
- llamafactory/data/processors/__init__.py +0 -0
- llamafactory/data/processors/__pycache__/__init__.cpython-311.pyc +0 -0
- llamafactory/data/processors/__pycache__/feedback.cpython-311.pyc +0 -0
- llamafactory/data/processors/__pycache__/pairwise.cpython-311.pyc +0 -0
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.44.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: src
|
3 |
+
app_file: webui.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.44.1
|
|
|
|
|
6 |
---
|
|
|
|
api.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
import uvicorn
|
18 |
+
|
19 |
+
from llamafactory.api.app import create_app
|
20 |
+
from llamafactory.chat import ChatModel
|
21 |
+
|
22 |
+
|
23 |
+
def main():
|
24 |
+
chat_model = ChatModel()
|
25 |
+
app = create_app(chat_model)
|
26 |
+
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
27 |
+
api_port = int(os.environ.get("API_PORT", "8000"))
|
28 |
+
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
29 |
+
uvicorn.run(app, host=api_host, port=api_port)
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
main()
|
llamafactory.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,815 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: llamafactory
|
3 |
+
Version: 0.9.1.dev0
|
4 |
+
Summary: Easy-to-use LLM fine-tuning framework
|
5 |
+
Home-page: https://github.com/hiyouga/LLaMA-Factory
|
6 |
+
Author: hiyouga
|
7 |
+
Author-email: [email protected]
|
8 |
+
License: Apache 2.0 License
|
9 |
+
Keywords: LLaMA,BLOOM,Falcon,LLM,ChatGPT,transformer,pytorch,deep learning
|
10 |
+
Classifier: Development Status :: 4 - Beta
|
11 |
+
Classifier: Intended Audience :: Developers
|
12 |
+
Classifier: Intended Audience :: Education
|
13 |
+
Classifier: Intended Audience :: Science/Research
|
14 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
15 |
+
Classifier: Operating System :: OS Independent
|
16 |
+
Classifier: Programming Language :: Python :: 3
|
17 |
+
Classifier: Programming Language :: Python :: 3.8
|
18 |
+
Classifier: Programming Language :: Python :: 3.9
|
19 |
+
Classifier: Programming Language :: Python :: 3.10
|
20 |
+
Classifier: Programming Language :: Python :: 3.11
|
21 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
22 |
+
Requires-Python: >=3.8.0
|
23 |
+
Description-Content-Type: text/markdown
|
24 |
+
License-File: LICENSE
|
25 |
+
Requires-Dist: transformers<=4.45.0,>=4.41.2
|
26 |
+
Requires-Dist: datasets<=2.21.0,>=2.16.0
|
27 |
+
Requires-Dist: accelerate<=0.34.2,>=0.30.1
|
28 |
+
Requires-Dist: peft<=0.12.0,>=0.11.1
|
29 |
+
Requires-Dist: trl<=0.9.6,>=0.8.6
|
30 |
+
Requires-Dist: gradio>=4.0.0
|
31 |
+
Requires-Dist: pandas>=2.0.0
|
32 |
+
Requires-Dist: scipy
|
33 |
+
Requires-Dist: einops
|
34 |
+
Requires-Dist: sentencepiece
|
35 |
+
Requires-Dist: tiktoken
|
36 |
+
Requires-Dist: protobuf
|
37 |
+
Requires-Dist: uvicorn
|
38 |
+
Requires-Dist: pydantic
|
39 |
+
Requires-Dist: fastapi
|
40 |
+
Requires-Dist: sse-starlette
|
41 |
+
Requires-Dist: matplotlib>=3.7.0
|
42 |
+
Requires-Dist: fire
|
43 |
+
Requires-Dist: packaging
|
44 |
+
Requires-Dist: pyyaml
|
45 |
+
Requires-Dist: numpy<2.0.0
|
46 |
+
Requires-Dist: av
|
47 |
+
Provides-Extra: torch
|
48 |
+
Requires-Dist: torch>=1.13.1; extra == "torch"
|
49 |
+
Provides-Extra: torch-npu
|
50 |
+
Requires-Dist: torch==2.1.0; extra == "torch-npu"
|
51 |
+
Requires-Dist: torch-npu==2.1.0.post3; extra == "torch-npu"
|
52 |
+
Requires-Dist: decorator; extra == "torch-npu"
|
53 |
+
Provides-Extra: metrics
|
54 |
+
Requires-Dist: nltk; extra == "metrics"
|
55 |
+
Requires-Dist: jieba; extra == "metrics"
|
56 |
+
Requires-Dist: rouge-chinese; extra == "metrics"
|
57 |
+
Provides-Extra: deepspeed
|
58 |
+
Requires-Dist: deepspeed<=0.14.4,>=0.10.0; extra == "deepspeed"
|
59 |
+
Provides-Extra: liger-kernel
|
60 |
+
Requires-Dist: liger-kernel; extra == "liger-kernel"
|
61 |
+
Provides-Extra: bitsandbytes
|
62 |
+
Requires-Dist: bitsandbytes>=0.39.0; extra == "bitsandbytes"
|
63 |
+
Provides-Extra: hqq
|
64 |
+
Requires-Dist: hqq; extra == "hqq"
|
65 |
+
Provides-Extra: eetq
|
66 |
+
Requires-Dist: eetq; extra == "eetq"
|
67 |
+
Provides-Extra: gptq
|
68 |
+
Requires-Dist: optimum>=1.17.0; extra == "gptq"
|
69 |
+
Requires-Dist: auto-gptq>=0.5.0; extra == "gptq"
|
70 |
+
Provides-Extra: awq
|
71 |
+
Requires-Dist: autoawq; extra == "awq"
|
72 |
+
Provides-Extra: aqlm
|
73 |
+
Requires-Dist: aqlm[gpu]>=1.1.0; extra == "aqlm"
|
74 |
+
Provides-Extra: vllm
|
75 |
+
Requires-Dist: vllm<=0.6.2,>=0.4.3; extra == "vllm"
|
76 |
+
Provides-Extra: galore
|
77 |
+
Requires-Dist: galore-torch; extra == "galore"
|
78 |
+
Provides-Extra: badam
|
79 |
+
Requires-Dist: badam>=1.2.1; extra == "badam"
|
80 |
+
Provides-Extra: adam-mini
|
81 |
+
Requires-Dist: adam-mini; extra == "adam-mini"
|
82 |
+
Provides-Extra: qwen
|
83 |
+
Requires-Dist: transformers_stream_generator; extra == "qwen"
|
84 |
+
Provides-Extra: modelscope
|
85 |
+
Requires-Dist: modelscope; extra == "modelscope"
|
86 |
+
Provides-Extra: dev
|
87 |
+
Requires-Dist: ruff; extra == "dev"
|
88 |
+
Requires-Dist: pytest; extra == "dev"
|
89 |
+
|
90 |
+
![# LLaMA Factory](assets/logo.png)
|
91 |
+
|
92 |
+
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
93 |
+
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
94 |
+
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
95 |
+
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
|
96 |
+
[![Citation](https://img.shields.io/badge/citation-91-green)](#projects-using-llama-factory)
|
97 |
+
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
98 |
+
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
99 |
+
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
100 |
+
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
101 |
+
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
102 |
+
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
103 |
+
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
104 |
+
|
105 |
+
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
|
106 |
+
|
107 |
+
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
|
108 |
+
|
109 |
+
\[ English | [中文](README_zh.md) \]
|
110 |
+
|
111 |
+
**Fine-tuning a large language model can be easy as...**
|
112 |
+
|
113 |
+
https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
|
114 |
+
|
115 |
+
Choose your path:
|
116 |
+
|
117 |
+
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
118 |
+
- **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
119 |
+
- **Local machine**: Please refer to [usage](#getting-started)
|
120 |
+
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
121 |
+
|
122 |
+
> [!NOTE]
|
123 |
+
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
124 |
+
|
125 |
+
## Table of Contents
|
126 |
+
|
127 |
+
- [Features](#features)
|
128 |
+
- [Benchmark](#benchmark)
|
129 |
+
- [Changelog](#changelog)
|
130 |
+
- [Supported Models](#supported-models)
|
131 |
+
- [Supported Training Approaches](#supported-training-approaches)
|
132 |
+
- [Provided Datasets](#provided-datasets)
|
133 |
+
- [Requirement](#requirement)
|
134 |
+
- [Getting Started](#getting-started)
|
135 |
+
- [Projects using LLaMA Factory](#projects-using-llama-factory)
|
136 |
+
- [License](#license)
|
137 |
+
- [Citation](#citation)
|
138 |
+
- [Acknowledgement](#acknowledgement)
|
139 |
+
|
140 |
+
## Features
|
141 |
+
|
142 |
+
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
143 |
+
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
144 |
+
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
145 |
+
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
146 |
+
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
|
147 |
+
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
148 |
+
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
149 |
+
|
150 |
+
## Benchmark
|
151 |
+
|
152 |
+
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
153 |
+
|
154 |
+
![benchmark](assets/benchmark.svg)
|
155 |
+
|
156 |
+
<details><summary>Definitions</summary>
|
157 |
+
|
158 |
+
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
159 |
+
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
160 |
+
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
161 |
+
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
|
162 |
+
|
163 |
+
</details>
|
164 |
+
|
165 |
+
## Changelog
|
166 |
+
|
167 |
+
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
168 |
+
|
169 |
+
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
170 |
+
|
171 |
+
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
|
172 |
+
|
173 |
+
[24/08/09] We support **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
|
174 |
+
|
175 |
+
<details><summary>Full Changelog</summary>
|
176 |
+
|
177 |
+
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
|
178 |
+
|
179 |
+
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
180 |
+
|
181 |
+
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
182 |
+
|
183 |
+
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
184 |
+
|
185 |
+
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
|
186 |
+
|
187 |
+
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
188 |
+
|
189 |
+
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
190 |
+
|
191 |
+
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
192 |
+
|
193 |
+
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
|
194 |
+
|
195 |
+
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
196 |
+
|
197 |
+
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
|
198 |
+
|
199 |
+
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
200 |
+
|
201 |
+
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage.
|
202 |
+
|
203 |
+
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
204 |
+
|
205 |
+
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage.
|
206 |
+
|
207 |
+
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
|
208 |
+
|
209 |
+
[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
|
210 |
+
|
211 |
+
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
|
212 |
+
|
213 |
+
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training.
|
214 |
+
|
215 |
+
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage.
|
216 |
+
|
217 |
+
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
218 |
+
|
219 |
+
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
|
220 |
+
|
221 |
+
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
222 |
+
|
223 |
+
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
224 |
+
|
225 |
+
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
|
226 |
+
|
227 |
+
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
228 |
+
|
229 |
+
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention.
|
230 |
+
|
231 |
+
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage.
|
232 |
+
|
233 |
+
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
234 |
+
|
235 |
+
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings.
|
236 |
+
|
237 |
+
[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
|
238 |
+
|
239 |
+
[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
|
240 |
+
|
241 |
+
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
|
242 |
+
|
243 |
+
[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
244 |
+
|
245 |
+
[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
246 |
+
|
247 |
+
[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
|
248 |
+
|
249 |
+
[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
250 |
+
|
251 |
+
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage.
|
252 |
+
|
253 |
+
</details>
|
254 |
+
|
255 |
+
## Supported Models
|
256 |
+
|
257 |
+
| Model | Model size | Template |
|
258 |
+
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
259 |
+
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
260 |
+
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
261 |
+
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
262 |
+
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
263 |
+
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
264 |
+
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
265 |
+
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
266 |
+
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
267 |
+
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
268 |
+
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
269 |
+
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
270 |
+
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
271 |
+
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
272 |
+
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
273 |
+
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
274 |
+
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
275 |
+
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
276 |
+
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
277 |
+
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
278 |
+
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
279 |
+
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
280 |
+
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
281 |
+
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
282 |
+
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
283 |
+
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
284 |
+
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
285 |
+
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
286 |
+
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
287 |
+
|
288 |
+
> [!NOTE]
|
289 |
+
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
290 |
+
>
|
291 |
+
> Remember to use the **SAME** template in training and inference.
|
292 |
+
|
293 |
+
Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
|
294 |
+
|
295 |
+
You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
|
296 |
+
|
297 |
+
## Supported Training Approaches
|
298 |
+
|
299 |
+
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
|
300 |
+
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
301 |
+
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
302 |
+
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
303 |
+
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
304 |
+
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
305 |
+
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
306 |
+
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
307 |
+
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
308 |
+
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
309 |
+
|
310 |
+
> [!TIP]
|
311 |
+
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
|
312 |
+
|
313 |
+
## Provided Datasets
|
314 |
+
|
315 |
+
<details><summary>Pre-training datasets</summary>
|
316 |
+
|
317 |
+
- [Wiki Demo (en)](data/wiki_demo.txt)
|
318 |
+
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
319 |
+
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
|
320 |
+
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
321 |
+
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
322 |
+
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
323 |
+
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
324 |
+
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
|
325 |
+
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
|
326 |
+
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
327 |
+
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
328 |
+
|
329 |
+
</details>
|
330 |
+
|
331 |
+
<details><summary>Supervised fine-tuning datasets</summary>
|
332 |
+
|
333 |
+
- [Identity (en&zh)](data/identity.json)
|
334 |
+
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
335 |
+
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
336 |
+
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
337 |
+
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
338 |
+
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
339 |
+
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
340 |
+
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
341 |
+
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
342 |
+
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
343 |
+
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
344 |
+
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
345 |
+
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
346 |
+
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
347 |
+
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
348 |
+
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
349 |
+
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
350 |
+
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
351 |
+
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
352 |
+
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
353 |
+
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
354 |
+
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
355 |
+
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
356 |
+
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
357 |
+
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
358 |
+
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
359 |
+
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
360 |
+
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
361 |
+
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
362 |
+
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
363 |
+
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
364 |
+
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
365 |
+
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
366 |
+
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
367 |
+
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
368 |
+
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
369 |
+
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
370 |
+
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
371 |
+
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
372 |
+
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
|
373 |
+
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
374 |
+
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
375 |
+
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
376 |
+
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
377 |
+
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
378 |
+
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
379 |
+
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
380 |
+
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
381 |
+
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
382 |
+
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
383 |
+
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
384 |
+
|
385 |
+
</details>
|
386 |
+
|
387 |
+
<details><summary>Preference datasets</summary>
|
388 |
+
|
389 |
+
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
390 |
+
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
391 |
+
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
|
392 |
+
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
|
393 |
+
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
394 |
+
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
395 |
+
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
396 |
+
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
397 |
+
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
398 |
+
|
399 |
+
</details>
|
400 |
+
|
401 |
+
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
402 |
+
|
403 |
+
```bash
|
404 |
+
pip install --upgrade huggingface_hub
|
405 |
+
huggingface-cli login
|
406 |
+
```
|
407 |
+
|
408 |
+
## Requirement
|
409 |
+
|
410 |
+
| Mandatory | Minimum | Recommend |
|
411 |
+
| ------------ | ------- | --------- |
|
412 |
+
| python | 3.8 | 3.11 |
|
413 |
+
| torch | 1.13.1 | 2.4.0 |
|
414 |
+
| transformers | 4.41.2 | 4.43.4 |
|
415 |
+
| datasets | 2.16.0 | 2.20.0 |
|
416 |
+
| accelerate | 0.30.1 | 0.32.0 |
|
417 |
+
| peft | 0.11.1 | 0.12.0 |
|
418 |
+
| trl | 0.8.6 | 0.9.6 |
|
419 |
+
|
420 |
+
| Optional | Minimum | Recommend |
|
421 |
+
| ------------ | ------- | --------- |
|
422 |
+
| CUDA | 11.6 | 12.2 |
|
423 |
+
| deepspeed | 0.10.0 | 0.14.0 |
|
424 |
+
| bitsandbytes | 0.39.0 | 0.43.1 |
|
425 |
+
| vllm | 0.4.3 | 0.5.0 |
|
426 |
+
| flash-attn | 2.3.0 | 2.6.3 |
|
427 |
+
|
428 |
+
### Hardware Requirement
|
429 |
+
|
430 |
+
\* *estimated*
|
431 |
+
|
432 |
+
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
|
433 |
+
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
|
434 |
+
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
|
435 |
+
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
|
436 |
+
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
|
437 |
+
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
|
438 |
+
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
|
439 |
+
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
|
440 |
+
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
|
441 |
+
|
442 |
+
## Getting Started
|
443 |
+
|
444 |
+
### Installation
|
445 |
+
|
446 |
+
> [!IMPORTANT]
|
447 |
+
> Installation is mandatory.
|
448 |
+
|
449 |
+
```bash
|
450 |
+
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
451 |
+
cd LLaMA-Factory
|
452 |
+
pip install -e ".[torch,metrics]"
|
453 |
+
```
|
454 |
+
|
455 |
+
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
|
456 |
+
|
457 |
+
> [!TIP]
|
458 |
+
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
459 |
+
|
460 |
+
<details><summary>For Windows users</summary>
|
461 |
+
|
462 |
+
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
463 |
+
|
464 |
+
```bash
|
465 |
+
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
466 |
+
```
|
467 |
+
|
468 |
+
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
469 |
+
|
470 |
+
</details>
|
471 |
+
|
472 |
+
<details><summary>For Ascend NPU users</summary>
|
473 |
+
|
474 |
+
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
475 |
+
|
476 |
+
```bash
|
477 |
+
# replace the url according to your CANN version and devices
|
478 |
+
# install CANN Toolkit
|
479 |
+
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
480 |
+
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
481 |
+
|
482 |
+
# install CANN Kernels
|
483 |
+
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
484 |
+
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
485 |
+
|
486 |
+
# set env variables
|
487 |
+
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
488 |
+
```
|
489 |
+
|
490 |
+
| Requirement | Minimum | Recommend |
|
491 |
+
| ------------ | ------- | ----------- |
|
492 |
+
| CANN | 8.0.RC1 | 8.0.RC1 |
|
493 |
+
| torch | 2.1.0 | 2.1.0 |
|
494 |
+
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
495 |
+
| deepspeed | 0.13.2 | 0.13.2 |
|
496 |
+
|
497 |
+
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
498 |
+
|
499 |
+
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
500 |
+
|
501 |
+
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
502 |
+
|
503 |
+
</details>
|
504 |
+
|
505 |
+
### Data Preparation
|
506 |
+
|
507 |
+
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
508 |
+
|
509 |
+
> [!NOTE]
|
510 |
+
> Please update `data/dataset_info.json` to use your custom dataset.
|
511 |
+
|
512 |
+
### Quickstart
|
513 |
+
|
514 |
+
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
515 |
+
|
516 |
+
```bash
|
517 |
+
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
518 |
+
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
519 |
+
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
520 |
+
```
|
521 |
+
|
522 |
+
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
|
523 |
+
|
524 |
+
> [!TIP]
|
525 |
+
> Use `llamafactory-cli help` to show help information.
|
526 |
+
|
527 |
+
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
528 |
+
|
529 |
+
```bash
|
530 |
+
llamafactory-cli webui
|
531 |
+
```
|
532 |
+
|
533 |
+
### Build Docker
|
534 |
+
|
535 |
+
For CUDA users:
|
536 |
+
|
537 |
+
```bash
|
538 |
+
cd docker/docker-cuda/
|
539 |
+
docker compose up -d
|
540 |
+
docker compose exec llamafactory bash
|
541 |
+
```
|
542 |
+
|
543 |
+
For Ascend NPU users:
|
544 |
+
|
545 |
+
```bash
|
546 |
+
cd docker/docker-npu/
|
547 |
+
docker compose up -d
|
548 |
+
docker compose exec llamafactory bash
|
549 |
+
```
|
550 |
+
|
551 |
+
For AMD ROCm users:
|
552 |
+
|
553 |
+
```bash
|
554 |
+
cd docker/docker-rocm/
|
555 |
+
docker compose up -d
|
556 |
+
docker compose exec llamafactory bash
|
557 |
+
```
|
558 |
+
|
559 |
+
<details><summary>Build without Docker Compose</summary>
|
560 |
+
|
561 |
+
For CUDA users:
|
562 |
+
|
563 |
+
```bash
|
564 |
+
docker build -f ./docker/docker-cuda/Dockerfile \
|
565 |
+
--build-arg INSTALL_BNB=false \
|
566 |
+
--build-arg INSTALL_VLLM=false \
|
567 |
+
--build-arg INSTALL_DEEPSPEED=false \
|
568 |
+
--build-arg INSTALL_FLASHATTN=false \
|
569 |
+
--build-arg PIP_INDEX=https://pypi.org/simple \
|
570 |
+
-t llamafactory:latest .
|
571 |
+
|
572 |
+
docker run -dit --gpus=all \
|
573 |
+
-v ./hf_cache:/root/.cache/huggingface \
|
574 |
+
-v ./ms_cache:/root/.cache/modelscope \
|
575 |
+
-v ./data:/app/data \
|
576 |
+
-v ./output:/app/output \
|
577 |
+
-p 7860:7860 \
|
578 |
+
-p 8000:8000 \
|
579 |
+
--shm-size 16G \
|
580 |
+
--name llamafactory \
|
581 |
+
llamafactory:latest
|
582 |
+
|
583 |
+
docker exec -it llamafactory bash
|
584 |
+
```
|
585 |
+
|
586 |
+
For Ascend NPU users:
|
587 |
+
|
588 |
+
```bash
|
589 |
+
# Choose docker image upon your environment
|
590 |
+
docker build -f ./docker/docker-npu/Dockerfile \
|
591 |
+
--build-arg INSTALL_DEEPSPEED=false \
|
592 |
+
--build-arg PIP_INDEX=https://pypi.org/simple \
|
593 |
+
-t llamafactory:latest .
|
594 |
+
|
595 |
+
# Change `device` upon your resources
|
596 |
+
docker run -dit \
|
597 |
+
-v ./hf_cache:/root/.cache/huggingface \
|
598 |
+
-v ./ms_cache:/root/.cache/modelscope \
|
599 |
+
-v ./data:/app/data \
|
600 |
+
-v ./output:/app/output \
|
601 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
602 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
603 |
+
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
604 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
605 |
+
-p 7860:7860 \
|
606 |
+
-p 8000:8000 \
|
607 |
+
--device /dev/davinci0 \
|
608 |
+
--device /dev/davinci_manager \
|
609 |
+
--device /dev/devmm_svm \
|
610 |
+
--device /dev/hisi_hdc \
|
611 |
+
--shm-size 16G \
|
612 |
+
--name llamafactory \
|
613 |
+
llamafactory:latest
|
614 |
+
|
615 |
+
docker exec -it llamafactory bash
|
616 |
+
```
|
617 |
+
|
618 |
+
For AMD ROCm users:
|
619 |
+
|
620 |
+
```bash
|
621 |
+
docker build -f ./docker/docker-rocm/Dockerfile \
|
622 |
+
--build-arg INSTALL_BNB=false \
|
623 |
+
--build-arg INSTALL_VLLM=false \
|
624 |
+
--build-arg INSTALL_DEEPSPEED=false \
|
625 |
+
--build-arg INSTALL_FLASHATTN=false \
|
626 |
+
--build-arg PIP_INDEX=https://pypi.org/simple \
|
627 |
+
-t llamafactory:latest .
|
628 |
+
|
629 |
+
docker run -dit \
|
630 |
+
-v ./hf_cache:/root/.cache/huggingface \
|
631 |
+
-v ./ms_cache:/root/.cache/modelscope \
|
632 |
+
-v ./data:/app/data \
|
633 |
+
-v ./output:/app/output \
|
634 |
+
-v ./saves:/app/saves \
|
635 |
+
-p 7860:7860 \
|
636 |
+
-p 8000:8000 \
|
637 |
+
--device /dev/kfd \
|
638 |
+
--device /dev/dri \
|
639 |
+
--shm-size 16G \
|
640 |
+
--name llamafactory \
|
641 |
+
llamafactory:latest
|
642 |
+
|
643 |
+
docker exec -it llamafactory bash
|
644 |
+
```
|
645 |
+
|
646 |
+
</details>
|
647 |
+
|
648 |
+
<details><summary>Details about volume</summary>
|
649 |
+
|
650 |
+
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
651 |
+
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
652 |
+
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
653 |
+
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
654 |
+
|
655 |
+
</details>
|
656 |
+
|
657 |
+
### Deploy with OpenAI-style API and vLLM
|
658 |
+
|
659 |
+
```bash
|
660 |
+
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
661 |
+
```
|
662 |
+
|
663 |
+
> [!TIP]
|
664 |
+
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
665 |
+
|
666 |
+
### Download from ModelScope Hub
|
667 |
+
|
668 |
+
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
|
669 |
+
|
670 |
+
```bash
|
671 |
+
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
672 |
+
```
|
673 |
+
|
674 |
+
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
675 |
+
|
676 |
+
### Use W&B Logger
|
677 |
+
|
678 |
+
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
679 |
+
|
680 |
+
```yaml
|
681 |
+
report_to: wandb
|
682 |
+
run_name: test_run # optional
|
683 |
+
```
|
684 |
+
|
685 |
+
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
|
686 |
+
|
687 |
+
## Projects using LLaMA Factory
|
688 |
+
|
689 |
+
If you have a project that should be incorporated, please contact via email or create a pull request.
|
690 |
+
|
691 |
+
<details><summary>Click to show</summary>
|
692 |
+
|
693 |
+
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
694 |
+
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
695 |
+
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
696 |
+
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
697 |
+
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
698 |
+
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
699 |
+
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
700 |
+
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
701 |
+
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
702 |
+
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
703 |
+
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
704 |
+
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
705 |
+
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
706 |
+
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
707 |
+
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
708 |
+
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
709 |
+
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
710 |
+
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
711 |
+
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
712 |
+
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
713 |
+
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
714 |
+
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
715 |
+
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
716 |
+
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
717 |
+
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
718 |
+
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
719 |
+
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
720 |
+
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
721 |
+
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
722 |
+
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
723 |
+
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
724 |
+
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
725 |
+
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
726 |
+
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
727 |
+
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
728 |
+
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
729 |
+
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
730 |
+
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
731 |
+
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
732 |
+
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
733 |
+
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
734 |
+
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
735 |
+
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
736 |
+
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
737 |
+
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
738 |
+
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
739 |
+
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
740 |
+
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
741 |
+
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
742 |
+
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
743 |
+
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
744 |
+
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
745 |
+
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
746 |
+
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
747 |
+
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
748 |
+
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
749 |
+
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
750 |
+
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
751 |
+
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
752 |
+
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
753 |
+
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
754 |
+
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
755 |
+
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
|
756 |
+
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
|
757 |
+
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
|
758 |
+
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
|
759 |
+
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
|
760 |
+
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
|
761 |
+
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
|
762 |
+
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
|
763 |
+
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
|
764 |
+
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
|
765 |
+
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
|
766 |
+
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
|
767 |
+
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
|
768 |
+
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
|
769 |
+
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
|
770 |
+
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
|
771 |
+
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
772 |
+
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
773 |
+
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
774 |
+
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
775 |
+
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
776 |
+
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
777 |
+
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
778 |
+
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
779 |
+
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
780 |
+
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
781 |
+
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
782 |
+
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
783 |
+
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
784 |
+
|
785 |
+
</details>
|
786 |
+
|
787 |
+
## License
|
788 |
+
|
789 |
+
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
790 |
+
|
791 |
+
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
792 |
+
|
793 |
+
## Citation
|
794 |
+
|
795 |
+
If this work is helpful, please kindly cite as:
|
796 |
+
|
797 |
+
```bibtex
|
798 |
+
@inproceedings{zheng2024llamafactory,
|
799 |
+
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
800 |
+
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
801 |
+
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
802 |
+
address={Bangkok, Thailand},
|
803 |
+
publisher={Association for Computational Linguistics},
|
804 |
+
year={2024},
|
805 |
+
url={http://arxiv.org/abs/2403.13372}
|
806 |
+
}
|
807 |
+
```
|
808 |
+
|
809 |
+
## Acknowledgement
|
810 |
+
|
811 |
+
This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
|
812 |
+
|
813 |
+
## Star History
|
814 |
+
|
815 |
+
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)
|
llamafactory.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE
|
2 |
+
MANIFEST.in
|
3 |
+
README.md
|
4 |
+
pyproject.toml
|
5 |
+
requirements.txt
|
6 |
+
setup.py
|
7 |
+
src/llamafactory/__init__.py
|
8 |
+
src/llamafactory/cli.py
|
9 |
+
src/llamafactory/launcher.py
|
10 |
+
src/llamafactory.egg-info/PKG-INFO
|
11 |
+
src/llamafactory.egg-info/SOURCES.txt
|
12 |
+
src/llamafactory.egg-info/dependency_links.txt
|
13 |
+
src/llamafactory.egg-info/entry_points.txt
|
14 |
+
src/llamafactory.egg-info/requires.txt
|
15 |
+
src/llamafactory.egg-info/top_level.txt
|
16 |
+
src/llamafactory/api/__init__.py
|
17 |
+
src/llamafactory/api/app.py
|
18 |
+
src/llamafactory/api/chat.py
|
19 |
+
src/llamafactory/api/common.py
|
20 |
+
src/llamafactory/api/protocol.py
|
21 |
+
src/llamafactory/chat/__init__.py
|
22 |
+
src/llamafactory/chat/base_engine.py
|
23 |
+
src/llamafactory/chat/chat_model.py
|
24 |
+
src/llamafactory/chat/hf_engine.py
|
25 |
+
src/llamafactory/chat/vllm_engine.py
|
26 |
+
src/llamafactory/data/__init__.py
|
27 |
+
src/llamafactory/data/aligner.py
|
28 |
+
src/llamafactory/data/collator.py
|
29 |
+
src/llamafactory/data/data_utils.py
|
30 |
+
src/llamafactory/data/formatter.py
|
31 |
+
src/llamafactory/data/loader.py
|
32 |
+
src/llamafactory/data/mm_plugin.py
|
33 |
+
src/llamafactory/data/parser.py
|
34 |
+
src/llamafactory/data/preprocess.py
|
35 |
+
src/llamafactory/data/template.py
|
36 |
+
src/llamafactory/data/tool_utils.py
|
37 |
+
src/llamafactory/data/processors/__init__.py
|
38 |
+
src/llamafactory/data/processors/feedback.py
|
39 |
+
src/llamafactory/data/processors/pairwise.py
|
40 |
+
src/llamafactory/data/processors/pretrain.py
|
41 |
+
src/llamafactory/data/processors/processor_utils.py
|
42 |
+
src/llamafactory/data/processors/supervised.py
|
43 |
+
src/llamafactory/data/processors/unsupervised.py
|
44 |
+
src/llamafactory/eval/__init__.py
|
45 |
+
src/llamafactory/eval/evaluator.py
|
46 |
+
src/llamafactory/eval/template.py
|
47 |
+
src/llamafactory/extras/__init__.py
|
48 |
+
src/llamafactory/extras/constants.py
|
49 |
+
src/llamafactory/extras/env.py
|
50 |
+
src/llamafactory/extras/logging.py
|
51 |
+
src/llamafactory/extras/misc.py
|
52 |
+
src/llamafactory/extras/packages.py
|
53 |
+
src/llamafactory/extras/ploting.py
|
54 |
+
src/llamafactory/hparams/__init__.py
|
55 |
+
src/llamafactory/hparams/data_args.py
|
56 |
+
src/llamafactory/hparams/evaluation_args.py
|
57 |
+
src/llamafactory/hparams/finetuning_args.py
|
58 |
+
src/llamafactory/hparams/generating_args.py
|
59 |
+
src/llamafactory/hparams/model_args.py
|
60 |
+
src/llamafactory/hparams/parser.py
|
61 |
+
src/llamafactory/model/__init__.py
|
62 |
+
src/llamafactory/model/adapter.py
|
63 |
+
src/llamafactory/model/loader.py
|
64 |
+
src/llamafactory/model/patcher.py
|
65 |
+
src/llamafactory/model/model_utils/__init__.py
|
66 |
+
src/llamafactory/model/model_utils/attention.py
|
67 |
+
src/llamafactory/model/model_utils/checkpointing.py
|
68 |
+
src/llamafactory/model/model_utils/embedding.py
|
69 |
+
src/llamafactory/model/model_utils/liger_kernel.py
|
70 |
+
src/llamafactory/model/model_utils/longlora.py
|
71 |
+
src/llamafactory/model/model_utils/misc.py
|
72 |
+
src/llamafactory/model/model_utils/mod.py
|
73 |
+
src/llamafactory/model/model_utils/moe.py
|
74 |
+
src/llamafactory/model/model_utils/packing.py
|
75 |
+
src/llamafactory/model/model_utils/quantization.py
|
76 |
+
src/llamafactory/model/model_utils/rope.py
|
77 |
+
src/llamafactory/model/model_utils/unsloth.py
|
78 |
+
src/llamafactory/model/model_utils/valuehead.py
|
79 |
+
src/llamafactory/model/model_utils/visual.py
|
80 |
+
src/llamafactory/train/__init__.py
|
81 |
+
src/llamafactory/train/callbacks.py
|
82 |
+
src/llamafactory/train/test_utils.py
|
83 |
+
src/llamafactory/train/trainer_utils.py
|
84 |
+
src/llamafactory/train/tuner.py
|
85 |
+
src/llamafactory/train/dpo/__init__.py
|
86 |
+
src/llamafactory/train/dpo/trainer.py
|
87 |
+
src/llamafactory/train/dpo/workflow.py
|
88 |
+
src/llamafactory/train/kto/__init__.py
|
89 |
+
src/llamafactory/train/kto/trainer.py
|
90 |
+
src/llamafactory/train/kto/workflow.py
|
91 |
+
src/llamafactory/train/ppo/__init__.py
|
92 |
+
src/llamafactory/train/ppo/ppo_utils.py
|
93 |
+
src/llamafactory/train/ppo/trainer.py
|
94 |
+
src/llamafactory/train/ppo/workflow.py
|
95 |
+
src/llamafactory/train/pt/__init__.py
|
96 |
+
src/llamafactory/train/pt/trainer.py
|
97 |
+
src/llamafactory/train/pt/workflow.py
|
98 |
+
src/llamafactory/train/rm/__init__.py
|
99 |
+
src/llamafactory/train/rm/metric.py
|
100 |
+
src/llamafactory/train/rm/trainer.py
|
101 |
+
src/llamafactory/train/rm/workflow.py
|
102 |
+
src/llamafactory/train/sft/__init__.py
|
103 |
+
src/llamafactory/train/sft/metric.py
|
104 |
+
src/llamafactory/train/sft/trainer.py
|
105 |
+
src/llamafactory/train/sft/workflow.py
|
106 |
+
src/llamafactory/webui/__init__.py
|
107 |
+
src/llamafactory/webui/chatter.py
|
108 |
+
src/llamafactory/webui/common.py
|
109 |
+
src/llamafactory/webui/css.py
|
110 |
+
src/llamafactory/webui/engine.py
|
111 |
+
src/llamafactory/webui/interface.py
|
112 |
+
src/llamafactory/webui/locales.py
|
113 |
+
src/llamafactory/webui/manager.py
|
114 |
+
src/llamafactory/webui/runner.py
|
115 |
+
src/llamafactory/webui/utils.py
|
116 |
+
src/llamafactory/webui/components/__init__.py
|
117 |
+
src/llamafactory/webui/components/chatbot.py
|
118 |
+
src/llamafactory/webui/components/data.py
|
119 |
+
src/llamafactory/webui/components/eval.py
|
120 |
+
src/llamafactory/webui/components/export.py
|
121 |
+
src/llamafactory/webui/components/infer.py
|
122 |
+
src/llamafactory/webui/components/top.py
|
123 |
+
src/llamafactory/webui/components/train.py
|
llamafactory.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
llamafactory.egg-info/entry_points.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[console_scripts]
|
2 |
+
llamafactory-cli = llamafactory.cli:main
|
3 |
+
lmf = llamafactory.cli:main
|
llamafactory.egg-info/requires.txt
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers<=4.45.0,>=4.41.2
|
2 |
+
datasets<=2.21.0,>=2.16.0
|
3 |
+
accelerate<=0.34.2,>=0.30.1
|
4 |
+
peft<=0.12.0,>=0.11.1
|
5 |
+
trl<=0.9.6,>=0.8.6
|
6 |
+
gradio>=4.0.0
|
7 |
+
pandas>=2.0.0
|
8 |
+
scipy
|
9 |
+
einops
|
10 |
+
sentencepiece
|
11 |
+
tiktoken
|
12 |
+
protobuf
|
13 |
+
uvicorn
|
14 |
+
pydantic
|
15 |
+
fastapi
|
16 |
+
sse-starlette
|
17 |
+
matplotlib>=3.7.0
|
18 |
+
fire
|
19 |
+
packaging
|
20 |
+
pyyaml
|
21 |
+
numpy<2.0.0
|
22 |
+
av
|
23 |
+
|
24 |
+
[adam-mini]
|
25 |
+
adam-mini
|
26 |
+
|
27 |
+
[aqlm]
|
28 |
+
aqlm[gpu]>=1.1.0
|
29 |
+
|
30 |
+
[awq]
|
31 |
+
autoawq
|
32 |
+
|
33 |
+
[badam]
|
34 |
+
badam>=1.2.1
|
35 |
+
|
36 |
+
[bitsandbytes]
|
37 |
+
bitsandbytes>=0.39.0
|
38 |
+
|
39 |
+
[deepspeed]
|
40 |
+
deepspeed<=0.14.4,>=0.10.0
|
41 |
+
|
42 |
+
[dev]
|
43 |
+
ruff
|
44 |
+
pytest
|
45 |
+
|
46 |
+
[eetq]
|
47 |
+
eetq
|
48 |
+
|
49 |
+
[galore]
|
50 |
+
galore-torch
|
51 |
+
|
52 |
+
[gptq]
|
53 |
+
optimum>=1.17.0
|
54 |
+
auto-gptq>=0.5.0
|
55 |
+
|
56 |
+
[hqq]
|
57 |
+
hqq
|
58 |
+
|
59 |
+
[liger-kernel]
|
60 |
+
liger-kernel
|
61 |
+
|
62 |
+
[metrics]
|
63 |
+
nltk
|
64 |
+
jieba
|
65 |
+
rouge-chinese
|
66 |
+
|
67 |
+
[modelscope]
|
68 |
+
modelscope
|
69 |
+
|
70 |
+
[qwen]
|
71 |
+
transformers_stream_generator
|
72 |
+
|
73 |
+
[torch]
|
74 |
+
torch>=1.13.1
|
75 |
+
|
76 |
+
[torch-npu]
|
77 |
+
torch==2.1.0
|
78 |
+
torch-npu==2.1.0.post3
|
79 |
+
decorator
|
80 |
+
|
81 |
+
[vllm]
|
82 |
+
vllm<=0.6.2,>=0.4.3
|
llamafactory.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
llamafactory
|
llamafactory/__init__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
r"""
|
16 |
+
Efficient fine-tuning of large language models.
|
17 |
+
|
18 |
+
Level:
|
19 |
+
api, webui > chat, eval, train > data, model > hparams > extras
|
20 |
+
|
21 |
+
Dependency graph:
|
22 |
+
main:
|
23 |
+
transformers>=4.41.2,<=4.45.0
|
24 |
+
datasets>=2.16.0,<=2.21.0
|
25 |
+
accelerate>=0.30.1,<=0.34.2
|
26 |
+
peft>=0.11.1,<=0.12.0
|
27 |
+
trl>=0.8.6,<=0.9.6
|
28 |
+
attention:
|
29 |
+
transformers>=4.42.4 (gemma+fa2)
|
30 |
+
longlora:
|
31 |
+
transformers>=4.41.2,<=4.45.0
|
32 |
+
packing:
|
33 |
+
transformers>=4.41.2,<=4.45.0
|
34 |
+
|
35 |
+
Disable version checking: DISABLE_VERSION_CHECK=1
|
36 |
+
Enable VRAM recording: RECORD_VRAM=1
|
37 |
+
Force check imports: FORCE_CHECK_IMPORTS=1
|
38 |
+
Force using torchrun: FORCE_TORCHRUN=1
|
39 |
+
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
40 |
+
Use modelscope: USE_MODELSCOPE_HUB=1
|
41 |
+
"""
|
42 |
+
|
43 |
+
from .extras.env import VERSION
|
44 |
+
|
45 |
+
|
46 |
+
__version__ = VERSION
|
llamafactory/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (959 Bytes). View file
|
|
llamafactory/api/__init__.py
ADDED
File without changes
|
llamafactory/api/app.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import asyncio
|
16 |
+
import os
|
17 |
+
from contextlib import asynccontextmanager
|
18 |
+
from functools import partial
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
from typing_extensions import Annotated
|
22 |
+
|
23 |
+
from ..chat import ChatModel
|
24 |
+
from ..extras.misc import torch_gc
|
25 |
+
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
|
26 |
+
from .chat import (
|
27 |
+
create_chat_completion_response,
|
28 |
+
create_score_evaluation_response,
|
29 |
+
create_stream_chat_completion_response,
|
30 |
+
)
|
31 |
+
from .protocol import (
|
32 |
+
ChatCompletionRequest,
|
33 |
+
ChatCompletionResponse,
|
34 |
+
ModelCard,
|
35 |
+
ModelList,
|
36 |
+
ScoreEvaluationRequest,
|
37 |
+
ScoreEvaluationResponse,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
if is_fastapi_available():
|
42 |
+
from fastapi import Depends, FastAPI, HTTPException, status
|
43 |
+
from fastapi.middleware.cors import CORSMiddleware
|
44 |
+
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
45 |
+
|
46 |
+
|
47 |
+
if is_starlette_available():
|
48 |
+
from sse_starlette import EventSourceResponse
|
49 |
+
|
50 |
+
|
51 |
+
if is_uvicorn_available():
|
52 |
+
import uvicorn
|
53 |
+
|
54 |
+
|
55 |
+
async def sweeper() -> None:
|
56 |
+
while True:
|
57 |
+
torch_gc()
|
58 |
+
await asyncio.sleep(300)
|
59 |
+
|
60 |
+
|
61 |
+
@asynccontextmanager
|
62 |
+
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
63 |
+
if chat_model.engine_type == "huggingface":
|
64 |
+
asyncio.create_task(sweeper())
|
65 |
+
|
66 |
+
yield
|
67 |
+
torch_gc()
|
68 |
+
|
69 |
+
|
70 |
+
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
71 |
+
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
72 |
+
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
73 |
+
app.add_middleware(
|
74 |
+
CORSMiddleware,
|
75 |
+
allow_origins=["*"],
|
76 |
+
allow_credentials=True,
|
77 |
+
allow_methods=["*"],
|
78 |
+
allow_headers=["*"],
|
79 |
+
)
|
80 |
+
api_key = os.environ.get("API_KEY", None)
|
81 |
+
security = HTTPBearer(auto_error=False)
|
82 |
+
|
83 |
+
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
84 |
+
if api_key and (auth is None or auth.credentials != api_key):
|
85 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
86 |
+
|
87 |
+
@app.get(
|
88 |
+
"/v1/models",
|
89 |
+
response_model=ModelList,
|
90 |
+
status_code=status.HTTP_200_OK,
|
91 |
+
dependencies=[Depends(verify_api_key)],
|
92 |
+
)
|
93 |
+
async def list_models():
|
94 |
+
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
95 |
+
return ModelList(data=[model_card])
|
96 |
+
|
97 |
+
@app.post(
|
98 |
+
"/v1/chat/completions",
|
99 |
+
response_model=ChatCompletionResponse,
|
100 |
+
status_code=status.HTTP_200_OK,
|
101 |
+
dependencies=[Depends(verify_api_key)],
|
102 |
+
)
|
103 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
104 |
+
if not chat_model.engine.can_generate:
|
105 |
+
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
106 |
+
|
107 |
+
if request.stream:
|
108 |
+
generate = create_stream_chat_completion_response(request, chat_model)
|
109 |
+
return EventSourceResponse(generate, media_type="text/event-stream")
|
110 |
+
else:
|
111 |
+
return await create_chat_completion_response(request, chat_model)
|
112 |
+
|
113 |
+
@app.post(
|
114 |
+
"/v1/score/evaluation",
|
115 |
+
response_model=ScoreEvaluationResponse,
|
116 |
+
status_code=status.HTTP_200_OK,
|
117 |
+
dependencies=[Depends(verify_api_key)],
|
118 |
+
)
|
119 |
+
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
120 |
+
if chat_model.engine.can_generate:
|
121 |
+
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
122 |
+
|
123 |
+
return await create_score_evaluation_response(request, chat_model)
|
124 |
+
|
125 |
+
return app
|
126 |
+
|
127 |
+
|
128 |
+
def run_api() -> None:
|
129 |
+
chat_model = ChatModel()
|
130 |
+
app = create_app(chat_model)
|
131 |
+
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
132 |
+
api_port = int(os.environ.get("API_PORT", "8000"))
|
133 |
+
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
134 |
+
uvicorn.run(app, host=api_host, port=api_port)
|
llamafactory/api/chat.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import base64
|
16 |
+
import io
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
import uuid
|
21 |
+
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
22 |
+
|
23 |
+
from ..data import Role as DataRole
|
24 |
+
from ..extras.logging import get_logger
|
25 |
+
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
26 |
+
from .common import dictify, jsonify
|
27 |
+
from .protocol import (
|
28 |
+
ChatCompletionMessage,
|
29 |
+
ChatCompletionResponse,
|
30 |
+
ChatCompletionResponseChoice,
|
31 |
+
ChatCompletionResponseUsage,
|
32 |
+
ChatCompletionStreamResponse,
|
33 |
+
ChatCompletionStreamResponseChoice,
|
34 |
+
Finish,
|
35 |
+
Function,
|
36 |
+
FunctionCall,
|
37 |
+
Role,
|
38 |
+
ScoreEvaluationResponse,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
if is_fastapi_available():
|
43 |
+
from fastapi import HTTPException, status
|
44 |
+
|
45 |
+
|
46 |
+
if is_pillow_available():
|
47 |
+
from PIL import Image
|
48 |
+
|
49 |
+
|
50 |
+
if is_requests_available():
|
51 |
+
import requests
|
52 |
+
|
53 |
+
|
54 |
+
if TYPE_CHECKING:
|
55 |
+
from ..chat import ChatModel
|
56 |
+
from ..data.mm_plugin import ImageInput
|
57 |
+
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
58 |
+
|
59 |
+
|
60 |
+
logger = get_logger(__name__)
|
61 |
+
ROLE_MAPPING = {
|
62 |
+
Role.USER: DataRole.USER.value,
|
63 |
+
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
64 |
+
Role.SYSTEM: DataRole.SYSTEM.value,
|
65 |
+
Role.FUNCTION: DataRole.FUNCTION.value,
|
66 |
+
Role.TOOL: DataRole.OBSERVATION.value,
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
def _process_request(
|
71 |
+
request: "ChatCompletionRequest",
|
72 |
+
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
73 |
+
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
74 |
+
|
75 |
+
if len(request.messages) == 0:
|
76 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
77 |
+
|
78 |
+
if request.messages[0].role == Role.SYSTEM:
|
79 |
+
system = request.messages.pop(0).content
|
80 |
+
else:
|
81 |
+
system = None
|
82 |
+
|
83 |
+
if len(request.messages) % 2 == 0:
|
84 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
85 |
+
|
86 |
+
input_messages = []
|
87 |
+
image = None
|
88 |
+
for i, message in enumerate(request.messages):
|
89 |
+
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
90 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
91 |
+
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
92 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
93 |
+
|
94 |
+
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
95 |
+
tool_calls = [
|
96 |
+
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
|
97 |
+
for tool_call in message.tool_calls
|
98 |
+
]
|
99 |
+
content = json.dumps(tool_calls, ensure_ascii=False)
|
100 |
+
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
101 |
+
elif isinstance(message.content, list):
|
102 |
+
for input_item in message.content:
|
103 |
+
if input_item.type == "text":
|
104 |
+
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
105 |
+
else:
|
106 |
+
image_url = input_item.image_url.url
|
107 |
+
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
|
108 |
+
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
|
109 |
+
elif os.path.isfile(image_url): # local file
|
110 |
+
image_stream = open(image_url, "rb")
|
111 |
+
else: # web uri
|
112 |
+
image_stream = requests.get(image_url, stream=True).raw
|
113 |
+
|
114 |
+
image = Image.open(image_stream).convert("RGB")
|
115 |
+
else:
|
116 |
+
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
117 |
+
|
118 |
+
tool_list = request.tools
|
119 |
+
if isinstance(tool_list, list) and len(tool_list):
|
120 |
+
try:
|
121 |
+
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
122 |
+
except json.JSONDecodeError:
|
123 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
124 |
+
else:
|
125 |
+
tools = None
|
126 |
+
|
127 |
+
return input_messages, system, tools, image
|
128 |
+
|
129 |
+
|
130 |
+
def _create_stream_chat_completion_chunk(
|
131 |
+
completion_id: str,
|
132 |
+
model: str,
|
133 |
+
delta: "ChatCompletionMessage",
|
134 |
+
index: Optional[int] = 0,
|
135 |
+
finish_reason: Optional["Finish"] = None,
|
136 |
+
) -> str:
|
137 |
+
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
|
138 |
+
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
|
139 |
+
return jsonify(chunk)
|
140 |
+
|
141 |
+
|
142 |
+
async def create_chat_completion_response(
|
143 |
+
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
144 |
+
) -> "ChatCompletionResponse":
|
145 |
+
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
146 |
+
input_messages, system, tools, image = _process_request(request)
|
147 |
+
responses = await chat_model.achat(
|
148 |
+
input_messages,
|
149 |
+
system,
|
150 |
+
tools,
|
151 |
+
image,
|
152 |
+
do_sample=request.do_sample,
|
153 |
+
temperature=request.temperature,
|
154 |
+
top_p=request.top_p,
|
155 |
+
max_new_tokens=request.max_tokens,
|
156 |
+
num_return_sequences=request.n,
|
157 |
+
stop=request.stop,
|
158 |
+
)
|
159 |
+
|
160 |
+
prompt_length, response_length = 0, 0
|
161 |
+
choices = []
|
162 |
+
for i, response in enumerate(responses):
|
163 |
+
if tools:
|
164 |
+
result = chat_model.engine.template.extract_tool(response.response_text)
|
165 |
+
else:
|
166 |
+
result = response.response_text
|
167 |
+
|
168 |
+
if isinstance(result, list):
|
169 |
+
tool_calls = []
|
170 |
+
for tool in result:
|
171 |
+
function = Function(name=tool[0], arguments=tool[1])
|
172 |
+
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
173 |
+
|
174 |
+
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
175 |
+
finish_reason = Finish.TOOL
|
176 |
+
else:
|
177 |
+
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
178 |
+
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
179 |
+
|
180 |
+
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
|
181 |
+
prompt_length = response.prompt_length
|
182 |
+
response_length += response.response_length
|
183 |
+
|
184 |
+
usage = ChatCompletionResponseUsage(
|
185 |
+
prompt_tokens=prompt_length,
|
186 |
+
completion_tokens=response_length,
|
187 |
+
total_tokens=prompt_length + response_length,
|
188 |
+
)
|
189 |
+
|
190 |
+
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
|
191 |
+
|
192 |
+
|
193 |
+
async def create_stream_chat_completion_response(
|
194 |
+
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
195 |
+
) -> AsyncGenerator[str, None]:
|
196 |
+
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
197 |
+
input_messages, system, tools, image = _process_request(request)
|
198 |
+
if tools:
|
199 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
200 |
+
|
201 |
+
if request.n > 1:
|
202 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
|
203 |
+
|
204 |
+
yield _create_stream_chat_completion_chunk(
|
205 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
|
206 |
+
)
|
207 |
+
async for new_token in chat_model.astream_chat(
|
208 |
+
input_messages,
|
209 |
+
system,
|
210 |
+
tools,
|
211 |
+
image,
|
212 |
+
do_sample=request.do_sample,
|
213 |
+
temperature=request.temperature,
|
214 |
+
top_p=request.top_p,
|
215 |
+
max_new_tokens=request.max_tokens,
|
216 |
+
stop=request.stop,
|
217 |
+
):
|
218 |
+
if len(new_token) != 0:
|
219 |
+
yield _create_stream_chat_completion_chunk(
|
220 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
221 |
+
)
|
222 |
+
|
223 |
+
yield _create_stream_chat_completion_chunk(
|
224 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
225 |
+
)
|
226 |
+
yield "[DONE]"
|
227 |
+
|
228 |
+
|
229 |
+
async def create_score_evaluation_response(
|
230 |
+
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
231 |
+
) -> "ScoreEvaluationResponse":
|
232 |
+
score_id = "scoreval-{}".format(uuid.uuid4().hex)
|
233 |
+
if len(request.messages) == 0:
|
234 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
235 |
+
|
236 |
+
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
237 |
+
return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
|
llamafactory/api/common.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
from typing import TYPE_CHECKING, Any, Dict
|
17 |
+
|
18 |
+
|
19 |
+
if TYPE_CHECKING:
|
20 |
+
from pydantic import BaseModel
|
21 |
+
|
22 |
+
|
23 |
+
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
24 |
+
try: # pydantic v2
|
25 |
+
return data.model_dump(exclude_unset=True)
|
26 |
+
except AttributeError: # pydantic v1
|
27 |
+
return data.dict(exclude_unset=True)
|
28 |
+
|
29 |
+
|
30 |
+
def jsonify(data: "BaseModel") -> str:
|
31 |
+
try: # pydantic v2
|
32 |
+
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
33 |
+
except AttributeError: # pydantic v1
|
34 |
+
return data.json(exclude_unset=True, ensure_ascii=False)
|
llamafactory/api/protocol.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import time
|
16 |
+
from enum import Enum, unique
|
17 |
+
from typing import Any, Dict, List, Optional, Union
|
18 |
+
|
19 |
+
from pydantic import BaseModel, Field
|
20 |
+
from typing_extensions import Literal
|
21 |
+
|
22 |
+
|
23 |
+
@unique
|
24 |
+
class Role(str, Enum):
|
25 |
+
USER = "user"
|
26 |
+
ASSISTANT = "assistant"
|
27 |
+
SYSTEM = "system"
|
28 |
+
FUNCTION = "function"
|
29 |
+
TOOL = "tool"
|
30 |
+
|
31 |
+
|
32 |
+
@unique
|
33 |
+
class Finish(str, Enum):
|
34 |
+
STOP = "stop"
|
35 |
+
LENGTH = "length"
|
36 |
+
TOOL = "tool_calls"
|
37 |
+
|
38 |
+
|
39 |
+
class ModelCard(BaseModel):
|
40 |
+
id: str
|
41 |
+
object: Literal["model"] = "model"
|
42 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
43 |
+
owned_by: Literal["owner"] = "owner"
|
44 |
+
|
45 |
+
|
46 |
+
class ModelList(BaseModel):
|
47 |
+
object: Literal["list"] = "list"
|
48 |
+
data: List[ModelCard] = []
|
49 |
+
|
50 |
+
|
51 |
+
class Function(BaseModel):
|
52 |
+
name: str
|
53 |
+
arguments: str
|
54 |
+
|
55 |
+
|
56 |
+
class FunctionDefinition(BaseModel):
|
57 |
+
name: str
|
58 |
+
description: str
|
59 |
+
parameters: Dict[str, Any]
|
60 |
+
|
61 |
+
|
62 |
+
class FunctionAvailable(BaseModel):
|
63 |
+
type: Literal["function", "code_interpreter"] = "function"
|
64 |
+
function: Optional[FunctionDefinition] = None
|
65 |
+
|
66 |
+
|
67 |
+
class FunctionCall(BaseModel):
|
68 |
+
id: str
|
69 |
+
type: Literal["function"] = "function"
|
70 |
+
function: Function
|
71 |
+
|
72 |
+
|
73 |
+
class ImageURL(BaseModel):
|
74 |
+
url: str
|
75 |
+
|
76 |
+
|
77 |
+
class MultimodalInputItem(BaseModel):
|
78 |
+
type: Literal["text", "image_url"]
|
79 |
+
text: Optional[str] = None
|
80 |
+
image_url: Optional[ImageURL] = None
|
81 |
+
|
82 |
+
|
83 |
+
class ChatMessage(BaseModel):
|
84 |
+
role: Role
|
85 |
+
content: Optional[Union[str, List[MultimodalInputItem]]] = None
|
86 |
+
tool_calls: Optional[List[FunctionCall]] = None
|
87 |
+
|
88 |
+
|
89 |
+
class ChatCompletionMessage(BaseModel):
|
90 |
+
role: Optional[Role] = None
|
91 |
+
content: Optional[str] = None
|
92 |
+
tool_calls: Optional[List[FunctionCall]] = None
|
93 |
+
|
94 |
+
|
95 |
+
class ChatCompletionRequest(BaseModel):
|
96 |
+
model: str
|
97 |
+
messages: List[ChatMessage]
|
98 |
+
tools: Optional[List[FunctionAvailable]] = None
|
99 |
+
do_sample: Optional[bool] = None
|
100 |
+
temperature: Optional[float] = None
|
101 |
+
top_p: Optional[float] = None
|
102 |
+
n: int = 1
|
103 |
+
max_tokens: Optional[int] = None
|
104 |
+
stop: Optional[Union[str, List[str]]] = None
|
105 |
+
stream: bool = False
|
106 |
+
|
107 |
+
|
108 |
+
class ChatCompletionResponseChoice(BaseModel):
|
109 |
+
index: int
|
110 |
+
message: ChatCompletionMessage
|
111 |
+
finish_reason: Finish
|
112 |
+
|
113 |
+
|
114 |
+
class ChatCompletionStreamResponseChoice(BaseModel):
|
115 |
+
index: int
|
116 |
+
delta: ChatCompletionMessage
|
117 |
+
finish_reason: Optional[Finish] = None
|
118 |
+
|
119 |
+
|
120 |
+
class ChatCompletionResponseUsage(BaseModel):
|
121 |
+
prompt_tokens: int
|
122 |
+
completion_tokens: int
|
123 |
+
total_tokens: int
|
124 |
+
|
125 |
+
|
126 |
+
class ChatCompletionResponse(BaseModel):
|
127 |
+
id: str
|
128 |
+
object: Literal["chat.completion"] = "chat.completion"
|
129 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
130 |
+
model: str
|
131 |
+
choices: List[ChatCompletionResponseChoice]
|
132 |
+
usage: ChatCompletionResponseUsage
|
133 |
+
|
134 |
+
|
135 |
+
class ChatCompletionStreamResponse(BaseModel):
|
136 |
+
id: str
|
137 |
+
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
138 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
139 |
+
model: str
|
140 |
+
choices: List[ChatCompletionStreamResponseChoice]
|
141 |
+
|
142 |
+
|
143 |
+
class ScoreEvaluationRequest(BaseModel):
|
144 |
+
model: str
|
145 |
+
messages: List[str]
|
146 |
+
max_length: Optional[int] = None
|
147 |
+
|
148 |
+
|
149 |
+
class ScoreEvaluationResponse(BaseModel):
|
150 |
+
id: str
|
151 |
+
object: Literal["score.evaluation"] = "score.evaluation"
|
152 |
+
model: str
|
153 |
+
scores: List[float]
|
llamafactory/chat/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .base_engine import BaseEngine
|
16 |
+
from .chat_model import ChatModel
|
17 |
+
|
18 |
+
|
19 |
+
__all__ = ["BaseEngine", "ChatModel"]
|
llamafactory/chat/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (321 Bytes). View file
|
|
llamafactory/chat/__pycache__/base_engine.cpython-311.pyc
ADDED
Binary file (4.03 kB). View file
|
|
llamafactory/chat/__pycache__/chat_model.cpython-311.pyc
ADDED
Binary file (8.77 kB). View file
|
|
llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc
ADDED
Binary file (18.8 kB). View file
|
|
llamafactory/chat/__pycache__/vllm_engine.cpython-311.pyc
ADDED
Binary file (11.1 kB). View file
|
|
llamafactory/chat/base_engine.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
18 |
+
|
19 |
+
|
20 |
+
if TYPE_CHECKING:
|
21 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
22 |
+
from vllm import AsyncLLMEngine
|
23 |
+
|
24 |
+
from ..data import Template
|
25 |
+
from ..data.mm_plugin import ImageInput, VideoInput
|
26 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Response:
|
31 |
+
response_text: str
|
32 |
+
response_length: int
|
33 |
+
prompt_length: int
|
34 |
+
finish_reason: Literal["stop", "length"]
|
35 |
+
|
36 |
+
|
37 |
+
class BaseEngine(ABC):
|
38 |
+
r"""
|
39 |
+
Base class for inference engine of chat models.
|
40 |
+
|
41 |
+
Must implements async methods: chat(), stream_chat() and get_scores().
|
42 |
+
"""
|
43 |
+
|
44 |
+
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
45 |
+
tokenizer: "PreTrainedTokenizer"
|
46 |
+
can_generate: bool
|
47 |
+
template: "Template"
|
48 |
+
generating_args: Dict[str, Any]
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
model_args: "ModelArguments",
|
54 |
+
data_args: "DataArguments",
|
55 |
+
finetuning_args: "FinetuningArguments",
|
56 |
+
generating_args: "GeneratingArguments",
|
57 |
+
) -> None:
|
58 |
+
r"""
|
59 |
+
Initializes an inference engine.
|
60 |
+
"""
|
61 |
+
...
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
async def chat(
|
65 |
+
self,
|
66 |
+
messages: Sequence[Dict[str, str]],
|
67 |
+
system: Optional[str] = None,
|
68 |
+
tools: Optional[str] = None,
|
69 |
+
image: Optional["ImageInput"] = None,
|
70 |
+
video: Optional["VideoInput"] = None,
|
71 |
+
**input_kwargs,
|
72 |
+
) -> List["Response"]:
|
73 |
+
r"""
|
74 |
+
Gets a list of responses of the chat model.
|
75 |
+
"""
|
76 |
+
...
|
77 |
+
|
78 |
+
@abstractmethod
|
79 |
+
async def stream_chat(
|
80 |
+
self,
|
81 |
+
messages: Sequence[Dict[str, str]],
|
82 |
+
system: Optional[str] = None,
|
83 |
+
tools: Optional[str] = None,
|
84 |
+
image: Optional["ImageInput"] = None,
|
85 |
+
video: Optional["VideoInput"] = None,
|
86 |
+
**input_kwargs,
|
87 |
+
) -> AsyncGenerator[str, None]:
|
88 |
+
r"""
|
89 |
+
Gets the response token-by-token of the chat model.
|
90 |
+
"""
|
91 |
+
...
|
92 |
+
|
93 |
+
@abstractmethod
|
94 |
+
async def get_scores(
|
95 |
+
self,
|
96 |
+
batch_input: List[str],
|
97 |
+
**input_kwargs,
|
98 |
+
) -> List[float]:
|
99 |
+
r"""
|
100 |
+
Gets a list of scores of the reward model.
|
101 |
+
"""
|
102 |
+
...
|
llamafactory/chat/chat_model.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 THUDM and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the THUDM's ChatGLM implementation.
|
4 |
+
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import asyncio
|
19 |
+
import os
|
20 |
+
from threading import Thread
|
21 |
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
22 |
+
|
23 |
+
from ..extras.misc import torch_gc
|
24 |
+
from ..hparams import get_infer_args
|
25 |
+
from .hf_engine import HuggingfaceEngine
|
26 |
+
from .vllm_engine import VllmEngine
|
27 |
+
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
from ..data.mm_plugin import ImageInput, VideoInput
|
31 |
+
from .base_engine import BaseEngine, Response
|
32 |
+
|
33 |
+
|
34 |
+
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
35 |
+
asyncio.set_event_loop(loop)
|
36 |
+
loop.run_forever()
|
37 |
+
|
38 |
+
|
39 |
+
class ChatModel:
|
40 |
+
r"""
|
41 |
+
General class for chat models. Backed by huggingface or vllm engines.
|
42 |
+
|
43 |
+
Supports both sync and async methods.
|
44 |
+
Sync methods: chat(), stream_chat() and get_scores().
|
45 |
+
Async methods: achat(), astream_chat() and aget_scores().
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
49 |
+
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
50 |
+
self.engine_type = model_args.infer_backend
|
51 |
+
if model_args.infer_backend == "huggingface":
|
52 |
+
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
53 |
+
elif model_args.infer_backend == "vllm":
|
54 |
+
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
55 |
+
else:
|
56 |
+
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
57 |
+
|
58 |
+
self._loop = asyncio.new_event_loop()
|
59 |
+
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
60 |
+
self._thread.start()
|
61 |
+
|
62 |
+
def chat(
|
63 |
+
self,
|
64 |
+
messages: Sequence[Dict[str, str]],
|
65 |
+
system: Optional[str] = None,
|
66 |
+
tools: Optional[str] = None,
|
67 |
+
image: Optional["ImageInput"] = None,
|
68 |
+
video: Optional["VideoInput"] = None,
|
69 |
+
**input_kwargs,
|
70 |
+
) -> List["Response"]:
|
71 |
+
r"""
|
72 |
+
Gets a list of responses of the chat model.
|
73 |
+
"""
|
74 |
+
task = asyncio.run_coroutine_threadsafe(
|
75 |
+
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
76 |
+
)
|
77 |
+
return task.result()
|
78 |
+
|
79 |
+
async def achat(
|
80 |
+
self,
|
81 |
+
messages: Sequence[Dict[str, str]],
|
82 |
+
system: Optional[str] = None,
|
83 |
+
tools: Optional[str] = None,
|
84 |
+
image: Optional["ImageInput"] = None,
|
85 |
+
video: Optional["VideoInput"] = None,
|
86 |
+
**input_kwargs,
|
87 |
+
) -> List["Response"]:
|
88 |
+
r"""
|
89 |
+
Asynchronously gets a list of responses of the chat model.
|
90 |
+
"""
|
91 |
+
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
92 |
+
|
93 |
+
def stream_chat(
|
94 |
+
self,
|
95 |
+
messages: Sequence[Dict[str, str]],
|
96 |
+
system: Optional[str] = None,
|
97 |
+
tools: Optional[str] = None,
|
98 |
+
image: Optional["ImageInput"] = None,
|
99 |
+
video: Optional["VideoInput"] = None,
|
100 |
+
**input_kwargs,
|
101 |
+
) -> Generator[str, None, None]:
|
102 |
+
r"""
|
103 |
+
Gets the response token-by-token of the chat model.
|
104 |
+
"""
|
105 |
+
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
106 |
+
while True:
|
107 |
+
try:
|
108 |
+
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
109 |
+
yield task.result()
|
110 |
+
except StopAsyncIteration:
|
111 |
+
break
|
112 |
+
|
113 |
+
async def astream_chat(
|
114 |
+
self,
|
115 |
+
messages: Sequence[Dict[str, str]],
|
116 |
+
system: Optional[str] = None,
|
117 |
+
tools: Optional[str] = None,
|
118 |
+
image: Optional["ImageInput"] = None,
|
119 |
+
video: Optional["VideoInput"] = None,
|
120 |
+
**input_kwargs,
|
121 |
+
) -> AsyncGenerator[str, None]:
|
122 |
+
r"""
|
123 |
+
Asynchronously gets the response token-by-token of the chat model.
|
124 |
+
"""
|
125 |
+
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
126 |
+
yield new_token
|
127 |
+
|
128 |
+
def get_scores(
|
129 |
+
self,
|
130 |
+
batch_input: List[str],
|
131 |
+
**input_kwargs,
|
132 |
+
) -> List[float]:
|
133 |
+
r"""
|
134 |
+
Gets a list of scores of the reward model.
|
135 |
+
"""
|
136 |
+
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
137 |
+
return task.result()
|
138 |
+
|
139 |
+
async def aget_scores(
|
140 |
+
self,
|
141 |
+
batch_input: List[str],
|
142 |
+
**input_kwargs,
|
143 |
+
) -> List[float]:
|
144 |
+
r"""
|
145 |
+
Asynchronously gets a list of scores of the reward model.
|
146 |
+
"""
|
147 |
+
return await self.engine.get_scores(batch_input, **input_kwargs)
|
148 |
+
|
149 |
+
|
150 |
+
def run_chat() -> None:
|
151 |
+
if os.name != "nt":
|
152 |
+
try:
|
153 |
+
import readline # noqa: F401
|
154 |
+
except ImportError:
|
155 |
+
print("Install `readline` for a better experience.")
|
156 |
+
|
157 |
+
chat_model = ChatModel()
|
158 |
+
messages = []
|
159 |
+
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
160 |
+
|
161 |
+
while True:
|
162 |
+
try:
|
163 |
+
query = input("\nUser: ")
|
164 |
+
except UnicodeDecodeError:
|
165 |
+
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
166 |
+
continue
|
167 |
+
except Exception:
|
168 |
+
raise
|
169 |
+
|
170 |
+
if query.strip() == "exit":
|
171 |
+
break
|
172 |
+
|
173 |
+
if query.strip() == "clear":
|
174 |
+
messages = []
|
175 |
+
torch_gc()
|
176 |
+
print("History has been removed.")
|
177 |
+
continue
|
178 |
+
|
179 |
+
messages.append({"role": "user", "content": query})
|
180 |
+
print("Assistant: ", end="", flush=True)
|
181 |
+
|
182 |
+
response = ""
|
183 |
+
for new_text in chat_model.stream_chat(messages):
|
184 |
+
print(new_text, end="", flush=True)
|
185 |
+
response += new_text
|
186 |
+
print()
|
187 |
+
messages.append({"role": "assistant", "content": response})
|
llamafactory/chat/hf_engine.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import asyncio
|
16 |
+
import concurrent.futures
|
17 |
+
import os
|
18 |
+
from threading import Thread
|
19 |
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from transformers import GenerationConfig, TextIteratorStreamer
|
23 |
+
from typing_extensions import override
|
24 |
+
|
25 |
+
from ..data import get_template_and_fix_tokenizer
|
26 |
+
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
27 |
+
from ..extras.logging import get_logger
|
28 |
+
from ..extras.misc import get_logits_processor
|
29 |
+
from ..model import load_model, load_tokenizer
|
30 |
+
from .base_engine import BaseEngine, Response
|
31 |
+
|
32 |
+
|
33 |
+
if TYPE_CHECKING:
|
34 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
35 |
+
from trl import PreTrainedModelWrapper
|
36 |
+
|
37 |
+
from ..data import Template
|
38 |
+
from ..data.mm_plugin import ImageInput, VideoInput
|
39 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
40 |
+
|
41 |
+
|
42 |
+
logger = get_logger(__name__)
|
43 |
+
|
44 |
+
|
45 |
+
class HuggingfaceEngine(BaseEngine):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
model_args: "ModelArguments",
|
49 |
+
data_args: "DataArguments",
|
50 |
+
finetuning_args: "FinetuningArguments",
|
51 |
+
generating_args: "GeneratingArguments",
|
52 |
+
) -> None:
|
53 |
+
self.can_generate = finetuning_args.stage == "sft"
|
54 |
+
tokenizer_module = load_tokenizer(model_args)
|
55 |
+
self.tokenizer = tokenizer_module["tokenizer"]
|
56 |
+
self.processor = tokenizer_module["processor"]
|
57 |
+
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
58 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
59 |
+
self.model = load_model(
|
60 |
+
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
61 |
+
) # must after fixing tokenizer to resize vocab
|
62 |
+
self.generating_args = generating_args.to_dict()
|
63 |
+
try:
|
64 |
+
asyncio.get_event_loop()
|
65 |
+
except RuntimeError:
|
66 |
+
logger.warning("There is no current event loop, creating a new one.")
|
67 |
+
loop = asyncio.new_event_loop()
|
68 |
+
asyncio.set_event_loop(loop)
|
69 |
+
|
70 |
+
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def _process_args(
|
74 |
+
model: "PreTrainedModel",
|
75 |
+
tokenizer: "PreTrainedTokenizer",
|
76 |
+
processor: Optional["ProcessorMixin"],
|
77 |
+
template: "Template",
|
78 |
+
generating_args: Dict[str, Any],
|
79 |
+
messages: Sequence[Dict[str, str]],
|
80 |
+
system: Optional[str] = None,
|
81 |
+
tools: Optional[str] = None,
|
82 |
+
image: Optional["ImageInput"] = None,
|
83 |
+
video: Optional["VideoInput"] = None,
|
84 |
+
input_kwargs: Optional[Dict[str, Any]] = {},
|
85 |
+
) -> Tuple[Dict[str, Any], int]:
|
86 |
+
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
87 |
+
if image is not None:
|
88 |
+
mm_input_dict.update({"images": [image], "imglens": [1]})
|
89 |
+
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
90 |
+
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
91 |
+
|
92 |
+
if video is not None:
|
93 |
+
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
94 |
+
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
95 |
+
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
96 |
+
|
97 |
+
messages = template.mm_plugin.process_messages(
|
98 |
+
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
99 |
+
)
|
100 |
+
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
101 |
+
system = system or generating_args["default_system"]
|
102 |
+
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
103 |
+
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
104 |
+
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
|
105 |
+
)
|
106 |
+
prompt_length = len(prompt_ids)
|
107 |
+
inputs = torch.tensor([prompt_ids], device=model.device)
|
108 |
+
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
109 |
+
|
110 |
+
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
111 |
+
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
112 |
+
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
113 |
+
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
114 |
+
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
115 |
+
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
116 |
+
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
117 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
118 |
+
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
119 |
+
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
120 |
+
|
121 |
+
if stop is not None:
|
122 |
+
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
123 |
+
|
124 |
+
generating_args = generating_args.copy()
|
125 |
+
generating_args.update(
|
126 |
+
dict(
|
127 |
+
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
128 |
+
temperature=temperature if temperature is not None else generating_args["temperature"],
|
129 |
+
top_p=top_p if top_p is not None else generating_args["top_p"],
|
130 |
+
top_k=top_k if top_k is not None else generating_args["top_k"],
|
131 |
+
num_return_sequences=num_return_sequences,
|
132 |
+
repetition_penalty=repetition_penalty
|
133 |
+
if repetition_penalty is not None
|
134 |
+
else generating_args["repetition_penalty"],
|
135 |
+
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
136 |
+
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
137 |
+
pad_token_id=tokenizer.pad_token_id,
|
138 |
+
)
|
139 |
+
)
|
140 |
+
|
141 |
+
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
|
142 |
+
generating_args["do_sample"] = True
|
143 |
+
generating_args["temperature"] = generating_args["temperature"] or 1.0
|
144 |
+
|
145 |
+
if not generating_args["temperature"]:
|
146 |
+
generating_args["do_sample"] = False
|
147 |
+
|
148 |
+
if not generating_args["do_sample"]:
|
149 |
+
generating_args.pop("temperature", None)
|
150 |
+
generating_args.pop("top_p", None)
|
151 |
+
|
152 |
+
if max_length:
|
153 |
+
generating_args.pop("max_new_tokens", None)
|
154 |
+
generating_args["max_length"] = max_length
|
155 |
+
|
156 |
+
if max_new_tokens:
|
157 |
+
generating_args.pop("max_length", None)
|
158 |
+
generating_args["max_new_tokens"] = max_new_tokens
|
159 |
+
|
160 |
+
gen_kwargs = dict(
|
161 |
+
inputs=inputs,
|
162 |
+
attention_mask=attention_mask,
|
163 |
+
generation_config=GenerationConfig(**generating_args),
|
164 |
+
logits_processor=get_logits_processor(),
|
165 |
+
)
|
166 |
+
|
167 |
+
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
168 |
+
for key, value in mm_inputs.items():
|
169 |
+
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
170 |
+
gen_kwargs[key] = value.to(model.device)
|
171 |
+
|
172 |
+
return gen_kwargs, prompt_length
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
@torch.inference_mode()
|
176 |
+
def _chat(
|
177 |
+
model: "PreTrainedModel",
|
178 |
+
tokenizer: "PreTrainedTokenizer",
|
179 |
+
processor: Optional["ProcessorMixin"],
|
180 |
+
template: "Template",
|
181 |
+
generating_args: Dict[str, Any],
|
182 |
+
messages: Sequence[Dict[str, str]],
|
183 |
+
system: Optional[str] = None,
|
184 |
+
tools: Optional[str] = None,
|
185 |
+
image: Optional["ImageInput"] = None,
|
186 |
+
video: Optional["VideoInput"] = None,
|
187 |
+
input_kwargs: Optional[Dict[str, Any]] = {},
|
188 |
+
) -> List["Response"]:
|
189 |
+
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
190 |
+
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
191 |
+
)
|
192 |
+
generate_output = model.generate(**gen_kwargs)
|
193 |
+
response_ids = generate_output[:, prompt_length:]
|
194 |
+
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
195 |
+
results = []
|
196 |
+
for i in range(len(response)):
|
197 |
+
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
198 |
+
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
199 |
+
results.append(
|
200 |
+
Response(
|
201 |
+
response_text=response[i],
|
202 |
+
response_length=response_length,
|
203 |
+
prompt_length=prompt_length,
|
204 |
+
finish_reason="stop" if len(eos_index) else "length",
|
205 |
+
)
|
206 |
+
)
|
207 |
+
|
208 |
+
return results
|
209 |
+
|
210 |
+
@staticmethod
|
211 |
+
@torch.inference_mode()
|
212 |
+
def _stream_chat(
|
213 |
+
model: "PreTrainedModel",
|
214 |
+
tokenizer: "PreTrainedTokenizer",
|
215 |
+
processor: Optional["ProcessorMixin"],
|
216 |
+
template: "Template",
|
217 |
+
generating_args: Dict[str, Any],
|
218 |
+
messages: Sequence[Dict[str, str]],
|
219 |
+
system: Optional[str] = None,
|
220 |
+
tools: Optional[str] = None,
|
221 |
+
image: Optional["ImageInput"] = None,
|
222 |
+
video: Optional["VideoInput"] = None,
|
223 |
+
input_kwargs: Optional[Dict[str, Any]] = {},
|
224 |
+
) -> Callable[[], str]:
|
225 |
+
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
226 |
+
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
227 |
+
)
|
228 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
229 |
+
gen_kwargs["streamer"] = streamer
|
230 |
+
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
231 |
+
thread.start()
|
232 |
+
|
233 |
+
def stream():
|
234 |
+
try:
|
235 |
+
return streamer.__next__()
|
236 |
+
except StopIteration:
|
237 |
+
raise StopAsyncIteration()
|
238 |
+
|
239 |
+
return stream
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
@torch.inference_mode()
|
243 |
+
def _get_scores(
|
244 |
+
model: "PreTrainedModelWrapper",
|
245 |
+
tokenizer: "PreTrainedTokenizer",
|
246 |
+
batch_input: List[str],
|
247 |
+
input_kwargs: Optional[Dict[str, Any]] = {},
|
248 |
+
) -> List[float]:
|
249 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
250 |
+
device = getattr(model.pretrained_model, "device", "cuda")
|
251 |
+
inputs: Dict[str, "torch.Tensor"] = tokenizer(
|
252 |
+
batch_input,
|
253 |
+
padding=True,
|
254 |
+
truncation=True,
|
255 |
+
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
256 |
+
return_tensors="pt",
|
257 |
+
add_special_tokens=False,
|
258 |
+
).to(device)
|
259 |
+
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
|
260 |
+
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
261 |
+
return scores
|
262 |
+
|
263 |
+
@override
|
264 |
+
async def chat(
|
265 |
+
self,
|
266 |
+
messages: Sequence[Dict[str, str]],
|
267 |
+
system: Optional[str] = None,
|
268 |
+
tools: Optional[str] = None,
|
269 |
+
image: Optional["ImageInput"] = None,
|
270 |
+
video: Optional["VideoInput"] = None,
|
271 |
+
**input_kwargs,
|
272 |
+
) -> List["Response"]:
|
273 |
+
if not self.can_generate:
|
274 |
+
raise ValueError("The current model does not support `chat`.")
|
275 |
+
|
276 |
+
loop = asyncio.get_running_loop()
|
277 |
+
input_args = (
|
278 |
+
self.model,
|
279 |
+
self.tokenizer,
|
280 |
+
self.processor,
|
281 |
+
self.template,
|
282 |
+
self.generating_args,
|
283 |
+
messages,
|
284 |
+
system,
|
285 |
+
tools,
|
286 |
+
image,
|
287 |
+
video,
|
288 |
+
input_kwargs,
|
289 |
+
)
|
290 |
+
async with self.semaphore:
|
291 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
292 |
+
return await loop.run_in_executor(pool, self._chat, *input_args)
|
293 |
+
|
294 |
+
@override
|
295 |
+
async def stream_chat(
|
296 |
+
self,
|
297 |
+
messages: Sequence[Dict[str, str]],
|
298 |
+
system: Optional[str] = None,
|
299 |
+
tools: Optional[str] = None,
|
300 |
+
image: Optional["ImageInput"] = None,
|
301 |
+
video: Optional["VideoInput"] = None,
|
302 |
+
**input_kwargs,
|
303 |
+
) -> AsyncGenerator[str, None]:
|
304 |
+
if not self.can_generate:
|
305 |
+
raise ValueError("The current model does not support `stream_chat`.")
|
306 |
+
|
307 |
+
loop = asyncio.get_running_loop()
|
308 |
+
input_args = (
|
309 |
+
self.model,
|
310 |
+
self.tokenizer,
|
311 |
+
self.processor,
|
312 |
+
self.template,
|
313 |
+
self.generating_args,
|
314 |
+
messages,
|
315 |
+
system,
|
316 |
+
tools,
|
317 |
+
image,
|
318 |
+
video,
|
319 |
+
input_kwargs,
|
320 |
+
)
|
321 |
+
async with self.semaphore:
|
322 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
323 |
+
stream = self._stream_chat(*input_args)
|
324 |
+
while True:
|
325 |
+
try:
|
326 |
+
yield await loop.run_in_executor(pool, stream)
|
327 |
+
except StopAsyncIteration:
|
328 |
+
break
|
329 |
+
|
330 |
+
@override
|
331 |
+
async def get_scores(
|
332 |
+
self,
|
333 |
+
batch_input: List[str],
|
334 |
+
**input_kwargs,
|
335 |
+
) -> List[float]:
|
336 |
+
if self.can_generate:
|
337 |
+
raise ValueError("Cannot get scores using an auto-regressive model.")
|
338 |
+
|
339 |
+
loop = asyncio.get_running_loop()
|
340 |
+
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
341 |
+
async with self.semaphore:
|
342 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
343 |
+
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
llamafactory/chat/vllm_engine.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import uuid
|
16 |
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
17 |
+
|
18 |
+
from typing_extensions import override
|
19 |
+
|
20 |
+
from ..data import get_template_and_fix_tokenizer
|
21 |
+
from ..extras.constants import IMAGE_PLACEHOLDER
|
22 |
+
from ..extras.logging import get_logger
|
23 |
+
from ..extras.misc import get_device_count
|
24 |
+
from ..extras.packages import is_pillow_available, is_vllm_available
|
25 |
+
from ..model import load_config, load_tokenizer
|
26 |
+
from ..model.model_utils.quantization import QuantizationMethod
|
27 |
+
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
28 |
+
from .base_engine import BaseEngine, Response
|
29 |
+
|
30 |
+
|
31 |
+
if is_pillow_available():
|
32 |
+
from PIL import Image
|
33 |
+
from PIL.Image import Image as ImageObject
|
34 |
+
|
35 |
+
|
36 |
+
if is_vllm_available():
|
37 |
+
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
38 |
+
from vllm.lora.request import LoRARequest
|
39 |
+
|
40 |
+
|
41 |
+
if TYPE_CHECKING:
|
42 |
+
from ..data.mm_plugin import ImageInput, VideoInput
|
43 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
44 |
+
|
45 |
+
|
46 |
+
logger = get_logger(__name__)
|
47 |
+
|
48 |
+
|
49 |
+
class VllmEngine(BaseEngine):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
model_args: "ModelArguments",
|
53 |
+
data_args: "DataArguments",
|
54 |
+
finetuning_args: "FinetuningArguments",
|
55 |
+
generating_args: "GeneratingArguments",
|
56 |
+
) -> None:
|
57 |
+
config = load_config(model_args) # may download model from ms hub
|
58 |
+
if getattr(config, "quantization_config", None): # gptq models should use float16
|
59 |
+
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
60 |
+
quant_method = quantization_config.get("quant_method", "")
|
61 |
+
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
62 |
+
model_args.infer_dtype = "float16"
|
63 |
+
|
64 |
+
self.can_generate = finetuning_args.stage == "sft"
|
65 |
+
tokenizer_module = load_tokenizer(model_args)
|
66 |
+
self.tokenizer = tokenizer_module["tokenizer"]
|
67 |
+
self.processor = tokenizer_module["processor"]
|
68 |
+
self.tokenizer.padding_side = "left"
|
69 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
70 |
+
self.generating_args = generating_args.to_dict()
|
71 |
+
|
72 |
+
engine_args = {
|
73 |
+
"model": model_args.model_name_or_path,
|
74 |
+
"trust_remote_code": True,
|
75 |
+
"download_dir": model_args.cache_dir,
|
76 |
+
"dtype": model_args.infer_dtype,
|
77 |
+
"max_model_len": model_args.vllm_maxlen,
|
78 |
+
"tensor_parallel_size": get_device_count() or 1,
|
79 |
+
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
80 |
+
"disable_log_stats": True,
|
81 |
+
"disable_log_requests": True,
|
82 |
+
"enforce_eager": model_args.vllm_enforce_eager,
|
83 |
+
"enable_lora": model_args.adapter_name_or_path is not None,
|
84 |
+
"max_lora_rank": model_args.vllm_max_lora_rank,
|
85 |
+
}
|
86 |
+
|
87 |
+
if getattr(config, "is_yi_vl_derived_model", None):
|
88 |
+
import vllm.model_executor.models.llava
|
89 |
+
|
90 |
+
logger.info("Detected Yi-VL model, applying projector patch.")
|
91 |
+
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
92 |
+
|
93 |
+
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
94 |
+
if model_args.adapter_name_or_path is not None:
|
95 |
+
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
96 |
+
else:
|
97 |
+
self.lora_request = None
|
98 |
+
|
99 |
+
async def _generate(
|
100 |
+
self,
|
101 |
+
messages: Sequence[Dict[str, str]],
|
102 |
+
system: Optional[str] = None,
|
103 |
+
tools: Optional[str] = None,
|
104 |
+
image: Optional["ImageInput"] = None,
|
105 |
+
video: Optional["VideoInput"] = None,
|
106 |
+
**input_kwargs,
|
107 |
+
) -> AsyncIterator["RequestOutput"]:
|
108 |
+
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
109 |
+
if image is not None:
|
110 |
+
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
111 |
+
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
112 |
+
|
113 |
+
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
114 |
+
system = system or self.generating_args["default_system"]
|
115 |
+
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
116 |
+
prompt_length = len(prompt_ids)
|
117 |
+
|
118 |
+
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
119 |
+
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
120 |
+
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
121 |
+
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
122 |
+
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
123 |
+
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
124 |
+
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
125 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
126 |
+
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
127 |
+
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
128 |
+
|
129 |
+
if "max_new_tokens" in self.generating_args:
|
130 |
+
max_tokens = self.generating_args["max_new_tokens"]
|
131 |
+
elif "max_length" in self.generating_args:
|
132 |
+
if self.generating_args["max_length"] > prompt_length:
|
133 |
+
max_tokens = self.generating_args["max_length"] - prompt_length
|
134 |
+
else:
|
135 |
+
max_tokens = 1
|
136 |
+
|
137 |
+
if max_length:
|
138 |
+
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
|
139 |
+
|
140 |
+
if max_new_tokens:
|
141 |
+
max_tokens = max_new_tokens
|
142 |
+
|
143 |
+
sampling_params = SamplingParams(
|
144 |
+
n=num_return_sequences,
|
145 |
+
repetition_penalty=(
|
146 |
+
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
|
147 |
+
)
|
148 |
+
or 1.0, # repetition_penalty must > 0
|
149 |
+
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
150 |
+
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
151 |
+
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
152 |
+
use_beam_search=use_beam_search,
|
153 |
+
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
|
154 |
+
stop=stop,
|
155 |
+
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
156 |
+
max_tokens=max_tokens,
|
157 |
+
skip_special_tokens=True,
|
158 |
+
)
|
159 |
+
|
160 |
+
if image is not None: # add image features
|
161 |
+
if not isinstance(image, (str, ImageObject)):
|
162 |
+
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
|
163 |
+
|
164 |
+
if isinstance(image, str):
|
165 |
+
image = Image.open(image).convert("RGB")
|
166 |
+
|
167 |
+
multi_modal_data = {"image": image}
|
168 |
+
else:
|
169 |
+
multi_modal_data = None
|
170 |
+
|
171 |
+
result_generator = self.model.generate(
|
172 |
+
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
173 |
+
sampling_params=sampling_params,
|
174 |
+
request_id=request_id,
|
175 |
+
lora_request=self.lora_request,
|
176 |
+
)
|
177 |
+
return result_generator
|
178 |
+
|
179 |
+
@override
|
180 |
+
async def chat(
|
181 |
+
self,
|
182 |
+
messages: Sequence[Dict[str, str]],
|
183 |
+
system: Optional[str] = None,
|
184 |
+
tools: Optional[str] = None,
|
185 |
+
image: Optional["ImageInput"] = None,
|
186 |
+
video: Optional["VideoInput"] = None,
|
187 |
+
**input_kwargs,
|
188 |
+
) -> List["Response"]:
|
189 |
+
final_output = None
|
190 |
+
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
191 |
+
async for request_output in generator:
|
192 |
+
final_output = request_output
|
193 |
+
|
194 |
+
results = []
|
195 |
+
for output in final_output.outputs:
|
196 |
+
results.append(
|
197 |
+
Response(
|
198 |
+
response_text=output.text,
|
199 |
+
response_length=len(output.token_ids),
|
200 |
+
prompt_length=len(final_output.prompt_token_ids),
|
201 |
+
finish_reason=output.finish_reason,
|
202 |
+
)
|
203 |
+
)
|
204 |
+
|
205 |
+
return results
|
206 |
+
|
207 |
+
@override
|
208 |
+
async def stream_chat(
|
209 |
+
self,
|
210 |
+
messages: Sequence[Dict[str, str]],
|
211 |
+
system: Optional[str] = None,
|
212 |
+
tools: Optional[str] = None,
|
213 |
+
image: Optional["ImageInput"] = None,
|
214 |
+
video: Optional["VideoInput"] = None,
|
215 |
+
**input_kwargs,
|
216 |
+
) -> AsyncGenerator[str, None]:
|
217 |
+
generated_text = ""
|
218 |
+
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
219 |
+
async for result in generator:
|
220 |
+
delta_text = result.outputs[0].text[len(generated_text) :]
|
221 |
+
generated_text = result.outputs[0].text
|
222 |
+
yield delta_text
|
223 |
+
|
224 |
+
@override
|
225 |
+
async def get_scores(
|
226 |
+
self,
|
227 |
+
batch_input: List[str],
|
228 |
+
**input_kwargs,
|
229 |
+
) -> List[float]:
|
230 |
+
raise NotImplementedError("vLLM engine does not support get_scores.")
|
llamafactory/cli.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import random
|
17 |
+
import subprocess
|
18 |
+
import sys
|
19 |
+
from enum import Enum, unique
|
20 |
+
|
21 |
+
from . import launcher
|
22 |
+
from .api.app import run_api
|
23 |
+
from .chat.chat_model import run_chat
|
24 |
+
from .eval.evaluator import run_eval
|
25 |
+
from .extras.env import VERSION, print_env
|
26 |
+
from .extras.logging import get_logger
|
27 |
+
from .extras.misc import get_device_count
|
28 |
+
from .train.tuner import export_model, run_exp
|
29 |
+
from .webui.interface import run_web_demo, run_web_ui
|
30 |
+
|
31 |
+
|
32 |
+
USAGE = (
|
33 |
+
"-" * 70
|
34 |
+
+ "\n"
|
35 |
+
+ "| Usage: |\n"
|
36 |
+
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
|
37 |
+
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
|
38 |
+
+ "| llamafactory-cli eval -h: evaluate models |\n"
|
39 |
+
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
|
40 |
+
+ "| llamafactory-cli train -h: train models |\n"
|
41 |
+
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
|
42 |
+
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
|
43 |
+
+ "| llamafactory-cli version: show version info |\n"
|
44 |
+
+ "-" * 70
|
45 |
+
)
|
46 |
+
|
47 |
+
WELCOME = (
|
48 |
+
"-" * 58
|
49 |
+
+ "\n"
|
50 |
+
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
51 |
+
+ " " * (21 - len(VERSION))
|
52 |
+
+ "|\n|"
|
53 |
+
+ " " * 56
|
54 |
+
+ "|\n"
|
55 |
+
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
56 |
+
+ "-" * 58
|
57 |
+
)
|
58 |
+
|
59 |
+
logger = get_logger(__name__)
|
60 |
+
|
61 |
+
|
62 |
+
@unique
|
63 |
+
class Command(str, Enum):
|
64 |
+
API = "api"
|
65 |
+
CHAT = "chat"
|
66 |
+
ENV = "env"
|
67 |
+
EVAL = "eval"
|
68 |
+
EXPORT = "export"
|
69 |
+
TRAIN = "train"
|
70 |
+
WEBDEMO = "webchat"
|
71 |
+
WEBUI = "webui"
|
72 |
+
VER = "version"
|
73 |
+
HELP = "help"
|
74 |
+
|
75 |
+
|
76 |
+
def main():
|
77 |
+
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
78 |
+
if command == Command.API:
|
79 |
+
run_api()
|
80 |
+
elif command == Command.CHAT:
|
81 |
+
run_chat()
|
82 |
+
elif command == Command.ENV:
|
83 |
+
print_env()
|
84 |
+
elif command == Command.EVAL:
|
85 |
+
run_eval()
|
86 |
+
elif command == Command.EXPORT:
|
87 |
+
export_model()
|
88 |
+
elif command == Command.TRAIN:
|
89 |
+
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
90 |
+
if force_torchrun or get_device_count() > 1:
|
91 |
+
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
92 |
+
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
93 |
+
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
94 |
+
process = subprocess.run(
|
95 |
+
(
|
96 |
+
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
97 |
+
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
98 |
+
).format(
|
99 |
+
nnodes=os.environ.get("NNODES", "1"),
|
100 |
+
node_rank=os.environ.get("RANK", "0"),
|
101 |
+
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
102 |
+
master_addr=master_addr,
|
103 |
+
master_port=master_port,
|
104 |
+
file_name=launcher.__file__,
|
105 |
+
args=" ".join(sys.argv[1:]),
|
106 |
+
),
|
107 |
+
shell=True,
|
108 |
+
)
|
109 |
+
sys.exit(process.returncode)
|
110 |
+
else:
|
111 |
+
run_exp()
|
112 |
+
elif command == Command.WEBDEMO:
|
113 |
+
run_web_demo()
|
114 |
+
elif command == Command.WEBUI:
|
115 |
+
run_web_ui()
|
116 |
+
elif command == Command.VER:
|
117 |
+
print(WELCOME)
|
118 |
+
elif command == Command.HELP:
|
119 |
+
print(USAGE)
|
120 |
+
else:
|
121 |
+
raise NotImplementedError("Unknown command: {}.".format(command))
|
llamafactory/data/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .collator import (
|
16 |
+
KTODataCollatorWithPadding,
|
17 |
+
MultiModalDataCollatorForSeq2Seq,
|
18 |
+
PairwiseDataCollatorWithPadding,
|
19 |
+
SFTDataCollatorWith4DAttentionMask,
|
20 |
+
)
|
21 |
+
from .data_utils import Role, split_dataset
|
22 |
+
from .loader import get_dataset
|
23 |
+
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
24 |
+
|
25 |
+
|
26 |
+
__all__ = [
|
27 |
+
"KTODataCollatorWithPadding",
|
28 |
+
"MultiModalDataCollatorForSeq2Seq",
|
29 |
+
"PairwiseDataCollatorWithPadding",
|
30 |
+
"SFTDataCollatorWith4DAttentionMask",
|
31 |
+
"Role",
|
32 |
+
"split_dataset",
|
33 |
+
"get_dataset",
|
34 |
+
"TEMPLATES",
|
35 |
+
"Template",
|
36 |
+
"get_template_and_fix_tokenizer",
|
37 |
+
]
|
llamafactory/data/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (766 Bytes). View file
|
|
llamafactory/data/__pycache__/aligner.cpython-311.pyc
ADDED
Binary file (11.6 kB). View file
|
|
llamafactory/data/__pycache__/collator.cpython-311.pyc
ADDED
Binary file (9.38 kB). View file
|
|
llamafactory/data/__pycache__/data_utils.cpython-311.pyc
ADDED
Binary file (4.43 kB). View file
|
|
llamafactory/data/__pycache__/formatter.cpython-311.pyc
ADDED
Binary file (8.97 kB). View file
|
|
llamafactory/data/__pycache__/loader.cpython-311.pyc
ADDED
Binary file (13.8 kB). View file
|
|
llamafactory/data/__pycache__/mm_plugin.cpython-311.pyc
ADDED
Binary file (32.6 kB). View file
|
|
llamafactory/data/__pycache__/parser.cpython-311.pyc
ADDED
Binary file (7.31 kB). View file
|
|
llamafactory/data/__pycache__/preprocess.cpython-311.pyc
ADDED
Binary file (3.63 kB). View file
|
|
llamafactory/data/__pycache__/template.cpython-311.pyc
ADDED
Binary file (41.7 kB). View file
|
|
llamafactory/data/__pycache__/tool_utils.cpython-311.pyc
ADDED
Binary file (8.97 kB). View file
|
|
llamafactory/data/aligner.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from functools import partial
|
17 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
18 |
+
|
19 |
+
from ..extras.logging import get_logger
|
20 |
+
from .data_utils import Role
|
21 |
+
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
from datasets import Dataset, IterableDataset
|
25 |
+
from transformers import Seq2SeqTrainingArguments
|
26 |
+
|
27 |
+
from ..hparams import DataArguments
|
28 |
+
from .mm_plugin import ImageInput, VideoInput
|
29 |
+
from .parser import DatasetAttr
|
30 |
+
|
31 |
+
|
32 |
+
logger = get_logger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def _convert_images(
|
36 |
+
images: Sequence["ImageInput"],
|
37 |
+
dataset_attr: "DatasetAttr",
|
38 |
+
data_args: "DataArguments",
|
39 |
+
) -> Optional[List["ImageInput"]]:
|
40 |
+
r"""
|
41 |
+
Optionally concatenates image path to dataset dir when loading from local disk.
|
42 |
+
"""
|
43 |
+
if len(images) == 0:
|
44 |
+
return None
|
45 |
+
|
46 |
+
images = images[:]
|
47 |
+
if dataset_attr.load_from in ["script", "file"]:
|
48 |
+
for i in range(len(images)):
|
49 |
+
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
|
50 |
+
images[i] = os.path.join(data_args.dataset_dir, images[i])
|
51 |
+
|
52 |
+
return images
|
53 |
+
|
54 |
+
|
55 |
+
def _convert_videos(
|
56 |
+
videos: Sequence["VideoInput"],
|
57 |
+
dataset_attr: "DatasetAttr",
|
58 |
+
data_args: "DataArguments",
|
59 |
+
) -> Optional[List["VideoInput"]]:
|
60 |
+
r"""
|
61 |
+
Optionally concatenates video path to dataset dir when loading from local disk.
|
62 |
+
"""
|
63 |
+
if len(videos) == 0:
|
64 |
+
return None
|
65 |
+
|
66 |
+
videos = videos[:]
|
67 |
+
if dataset_attr.load_from in ["script", "file"]:
|
68 |
+
for i in range(len(videos)):
|
69 |
+
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
|
70 |
+
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
|
71 |
+
|
72 |
+
return videos
|
73 |
+
|
74 |
+
|
75 |
+
def convert_alpaca(
|
76 |
+
example: Dict[str, Any],
|
77 |
+
dataset_attr: "DatasetAttr",
|
78 |
+
data_args: "DataArguments",
|
79 |
+
) -> Dict[str, Any]:
|
80 |
+
r"""
|
81 |
+
Converts alpaca format dataset to the standard format.
|
82 |
+
"""
|
83 |
+
prompt = []
|
84 |
+
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
|
85 |
+
for old_prompt, old_response in example[dataset_attr.history]:
|
86 |
+
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
87 |
+
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
88 |
+
|
89 |
+
query = []
|
90 |
+
if dataset_attr.prompt and example[dataset_attr.prompt]:
|
91 |
+
query.append(example[dataset_attr.prompt])
|
92 |
+
|
93 |
+
if dataset_attr.query and example[dataset_attr.query]:
|
94 |
+
query.append(example[dataset_attr.query])
|
95 |
+
|
96 |
+
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
|
97 |
+
|
98 |
+
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
99 |
+
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
100 |
+
if example[dataset_attr.kto_tag]:
|
101 |
+
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
102 |
+
else:
|
103 |
+
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
104 |
+
elif (
|
105 |
+
dataset_attr.ranking
|
106 |
+
and isinstance(example[dataset_attr.chosen], str)
|
107 |
+
and isinstance(example[dataset_attr.rejected], str)
|
108 |
+
): # pairwise example
|
109 |
+
response = [
|
110 |
+
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
|
111 |
+
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
|
112 |
+
]
|
113 |
+
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
|
114 |
+
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
115 |
+
else: # unsupervised
|
116 |
+
response = []
|
117 |
+
|
118 |
+
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
119 |
+
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
|
120 |
+
output = {
|
121 |
+
"_prompt": prompt,
|
122 |
+
"_response": response,
|
123 |
+
"_system": example[dataset_attr.system] if dataset_attr.system else "",
|
124 |
+
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
125 |
+
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
126 |
+
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
127 |
+
}
|
128 |
+
return output
|
129 |
+
|
130 |
+
|
131 |
+
def convert_sharegpt(
|
132 |
+
example: Dict[str, Any],
|
133 |
+
dataset_attr: "DatasetAttr",
|
134 |
+
data_args: "DataArguments",
|
135 |
+
) -> Dict[str, Any]:
|
136 |
+
r"""
|
137 |
+
Converts sharegpt format dataset to the standard format.
|
138 |
+
"""
|
139 |
+
tag_mapping = {
|
140 |
+
dataset_attr.user_tag: Role.USER.value,
|
141 |
+
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
142 |
+
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
143 |
+
dataset_attr.function_tag: Role.FUNCTION.value,
|
144 |
+
dataset_attr.system_tag: Role.SYSTEM.value,
|
145 |
+
}
|
146 |
+
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
147 |
+
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
148 |
+
accept_tags = (odd_tags, even_tags)
|
149 |
+
messages = example[dataset_attr.messages]
|
150 |
+
if (
|
151 |
+
dataset_attr.system_tag
|
152 |
+
and len(messages) != 0
|
153 |
+
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
154 |
+
):
|
155 |
+
system = messages[0][dataset_attr.content_tag]
|
156 |
+
messages = messages[1:]
|
157 |
+
else:
|
158 |
+
system = example[dataset_attr.system] if dataset_attr.system else ""
|
159 |
+
|
160 |
+
aligned_messages = []
|
161 |
+
broken_data = False
|
162 |
+
for turn_idx, message in enumerate(messages):
|
163 |
+
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
164 |
+
logger.warning("Invalid role tag in {}.".format(messages))
|
165 |
+
broken_data = True
|
166 |
+
|
167 |
+
aligned_messages.append(
|
168 |
+
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
169 |
+
)
|
170 |
+
|
171 |
+
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
172 |
+
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
173 |
+
):
|
174 |
+
logger.warning("Invalid message count in {}.".format(messages))
|
175 |
+
broken_data = True
|
176 |
+
|
177 |
+
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
178 |
+
prompt = aligned_messages[:-1]
|
179 |
+
response = aligned_messages[-1:]
|
180 |
+
if example[dataset_attr.kto_tag]:
|
181 |
+
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
182 |
+
else:
|
183 |
+
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
184 |
+
elif (
|
185 |
+
dataset_attr.ranking
|
186 |
+
and isinstance(example[dataset_attr.chosen], dict)
|
187 |
+
and isinstance(example[dataset_attr.rejected], dict)
|
188 |
+
): # pairwise example
|
189 |
+
chosen = example[dataset_attr.chosen]
|
190 |
+
rejected = example[dataset_attr.rejected]
|
191 |
+
if (
|
192 |
+
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
193 |
+
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
194 |
+
):
|
195 |
+
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
196 |
+
broken_data = True
|
197 |
+
|
198 |
+
prompt = aligned_messages
|
199 |
+
response = [
|
200 |
+
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
201 |
+
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
202 |
+
]
|
203 |
+
else: # normal example
|
204 |
+
prompt = aligned_messages[:-1]
|
205 |
+
response = aligned_messages[-1:]
|
206 |
+
|
207 |
+
if broken_data:
|
208 |
+
logger.warning("Skipping this abnormal example.")
|
209 |
+
prompt, response = [], []
|
210 |
+
|
211 |
+
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
212 |
+
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
|
213 |
+
output = {
|
214 |
+
"_prompt": prompt,
|
215 |
+
"_response": response,
|
216 |
+
"_system": system,
|
217 |
+
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
218 |
+
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
219 |
+
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
220 |
+
}
|
221 |
+
return output
|
222 |
+
|
223 |
+
|
224 |
+
def align_dataset(
|
225 |
+
dataset: Union["Dataset", "IterableDataset"],
|
226 |
+
dataset_attr: "DatasetAttr",
|
227 |
+
data_args: "DataArguments",
|
228 |
+
training_args: "Seq2SeqTrainingArguments",
|
229 |
+
) -> Union["Dataset", "IterableDataset"]:
|
230 |
+
r"""
|
231 |
+
Aligned dataset:
|
232 |
+
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
233 |
+
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
234 |
+
_system: "..."
|
235 |
+
_tools: "...",
|
236 |
+
_images: [],
|
237 |
+
_videos: [],
|
238 |
+
"""
|
239 |
+
if dataset_attr.formatting == "alpaca":
|
240 |
+
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
241 |
+
else:
|
242 |
+
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
243 |
+
|
244 |
+
column_names = list(next(iter(dataset)).keys())
|
245 |
+
kwargs = {}
|
246 |
+
if not data_args.streaming:
|
247 |
+
kwargs = dict(
|
248 |
+
num_proc=data_args.preprocessing_num_workers,
|
249 |
+
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
250 |
+
desc="Converting format of dataset",
|
251 |
+
)
|
252 |
+
|
253 |
+
return dataset.map(
|
254 |
+
convert_func,
|
255 |
+
batched=False,
|
256 |
+
remove_columns=column_names,
|
257 |
+
**kwargs,
|
258 |
+
)
|
llamafactory/data/collator.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
4 |
+
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from transformers import DataCollatorForSeq2Seq
|
23 |
+
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from transformers import ProcessorMixin
|
27 |
+
|
28 |
+
from .template import Template
|
29 |
+
|
30 |
+
|
31 |
+
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
32 |
+
r"""
|
33 |
+
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
34 |
+
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
35 |
+
|
36 |
+
e.g.
|
37 |
+
```python
|
38 |
+
# input
|
39 |
+
[[1, 1, 2, 2, 2, 0]]
|
40 |
+
# output
|
41 |
+
[
|
42 |
+
[
|
43 |
+
[
|
44 |
+
[o, x, x, x, x, x],
|
45 |
+
[o, o, x, x, x, x],
|
46 |
+
[x, x, o, x, x, x],
|
47 |
+
[x, x, o, o, x, x],
|
48 |
+
[x, x, o, o, o, x],
|
49 |
+
[x, x, x, x, x, x],
|
50 |
+
]
|
51 |
+
]
|
52 |
+
]
|
53 |
+
```
|
54 |
+
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
55 |
+
"""
|
56 |
+
bsz, seq_len = attention_mask_with_indices.size()
|
57 |
+
min_dtype = torch.finfo(dtype).min
|
58 |
+
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
59 |
+
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
60 |
+
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
61 |
+
# Create a block-diagonal mask.
|
62 |
+
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
63 |
+
# Use the lower triangular mask to zero out the upper triangular part
|
64 |
+
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
65 |
+
# Invert the attention mask.
|
66 |
+
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
67 |
+
return attention_mask_4d
|
68 |
+
|
69 |
+
|
70 |
+
@dataclass
|
71 |
+
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
72 |
+
r"""
|
73 |
+
Data collator that supports VLMs.
|
74 |
+
|
75 |
+
Features should contain input_ids, attention_mask, labels and images.
|
76 |
+
"""
|
77 |
+
|
78 |
+
template: Optional["Template"] = None
|
79 |
+
processor: Optional["ProcessorMixin"] = None
|
80 |
+
|
81 |
+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
82 |
+
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
|
83 |
+
for feature in features:
|
84 |
+
images = feature.pop("images", None) or []
|
85 |
+
videos = feature.pop("videos", None) or []
|
86 |
+
batch_images.extend(images)
|
87 |
+
batch_videos.extend(videos)
|
88 |
+
batch_imglens.append(len(images))
|
89 |
+
batch_vidlens.append(len(videos))
|
90 |
+
batch_seqlens.append(len(feature["input_ids"]))
|
91 |
+
|
92 |
+
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
93 |
+
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
|
94 |
+
)
|
95 |
+
if "token_type_ids" in mm_inputs:
|
96 |
+
token_type_ids = mm_inputs.pop("token_type_ids")
|
97 |
+
for i, feature in enumerate(features):
|
98 |
+
feature["token_type_ids"] = token_type_ids[i]
|
99 |
+
|
100 |
+
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
101 |
+
features.update(mm_inputs)
|
102 |
+
return features
|
103 |
+
|
104 |
+
|
105 |
+
@dataclass
|
106 |
+
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
107 |
+
r"""
|
108 |
+
Data collator for 4d attention mask.
|
109 |
+
"""
|
110 |
+
|
111 |
+
block_diag_attn: bool = False
|
112 |
+
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
113 |
+
compute_dtype: "torch.dtype" = torch.float32
|
114 |
+
|
115 |
+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
116 |
+
features = super().__call__(features)
|
117 |
+
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
118 |
+
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
119 |
+
|
120 |
+
return features
|
121 |
+
|
122 |
+
|
123 |
+
@dataclass
|
124 |
+
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
125 |
+
r"""
|
126 |
+
Data collator for pairwise data.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
130 |
+
r"""
|
131 |
+
Pads batched data to the longest sequence in the batch.
|
132 |
+
|
133 |
+
We generate 2 * n examples where the first n examples represent chosen examples and
|
134 |
+
the last n examples represent rejected examples.
|
135 |
+
"""
|
136 |
+
concatenated_features = []
|
137 |
+
for key in ("chosen", "rejected"):
|
138 |
+
for feature in features:
|
139 |
+
target_feature = {
|
140 |
+
"input_ids": feature["{}_input_ids".format(key)],
|
141 |
+
"attention_mask": feature["{}_attention_mask".format(key)],
|
142 |
+
"labels": feature["{}_labels".format(key)],
|
143 |
+
"images": feature["images"],
|
144 |
+
"videos": feature["videos"],
|
145 |
+
}
|
146 |
+
concatenated_features.append(target_feature)
|
147 |
+
|
148 |
+
return super().__call__(concatenated_features)
|
149 |
+
|
150 |
+
|
151 |
+
@dataclass
|
152 |
+
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
153 |
+
r"""
|
154 |
+
Data collator for KTO data.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
158 |
+
target_features = []
|
159 |
+
kl_features = []
|
160 |
+
kto_tags = []
|
161 |
+
for feature in features:
|
162 |
+
target_feature = {
|
163 |
+
"input_ids": feature["input_ids"],
|
164 |
+
"attention_mask": feature["attention_mask"],
|
165 |
+
"labels": feature["labels"],
|
166 |
+
"images": feature["images"],
|
167 |
+
"videos": feature["videos"],
|
168 |
+
}
|
169 |
+
kl_feature = {
|
170 |
+
"input_ids": feature["kl_input_ids"],
|
171 |
+
"attention_mask": feature["kl_attention_mask"],
|
172 |
+
"labels": feature["kl_labels"],
|
173 |
+
"images": feature["images"],
|
174 |
+
"videos": feature["videos"],
|
175 |
+
}
|
176 |
+
target_features.append(target_feature)
|
177 |
+
kl_features.append(kl_feature)
|
178 |
+
kto_tags.append(feature["kto_tags"])
|
179 |
+
|
180 |
+
batch = super().__call__(target_features)
|
181 |
+
kl_batch = super().__call__(kl_features)
|
182 |
+
batch["kl_input_ids"] = kl_batch["input_ids"]
|
183 |
+
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
184 |
+
batch["kl_labels"] = kl_batch["labels"]
|
185 |
+
if "token_type_ids" in kl_batch:
|
186 |
+
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
187 |
+
|
188 |
+
batch["kto_tags"] = torch.tensor(kto_tags)
|
189 |
+
return batch
|
llamafactory/data/data_utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from enum import Enum, unique
|
16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
|
17 |
+
|
18 |
+
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
19 |
+
|
20 |
+
from ..extras.logging import get_logger
|
21 |
+
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
from datasets import Dataset, IterableDataset
|
25 |
+
|
26 |
+
from ..hparams import DataArguments
|
27 |
+
|
28 |
+
|
29 |
+
logger = get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
33 |
+
|
34 |
+
|
35 |
+
@unique
|
36 |
+
class Role(str, Enum):
|
37 |
+
USER = "user"
|
38 |
+
ASSISTANT = "assistant"
|
39 |
+
SYSTEM = "system"
|
40 |
+
FUNCTION = "function"
|
41 |
+
OBSERVATION = "observation"
|
42 |
+
|
43 |
+
|
44 |
+
class DatasetModule(TypedDict):
|
45 |
+
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
46 |
+
eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
47 |
+
|
48 |
+
|
49 |
+
def merge_dataset(
|
50 |
+
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
51 |
+
) -> Union["Dataset", "IterableDataset"]:
|
52 |
+
r"""
|
53 |
+
Merges multiple datasets to a unified dataset.
|
54 |
+
"""
|
55 |
+
if len(all_datasets) == 1:
|
56 |
+
return all_datasets[0]
|
57 |
+
elif data_args.mix_strategy == "concat":
|
58 |
+
if data_args.streaming:
|
59 |
+
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
60 |
+
|
61 |
+
return concatenate_datasets(all_datasets)
|
62 |
+
elif data_args.mix_strategy.startswith("interleave"):
|
63 |
+
if not data_args.streaming:
|
64 |
+
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
65 |
+
|
66 |
+
return interleave_datasets(
|
67 |
+
datasets=all_datasets,
|
68 |
+
probabilities=data_args.interleave_probs,
|
69 |
+
seed=seed,
|
70 |
+
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
74 |
+
|
75 |
+
|
76 |
+
def split_dataset(
|
77 |
+
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
78 |
+
) -> "DatasetDict":
|
79 |
+
r"""
|
80 |
+
Splits the dataset and returns a dataset dict containing train set and validation set.
|
81 |
+
|
82 |
+
Supports both map dataset and iterable dataset.
|
83 |
+
"""
|
84 |
+
if data_args.streaming:
|
85 |
+
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
86 |
+
val_set = dataset.take(int(data_args.val_size))
|
87 |
+
train_set = dataset.skip(int(data_args.val_size))
|
88 |
+
return DatasetDict({"train": train_set, "validation": val_set})
|
89 |
+
else:
|
90 |
+
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
91 |
+
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
92 |
+
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
|
llamafactory/data/formatter.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import re
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
from typing_extensions import override
|
22 |
+
|
23 |
+
from .data_utils import SLOTS
|
24 |
+
from .tool_utils import get_tool_utils
|
25 |
+
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from .tool_utils import FunctionCall
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class Formatter(ABC):
|
33 |
+
slots: SLOTS = field(default_factory=list)
|
34 |
+
tool_format: Optional[str] = None
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def apply(self, **kwargs) -> SLOTS:
|
38 |
+
r"""
|
39 |
+
Forms a list of slots according to the inputs to encode.
|
40 |
+
"""
|
41 |
+
...
|
42 |
+
|
43 |
+
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
44 |
+
r"""
|
45 |
+
Extract a list of tuples from the response message if using tools.
|
46 |
+
|
47 |
+
Each tuple consists of function name and function arguments.
|
48 |
+
"""
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class EmptyFormatter(Formatter):
|
54 |
+
def __post_init__(self):
|
55 |
+
has_placeholder = False
|
56 |
+
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
57 |
+
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
58 |
+
has_placeholder = True
|
59 |
+
|
60 |
+
if has_placeholder:
|
61 |
+
raise ValueError("Empty formatter should not contain any placeholder.")
|
62 |
+
|
63 |
+
@override
|
64 |
+
def apply(self, **kwargs) -> SLOTS:
|
65 |
+
return self.slots
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class StringFormatter(Formatter):
|
70 |
+
def __post_init__(self):
|
71 |
+
has_placeholder = False
|
72 |
+
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
73 |
+
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
74 |
+
has_placeholder = True
|
75 |
+
|
76 |
+
if not has_placeholder:
|
77 |
+
raise ValueError("A placeholder is required in the string formatter.")
|
78 |
+
|
79 |
+
@override
|
80 |
+
def apply(self, **kwargs) -> SLOTS:
|
81 |
+
elements = []
|
82 |
+
for slot in self.slots:
|
83 |
+
if isinstance(slot, str):
|
84 |
+
for name, value in kwargs.items():
|
85 |
+
if not isinstance(value, str):
|
86 |
+
raise RuntimeError("Expected a string, got {}".format(value))
|
87 |
+
|
88 |
+
slot = slot.replace("{{" + name + "}}", value, 1)
|
89 |
+
elements.append(slot)
|
90 |
+
elif isinstance(slot, (dict, set)):
|
91 |
+
elements.append(slot)
|
92 |
+
else:
|
93 |
+
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
94 |
+
|
95 |
+
return elements
|
96 |
+
|
97 |
+
|
98 |
+
@dataclass
|
99 |
+
class FunctionFormatter(Formatter):
|
100 |
+
def __post_init__(self):
|
101 |
+
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
102 |
+
|
103 |
+
@override
|
104 |
+
def apply(self, **kwargs) -> SLOTS:
|
105 |
+
content = kwargs.pop("content")
|
106 |
+
functions: List[Tuple[str, str]] = []
|
107 |
+
try:
|
108 |
+
tool_calls = json.loads(content)
|
109 |
+
if not isinstance(tool_calls, list): # parallel function call
|
110 |
+
tool_calls = [tool_calls]
|
111 |
+
|
112 |
+
for tool_call in tool_calls:
|
113 |
+
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
114 |
+
|
115 |
+
except json.JSONDecodeError:
|
116 |
+
raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string
|
117 |
+
|
118 |
+
elements = []
|
119 |
+
for name, arguments in functions:
|
120 |
+
for slot in self.slots:
|
121 |
+
if isinstance(slot, str):
|
122 |
+
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
123 |
+
elements.append(slot)
|
124 |
+
elif isinstance(slot, (dict, set)):
|
125 |
+
elements.append(slot)
|
126 |
+
else:
|
127 |
+
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
128 |
+
|
129 |
+
return elements
|
130 |
+
|
131 |
+
|
132 |
+
@dataclass
|
133 |
+
class ToolFormatter(Formatter):
|
134 |
+
def __post_init__(self):
|
135 |
+
self.tool_utils = get_tool_utils(self.tool_format)
|
136 |
+
|
137 |
+
@override
|
138 |
+
def apply(self, **kwargs) -> SLOTS:
|
139 |
+
content = kwargs.pop("content")
|
140 |
+
try:
|
141 |
+
tools = json.loads(content)
|
142 |
+
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
143 |
+
except json.JSONDecodeError:
|
144 |
+
raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string
|
145 |
+
|
146 |
+
@override
|
147 |
+
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
148 |
+
return self.tool_utils.tool_extractor(content)
|
llamafactory/data/loader.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
from datasets import DatasetDict, load_dataset, load_from_disk
|
21 |
+
from transformers.utils.versions import require_version
|
22 |
+
|
23 |
+
from ..extras.constants import FILEEXT2TYPE
|
24 |
+
from ..extras.logging import get_logger
|
25 |
+
from ..extras.misc import has_tokenized_data
|
26 |
+
from .aligner import align_dataset
|
27 |
+
from .data_utils import merge_dataset, split_dataset
|
28 |
+
from .parser import get_dataset_list
|
29 |
+
from .preprocess import get_preprocess_and_print_func
|
30 |
+
|
31 |
+
|
32 |
+
if TYPE_CHECKING:
|
33 |
+
from datasets import Dataset, IterableDataset
|
34 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
35 |
+
|
36 |
+
from ..hparams import DataArguments, ModelArguments
|
37 |
+
from .data_utils import DatasetModule
|
38 |
+
from .parser import DatasetAttr
|
39 |
+
from .template import Template
|
40 |
+
|
41 |
+
|
42 |
+
logger = get_logger(__name__)
|
43 |
+
|
44 |
+
|
45 |
+
def _load_single_dataset(
|
46 |
+
dataset_attr: "DatasetAttr",
|
47 |
+
model_args: "ModelArguments",
|
48 |
+
data_args: "DataArguments",
|
49 |
+
training_args: "Seq2SeqTrainingArguments",
|
50 |
+
) -> Union["Dataset", "IterableDataset"]:
|
51 |
+
r"""
|
52 |
+
Loads a single dataset and aligns it to the standard format.
|
53 |
+
"""
|
54 |
+
logger.info("Loading dataset {}...".format(dataset_attr))
|
55 |
+
data_path, data_name, data_dir, data_files = None, None, None, None
|
56 |
+
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
57 |
+
data_path = dataset_attr.dataset_name
|
58 |
+
data_name = dataset_attr.subset
|
59 |
+
data_dir = dataset_attr.folder
|
60 |
+
|
61 |
+
elif dataset_attr.load_from == "script":
|
62 |
+
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
63 |
+
data_name = dataset_attr.subset
|
64 |
+
data_dir = dataset_attr.folder
|
65 |
+
|
66 |
+
elif dataset_attr.load_from == "file":
|
67 |
+
data_files = []
|
68 |
+
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
69 |
+
if os.path.isdir(local_path): # is directory
|
70 |
+
for file_name in os.listdir(local_path):
|
71 |
+
data_files.append(os.path.join(local_path, file_name))
|
72 |
+
if data_path is None:
|
73 |
+
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
74 |
+
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
75 |
+
raise ValueError("File types should be identical.")
|
76 |
+
elif os.path.isfile(local_path): # is file
|
77 |
+
data_files.append(local_path)
|
78 |
+
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
79 |
+
else:
|
80 |
+
raise ValueError("File {} not found.".format(local_path))
|
81 |
+
|
82 |
+
if data_path is None:
|
83 |
+
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
84 |
+
else:
|
85 |
+
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
86 |
+
|
87 |
+
if dataset_attr.load_from == "ms_hub":
|
88 |
+
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
89 |
+
from modelscope import MsDataset
|
90 |
+
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
91 |
+
|
92 |
+
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
93 |
+
dataset = MsDataset.load(
|
94 |
+
dataset_name=data_path,
|
95 |
+
subset_name=data_name,
|
96 |
+
data_dir=data_dir,
|
97 |
+
data_files=data_files,
|
98 |
+
split=dataset_attr.split,
|
99 |
+
cache_dir=cache_dir,
|
100 |
+
token=model_args.ms_hub_token,
|
101 |
+
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
102 |
+
)
|
103 |
+
if isinstance(dataset, MsDataset):
|
104 |
+
dataset = dataset.to_hf_dataset()
|
105 |
+
else:
|
106 |
+
dataset = load_dataset(
|
107 |
+
path=data_path,
|
108 |
+
name=data_name,
|
109 |
+
data_dir=data_dir,
|
110 |
+
data_files=data_files,
|
111 |
+
split=dataset_attr.split,
|
112 |
+
cache_dir=model_args.cache_dir,
|
113 |
+
token=model_args.hf_hub_token,
|
114 |
+
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
115 |
+
trust_remote_code=True,
|
116 |
+
)
|
117 |
+
|
118 |
+
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
119 |
+
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
120 |
+
|
121 |
+
if dataset_attr.num_samples is not None and not data_args.streaming:
|
122 |
+
target_num = dataset_attr.num_samples
|
123 |
+
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
124 |
+
target_num -= len(indexes)
|
125 |
+
if target_num > 0:
|
126 |
+
expand_indexes = np.random.choice(len(dataset), target_num)
|
127 |
+
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
128 |
+
|
129 |
+
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
130 |
+
dataset = dataset.select(indexes)
|
131 |
+
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
132 |
+
|
133 |
+
if data_args.max_samples is not None: # truncate dataset
|
134 |
+
max_samples = min(data_args.max_samples, len(dataset))
|
135 |
+
dataset = dataset.select(range(max_samples))
|
136 |
+
|
137 |
+
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
138 |
+
|
139 |
+
|
140 |
+
def _get_merged_dataset(
|
141 |
+
dataset_names: Optional[Sequence[str]],
|
142 |
+
model_args: "ModelArguments",
|
143 |
+
data_args: "DataArguments",
|
144 |
+
training_args: "Seq2SeqTrainingArguments",
|
145 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
146 |
+
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
147 |
+
r"""
|
148 |
+
Gets the merged datasets in the standard format.
|
149 |
+
"""
|
150 |
+
if dataset_names is None:
|
151 |
+
return None
|
152 |
+
|
153 |
+
datasets = []
|
154 |
+
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
|
155 |
+
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
156 |
+
raise ValueError("The dataset is not applicable in the current training stage.")
|
157 |
+
|
158 |
+
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
159 |
+
|
160 |
+
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
161 |
+
|
162 |
+
|
163 |
+
def _get_preprocessed_dataset(
|
164 |
+
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
165 |
+
data_args: "DataArguments",
|
166 |
+
training_args: "Seq2SeqTrainingArguments",
|
167 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
168 |
+
template: "Template",
|
169 |
+
tokenizer: "PreTrainedTokenizer",
|
170 |
+
processor: Optional["ProcessorMixin"] = None,
|
171 |
+
is_eval: bool = False,
|
172 |
+
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
173 |
+
r"""
|
174 |
+
Preprocesses the dataset, including format checking and tokenization.
|
175 |
+
"""
|
176 |
+
if dataset is None:
|
177 |
+
return None
|
178 |
+
|
179 |
+
preprocess_func, print_function = get_preprocess_and_print_func(
|
180 |
+
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
181 |
+
)
|
182 |
+
column_names = list(next(iter(dataset)).keys())
|
183 |
+
kwargs = {}
|
184 |
+
if not data_args.streaming:
|
185 |
+
kwargs = dict(
|
186 |
+
num_proc=data_args.preprocessing_num_workers,
|
187 |
+
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
188 |
+
desc="Running tokenizer on dataset",
|
189 |
+
)
|
190 |
+
|
191 |
+
dataset = dataset.map(
|
192 |
+
preprocess_func,
|
193 |
+
batched=True,
|
194 |
+
batch_size=data_args.preprocessing_batch_size,
|
195 |
+
remove_columns=column_names,
|
196 |
+
**kwargs,
|
197 |
+
)
|
198 |
+
|
199 |
+
if training_args.should_log:
|
200 |
+
try:
|
201 |
+
print("eval example:" if is_eval else "training example:")
|
202 |
+
print_function(next(iter(dataset)))
|
203 |
+
except StopIteration:
|
204 |
+
if stage == "pt":
|
205 |
+
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
206 |
+
else:
|
207 |
+
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
208 |
+
|
209 |
+
return dataset
|
210 |
+
|
211 |
+
|
212 |
+
def get_dataset(
|
213 |
+
template: "Template",
|
214 |
+
model_args: "ModelArguments",
|
215 |
+
data_args: "DataArguments",
|
216 |
+
training_args: "Seq2SeqTrainingArguments",
|
217 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
218 |
+
tokenizer: "PreTrainedTokenizer",
|
219 |
+
processor: Optional["ProcessorMixin"] = None,
|
220 |
+
) -> "DatasetModule":
|
221 |
+
r"""
|
222 |
+
Gets the train dataset and optionally gets the evaluation dataset.
|
223 |
+
"""
|
224 |
+
# Load tokenized dataset
|
225 |
+
if data_args.tokenized_path is not None:
|
226 |
+
if has_tokenized_data(data_args.tokenized_path):
|
227 |
+
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
228 |
+
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
229 |
+
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
230 |
+
|
231 |
+
dataset_module: Dict[str, "Dataset"] = {}
|
232 |
+
if "train" in dataset_dict:
|
233 |
+
dataset_module["train_dataset"] = dataset_dict["train"]
|
234 |
+
|
235 |
+
if "validation" in dataset_dict:
|
236 |
+
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
237 |
+
|
238 |
+
if data_args.streaming:
|
239 |
+
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
240 |
+
|
241 |
+
return dataset_module
|
242 |
+
|
243 |
+
if data_args.streaming:
|
244 |
+
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
245 |
+
|
246 |
+
# Load and preprocess dataset
|
247 |
+
with training_args.main_process_first(desc="load dataset"):
|
248 |
+
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
249 |
+
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
|
250 |
+
|
251 |
+
with training_args.main_process_first(desc="pre-process dataset"):
|
252 |
+
dataset = _get_preprocessed_dataset(
|
253 |
+
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
254 |
+
)
|
255 |
+
eval_dataset = _get_preprocessed_dataset(
|
256 |
+
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
257 |
+
)
|
258 |
+
|
259 |
+
if data_args.val_size > 1e-6:
|
260 |
+
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
|
261 |
+
else:
|
262 |
+
dataset_dict = {}
|
263 |
+
if dataset is not None:
|
264 |
+
if data_args.streaming:
|
265 |
+
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
266 |
+
|
267 |
+
dataset_dict["train"] = dataset
|
268 |
+
|
269 |
+
if eval_dataset is not None:
|
270 |
+
if data_args.streaming:
|
271 |
+
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
272 |
+
|
273 |
+
dataset_dict["validation"] = eval_dataset
|
274 |
+
|
275 |
+
dataset_dict = DatasetDict(dataset_dict)
|
276 |
+
|
277 |
+
if data_args.tokenized_path is not None:
|
278 |
+
if training_args.should_save:
|
279 |
+
dataset_dict.save_to_disk(data_args.tokenized_path)
|
280 |
+
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
281 |
+
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
282 |
+
|
283 |
+
sys.exit(0)
|
284 |
+
|
285 |
+
dataset_module = {}
|
286 |
+
if "train" in dataset_dict:
|
287 |
+
dataset_module["train_dataset"] = dataset_dict["train"]
|
288 |
+
|
289 |
+
if "validation" in dataset_dict:
|
290 |
+
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
291 |
+
|
292 |
+
return dataset_module
|
llamafactory/data/mm_plugin.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from copy import deepcopy
|
3 |
+
from io import BytesIO
|
4 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from transformers.image_utils import get_image_size, to_numpy_array
|
8 |
+
from typing_extensions import override
|
9 |
+
|
10 |
+
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
11 |
+
from ..extras.packages import is_pillow_available, is_pyav_available
|
12 |
+
|
13 |
+
|
14 |
+
if is_pillow_available():
|
15 |
+
from PIL import Image
|
16 |
+
from PIL.Image import Image as ImageObject
|
17 |
+
|
18 |
+
|
19 |
+
if is_pyav_available():
|
20 |
+
import av
|
21 |
+
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
import torch
|
25 |
+
from av.stream import Stream
|
26 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
27 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
28 |
+
|
29 |
+
class EncodedImage(TypedDict):
|
30 |
+
path: Optional[str]
|
31 |
+
bytes: Optional[bytes]
|
32 |
+
|
33 |
+
ImageInput = Union[str, EncodedImage, ImageObject]
|
34 |
+
VideoInput = str
|
35 |
+
|
36 |
+
|
37 |
+
def _get_paligemma_token_type_ids(
|
38 |
+
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
39 |
+
) -> List[List[int]]:
|
40 |
+
r"""
|
41 |
+
Gets paligemma token type ids for computing loss.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
batch_token_type_ids: shape (batch_size, sequence_length)
|
45 |
+
"""
|
46 |
+
batch_token_type_ids = []
|
47 |
+
for imglen, seqlen in zip(imglens, seqlens):
|
48 |
+
image_seqlen = imglen * getattr(processor, "image_seqlen")
|
49 |
+
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
|
50 |
+
|
51 |
+
return batch_token_type_ids
|
52 |
+
|
53 |
+
|
54 |
+
class BasePlugin:
|
55 |
+
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
56 |
+
self.image_token = image_token
|
57 |
+
self.video_token = video_token
|
58 |
+
|
59 |
+
def _validate_input(
|
60 |
+
self,
|
61 |
+
images: Sequence["ImageInput"],
|
62 |
+
videos: Sequence["VideoInput"],
|
63 |
+
) -> None:
|
64 |
+
r"""
|
65 |
+
Validates if this model accepts the input modalities.
|
66 |
+
"""
|
67 |
+
if len(images) != 0 and self.image_token is None:
|
68 |
+
raise ValueError("This model does not support image input.")
|
69 |
+
|
70 |
+
if len(videos) != 0 and self.video_token is None:
|
71 |
+
raise ValueError("This model does not support video input.")
|
72 |
+
|
73 |
+
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
74 |
+
r"""
|
75 |
+
Pre-processes a single image.
|
76 |
+
"""
|
77 |
+
image_resolution: int = kwargs.get("image_resolution")
|
78 |
+
if max(image.width, image.height) > image_resolution:
|
79 |
+
resize_factor = image_resolution / max(image.width, image.height)
|
80 |
+
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
81 |
+
image = image.resize((width, height), resample=Image.NEAREST)
|
82 |
+
|
83 |
+
if image.mode != "RGB":
|
84 |
+
image = image.convert("RGB")
|
85 |
+
|
86 |
+
return image
|
87 |
+
|
88 |
+
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
89 |
+
r"""
|
90 |
+
Computes video sample frames according to fps.
|
91 |
+
"""
|
92 |
+
video_fps: float = kwargs.get("video_fps")
|
93 |
+
video_maxlen: int = kwargs.get("video_maxlen")
|
94 |
+
total_frames = video_stream.frames
|
95 |
+
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
96 |
+
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
97 |
+
return math.floor(sample_frames)
|
98 |
+
|
99 |
+
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
100 |
+
r"""
|
101 |
+
Regularizes images to avoid error. Including reading and pre-processing.
|
102 |
+
"""
|
103 |
+
results = []
|
104 |
+
for image in images:
|
105 |
+
if isinstance(image, str):
|
106 |
+
image = Image.open(image)
|
107 |
+
elif isinstance(image, dict):
|
108 |
+
if image["bytes"] is not None:
|
109 |
+
image = Image.open(BytesIO(image["bytes"]))
|
110 |
+
else:
|
111 |
+
image = Image.open(image["path"])
|
112 |
+
|
113 |
+
if not isinstance(image, ImageObject):
|
114 |
+
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
115 |
+
|
116 |
+
results.append(self._preprocess_image(image, **kwargs))
|
117 |
+
|
118 |
+
return results
|
119 |
+
|
120 |
+
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
121 |
+
r"""
|
122 |
+
Regularizes videos to avoid error. Including reading, resizing and converting.
|
123 |
+
"""
|
124 |
+
results = []
|
125 |
+
for video in videos:
|
126 |
+
container = av.open(video, "r")
|
127 |
+
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
128 |
+
total_frames = video_stream.frames
|
129 |
+
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
130 |
+
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
131 |
+
frames: List["ImageObject"] = []
|
132 |
+
container.seek(0)
|
133 |
+
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
134 |
+
if frame_idx in sample_indices:
|
135 |
+
frames.append(frame.to_image())
|
136 |
+
|
137 |
+
frames = self._regularize_images(frames, **kwargs)
|
138 |
+
results.append(frames)
|
139 |
+
|
140 |
+
return results
|
141 |
+
|
142 |
+
def _get_mm_inputs(
|
143 |
+
self,
|
144 |
+
images: Sequence["ImageInput"],
|
145 |
+
videos: Sequence["VideoInput"],
|
146 |
+
processor: "ProcessorMixin",
|
147 |
+
) -> Dict[str, "torch.Tensor"]:
|
148 |
+
r"""
|
149 |
+
Processes visual inputs.
|
150 |
+
|
151 |
+
Returns: (llava and paligemma)
|
152 |
+
pixel_values: tensor with shape (B, C, H, W)
|
153 |
+
|
154 |
+
Returns: (qwen2-vl)
|
155 |
+
pixel_values: tensor with shape (num_patches, patch_dim)
|
156 |
+
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
157 |
+
|
158 |
+
It holds num_patches == torch.prod(image_grid_thw)
|
159 |
+
"""
|
160 |
+
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
161 |
+
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
162 |
+
input_dict = {"images": None} # default key
|
163 |
+
if len(images) != 0:
|
164 |
+
images = self._regularize_images(
|
165 |
+
images,
|
166 |
+
image_resolution=getattr(processor, "image_resolution", 512),
|
167 |
+
)
|
168 |
+
input_dict["images"] = images
|
169 |
+
|
170 |
+
if len(videos) != 0:
|
171 |
+
videos = self._regularize_videos(
|
172 |
+
videos,
|
173 |
+
image_resolution=getattr(processor, "video_resolution", 128),
|
174 |
+
video_fps=getattr(processor, "video_fps", 1.0),
|
175 |
+
video_maxlen=getattr(processor, "video_maxlen", 64),
|
176 |
+
)
|
177 |
+
input_dict["videos"] = videos
|
178 |
+
|
179 |
+
mm_inputs = {}
|
180 |
+
if image_processor != video_processor:
|
181 |
+
if input_dict.get("images") is not None:
|
182 |
+
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
183 |
+
if input_dict.get("videos") is not None:
|
184 |
+
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
185 |
+
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
186 |
+
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
187 |
+
|
188 |
+
return mm_inputs
|
189 |
+
|
190 |
+
def process_messages(
|
191 |
+
self,
|
192 |
+
messages: Sequence[Dict[str, str]],
|
193 |
+
images: Sequence["ImageInput"],
|
194 |
+
videos: Sequence["VideoInput"],
|
195 |
+
processor: Optional["ProcessorMixin"],
|
196 |
+
) -> List[Dict[str, str]]:
|
197 |
+
r"""
|
198 |
+
Pre-processes input messages before tokenization for VLMs.
|
199 |
+
"""
|
200 |
+
self._validate_input(images, videos)
|
201 |
+
return messages
|
202 |
+
|
203 |
+
def process_token_ids(
|
204 |
+
self,
|
205 |
+
input_ids: List[int],
|
206 |
+
labels: Optional[List[int]],
|
207 |
+
images: Sequence["ImageInput"],
|
208 |
+
videos: Sequence["VideoInput"],
|
209 |
+
tokenizer: "PreTrainedTokenizer",
|
210 |
+
processor: Optional["ProcessorMixin"],
|
211 |
+
) -> Tuple[List[int], Optional[List[int]]]:
|
212 |
+
r"""
|
213 |
+
Pre-processes token ids after tokenization for VLMs.
|
214 |
+
"""
|
215 |
+
self._validate_input(images, videos)
|
216 |
+
return input_ids, labels
|
217 |
+
|
218 |
+
def get_mm_inputs(
|
219 |
+
self,
|
220 |
+
images: Sequence["ImageInput"],
|
221 |
+
videos: Sequence["VideoInput"],
|
222 |
+
imglens: Sequence[int],
|
223 |
+
vidlens: Sequence[int],
|
224 |
+
seqlens: Sequence[int],
|
225 |
+
processor: Optional["ProcessorMixin"],
|
226 |
+
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
227 |
+
r"""
|
228 |
+
Builds batched multimodal inputs for VLMs.
|
229 |
+
"""
|
230 |
+
self._validate_input(images, videos)
|
231 |
+
return {}
|
232 |
+
|
233 |
+
|
234 |
+
class LlavaPlugin(BasePlugin):
|
235 |
+
@override
|
236 |
+
def process_messages(
|
237 |
+
self,
|
238 |
+
messages: Sequence[Dict[str, str]],
|
239 |
+
images: Sequence["ImageInput"],
|
240 |
+
videos: Sequence["VideoInput"],
|
241 |
+
processor: Optional["ProcessorMixin"],
|
242 |
+
) -> List[Dict[str, str]]:
|
243 |
+
self._validate_input(images, videos)
|
244 |
+
num_image_tokens = 0
|
245 |
+
image_seqlen = getattr(processor, "image_seqlen")
|
246 |
+
messages = deepcopy(messages)
|
247 |
+
for message in messages:
|
248 |
+
content = message["content"]
|
249 |
+
while IMAGE_PLACEHOLDER in content:
|
250 |
+
num_image_tokens += 1
|
251 |
+
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
252 |
+
|
253 |
+
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
254 |
+
|
255 |
+
if len(images) != num_image_tokens:
|
256 |
+
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
257 |
+
|
258 |
+
return messages
|
259 |
+
|
260 |
+
@override
|
261 |
+
def get_mm_inputs(
|
262 |
+
self,
|
263 |
+
images: Sequence["ImageInput"],
|
264 |
+
videos: Sequence["VideoInput"],
|
265 |
+
imglens: Sequence[int],
|
266 |
+
vidlens: Sequence[int],
|
267 |
+
seqlens: Sequence[int],
|
268 |
+
processor: Optional["ProcessorMixin"],
|
269 |
+
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
270 |
+
self._validate_input(images, videos)
|
271 |
+
return self._get_mm_inputs(images, videos, processor)
|
272 |
+
|
273 |
+
|
274 |
+
class LlavaNextPlugin(BasePlugin):
|
275 |
+
@override
|
276 |
+
def process_messages(
|
277 |
+
self,
|
278 |
+
messages: Sequence[Dict[str, str]],
|
279 |
+
images: Sequence["ImageInput"],
|
280 |
+
videos: Sequence["VideoInput"],
|
281 |
+
processor: Optional["ProcessorMixin"],
|
282 |
+
) -> List[Dict[str, str]]:
|
283 |
+
self._validate_input(images, videos)
|
284 |
+
num_image_tokens = 0
|
285 |
+
messages = deepcopy(messages)
|
286 |
+
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
287 |
+
if "image_sizes" in mm_inputs:
|
288 |
+
image_sizes = iter(mm_inputs["image_sizes"])
|
289 |
+
if "pixel_values" in mm_inputs:
|
290 |
+
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
291 |
+
for message in messages:
|
292 |
+
content = message["content"]
|
293 |
+
while self.image_token in content:
|
294 |
+
image_size = next(image_sizes)
|
295 |
+
orig_height, orig_width = image_size
|
296 |
+
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
297 |
+
if processor.vision_feature_select_strategy == "default":
|
298 |
+
image_seqlen -= 1
|
299 |
+
num_image_tokens += 1
|
300 |
+
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
301 |
+
|
302 |
+
message["content"] = content.replace("{{image}}", self.image_token)
|
303 |
+
|
304 |
+
if len(images) != num_image_tokens:
|
305 |
+
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
306 |
+
return messages
|
307 |
+
|
308 |
+
@override
|
309 |
+
def get_mm_inputs(
|
310 |
+
self,
|
311 |
+
images: Sequence["ImageInput"],
|
312 |
+
videos: Sequence["VideoInput"],
|
313 |
+
imglens: Sequence[int],
|
314 |
+
vidlens: Sequence[int],
|
315 |
+
seqlens: Sequence[int],
|
316 |
+
processor: Optional["ProcessorMixin"],
|
317 |
+
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
318 |
+
self._validate_input(images, videos)
|
319 |
+
res = self._get_mm_inputs(images, videos, processor)
|
320 |
+
return res
|
321 |
+
|
322 |
+
|
323 |
+
class LlavaNextVideoPlugin(BasePlugin):
|
324 |
+
@override
|
325 |
+
def process_messages(
|
326 |
+
self,
|
327 |
+
messages: Sequence[Dict[str, str]],
|
328 |
+
images: Sequence["ImageInput"],
|
329 |
+
videos: Sequence["VideoInput"],
|
330 |
+
processor: Optional["ProcessorMixin"],
|
331 |
+
) -> List[Dict[str, str]]:
|
332 |
+
self._validate_input(images, videos)
|
333 |
+
num_image_tokens = 0
|
334 |
+
num_video_tokens = 0
|
335 |
+
messages = deepcopy(messages)
|
336 |
+
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
337 |
+
if "pixel_values" in mm_inputs:
|
338 |
+
image_sizes = iter(mm_inputs["image_sizes"])
|
339 |
+
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
340 |
+
for message in messages:
|
341 |
+
content = message["content"]
|
342 |
+
|
343 |
+
while self.image_token in content:
|
344 |
+
image_size = next(image_sizes)
|
345 |
+
orig_height, orig_width = image_size
|
346 |
+
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
347 |
+
if processor.vision_feature_select_strategy == "default":
|
348 |
+
image_seqlen -= 1
|
349 |
+
num_image_tokens += 1
|
350 |
+
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
351 |
+
|
352 |
+
message["content"] = content.replace("{{image}}", self.image_token)
|
353 |
+
|
354 |
+
if "pixel_values_videos" in mm_inputs:
|
355 |
+
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
356 |
+
height, width = get_image_size(pixel_values_video[0])
|
357 |
+
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
358 |
+
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
359 |
+
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
360 |
+
|
361 |
+
for message in messages:
|
362 |
+
content = message["content"]
|
363 |
+
while self.video_token in content:
|
364 |
+
num_video_tokens += 1
|
365 |
+
content = content.replace(self.video_token, "{{video}}", 1)
|
366 |
+
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
367 |
+
|
368 |
+
if len(images) != num_image_tokens:
|
369 |
+
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
370 |
+
|
371 |
+
if len(videos) != num_video_tokens:
|
372 |
+
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
373 |
+
|
374 |
+
return messages
|
375 |
+
|
376 |
+
@override
|
377 |
+
def get_mm_inputs(
|
378 |
+
self,
|
379 |
+
images: Sequence["ImageInput"],
|
380 |
+
videos: Sequence["VideoInput"],
|
381 |
+
imglens: Sequence[int],
|
382 |
+
vidlens: Sequence[int],
|
383 |
+
seqlens: Sequence[int],
|
384 |
+
processor: Optional["ProcessorMixin"],
|
385 |
+
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
386 |
+
self._validate_input(images, videos)
|
387 |
+
return self._get_mm_inputs(images, videos, processor)
|
388 |
+
|
389 |
+
|
390 |
+
class PaliGemmaPlugin(BasePlugin):
|
391 |
+
@override
|
392 |
+
def process_messages(
|
393 |
+
self,
|
394 |
+
messages: Sequence[Dict[str, str]],
|
395 |
+
images: Sequence["ImageInput"],
|
396 |
+
videos: Sequence["VideoInput"],
|
397 |
+
processor: Optional["ProcessorMixin"],
|
398 |
+
) -> List[Dict[str, str]]:
|
399 |
+
self._validate_input(images, videos)
|
400 |
+
num_image_tokens = 0
|
401 |
+
messages = deepcopy(messages)
|
402 |
+
for message in messages:
|
403 |
+
content = message["content"]
|
404 |
+
while IMAGE_PLACEHOLDER in content:
|
405 |
+
num_image_tokens += 1
|
406 |
+
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
407 |
+
|
408 |
+
message["content"] = content.replace("{{image}}", "")
|
409 |
+
|
410 |
+
if len(images) != num_image_tokens:
|
411 |
+
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
412 |
+
|
413 |
+
return messages
|
414 |
+
|
415 |
+
@override
|
416 |
+
def process_token_ids(
|
417 |
+
self,
|
418 |
+
input_ids: List[int],
|
419 |
+
labels: Optional[List[int]],
|
420 |
+
images: Sequence["ImageInput"],
|
421 |
+
videos: Sequence["VideoInput"],
|
422 |
+
tokenizer: "PreTrainedTokenizer",
|
423 |
+
processor: Optional["ProcessorMixin"],
|
424 |
+
) -> Tuple[List[int], Optional[List[int]]]:
|
425 |
+
self._validate_input(images, videos)
|
426 |
+
num_images = len(images)
|
427 |
+
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
428 |
+
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
429 |
+
input_ids = [image_token_id] * image_seqlen + input_ids
|
430 |
+
if labels is not None:
|
431 |
+
labels = [IGNORE_INDEX] * image_seqlen + labels
|
432 |
+
|
433 |
+
return input_ids, labels
|
434 |
+
|
435 |
+
@override
|
436 |
+
def get_mm_inputs(
|
437 |
+
self,
|
438 |
+
images: Sequence["ImageInput"],
|
439 |
+
videos: Sequence["VideoInput"],
|
440 |
+
imglens: Sequence[int],
|
441 |
+
vidlens: Sequence[int],
|
442 |
+
seqlens: Sequence[int],
|
443 |
+
processor: Optional["ProcessorMixin"],
|
444 |
+
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
445 |
+
self._validate_input(images, videos)
|
446 |
+
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
447 |
+
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
448 |
+
return mm_inputs
|
449 |
+
|
450 |
+
|
451 |
+
class Qwen2vlPlugin(BasePlugin):
|
452 |
+
@override
|
453 |
+
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
454 |
+
image = super()._preprocess_image(image, **kwargs)
|
455 |
+
if min(image.width, image.height) < 28:
|
456 |
+
width, height = max(image.width, 28), max(image.height, 28)
|
457 |
+
image = image.resize((width, height), resample=Image.NEAREST)
|
458 |
+
|
459 |
+
if image.width / image.height > 200:
|
460 |
+
width, height = image.height * 180, image.height
|
461 |
+
image = image.resize((width, height), resample=Image.NEAREST)
|
462 |
+
|
463 |
+
if image.height / image.width > 200:
|
464 |
+
width, height = image.width, image.width * 180
|
465 |
+
image = image.resize((width, height), resample=Image.NEAREST)
|
466 |
+
|
467 |
+
return image
|
468 |
+
|
469 |
+
@override
|
470 |
+
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
471 |
+
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
|
472 |
+
sample_frames = sample_frames // 2 * 2
|
473 |
+
return sample_frames
|
474 |
+
|
475 |
+
@override
|
476 |
+
def process_messages(
|
477 |
+
self,
|
478 |
+
messages: Sequence[Dict[str, str]],
|
479 |
+
images: Sequence["ImageInput"],
|
480 |
+
videos: Sequence["VideoInput"],
|
481 |
+
processor: Optional["ProcessorMixin"],
|
482 |
+
) -> List[Dict[str, str]]:
|
483 |
+
self._validate_input(images, videos)
|
484 |
+
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
485 |
+
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
486 |
+
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
487 |
+
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
488 |
+
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
489 |
+
|
490 |
+
num_image_tokens, num_video_tokens = 0, 0
|
491 |
+
messages = deepcopy(messages)
|
492 |
+
for message in messages:
|
493 |
+
content = message["content"]
|
494 |
+
while IMAGE_PLACEHOLDER in content:
|
495 |
+
if num_image_tokens >= len(image_grid_thw):
|
496 |
+
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
|
497 |
+
|
498 |
+
content = content.replace(
|
499 |
+
IMAGE_PLACEHOLDER,
|
500 |
+
"<|vision_start|>{}<|vision_end|>".format(
|
501 |
+
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
|
502 |
+
),
|
503 |
+
1,
|
504 |
+
)
|
505 |
+
num_image_tokens += 1
|
506 |
+
|
507 |
+
while VIDEO_PLACEHOLDER in content:
|
508 |
+
if num_video_tokens >= len(video_grid_thw):
|
509 |
+
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
|
510 |
+
|
511 |
+
content = content.replace(
|
512 |
+
VIDEO_PLACEHOLDER,
|
513 |
+
"<|vision_start|>{}<|vision_end|>".format(
|
514 |
+
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
|
515 |
+
),
|
516 |
+
1,
|
517 |
+
)
|
518 |
+
num_video_tokens += 1
|
519 |
+
|
520 |
+
message["content"] = content
|
521 |
+
|
522 |
+
if len(images) != num_image_tokens:
|
523 |
+
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
524 |
+
|
525 |
+
if len(videos) != num_video_tokens:
|
526 |
+
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
|
527 |
+
|
528 |
+
return messages
|
529 |
+
|
530 |
+
@override
|
531 |
+
def get_mm_inputs(
|
532 |
+
self,
|
533 |
+
images: Sequence["ImageInput"],
|
534 |
+
videos: Sequence["VideoInput"],
|
535 |
+
imglens: Sequence[int],
|
536 |
+
vidlens: Sequence[int],
|
537 |
+
seqlens: Sequence[int],
|
538 |
+
processor: Optional["ProcessorMixin"],
|
539 |
+
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
540 |
+
self._validate_input(images, videos)
|
541 |
+
return self._get_mm_inputs(images, videos, processor)
|
542 |
+
|
543 |
+
|
544 |
+
class VideoLlavaPlugin(BasePlugin):
|
545 |
+
@override
|
546 |
+
def process_messages(
|
547 |
+
self,
|
548 |
+
messages: Sequence[Dict[str, str]],
|
549 |
+
images: Sequence["ImageInput"],
|
550 |
+
videos: Sequence["VideoInput"],
|
551 |
+
processor: Optional["ProcessorMixin"],
|
552 |
+
) -> List[Dict[str, str]]:
|
553 |
+
self._validate_input(images, videos)
|
554 |
+
num_image_tokens = 0
|
555 |
+
num_video_tokens = 0
|
556 |
+
messages = deepcopy(messages)
|
557 |
+
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
558 |
+
num_frames = 0
|
559 |
+
exist_images = "pixel_values_images" in mm_inputs
|
560 |
+
exist_videos = "pixel_values_videos" in mm_inputs
|
561 |
+
if exist_videos or exist_images:
|
562 |
+
if exist_images:
|
563 |
+
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
564 |
+
num_frames = 1
|
565 |
+
if exist_videos:
|
566 |
+
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
567 |
+
height, width = get_image_size(pixel_values_video[0])
|
568 |
+
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
569 |
+
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
570 |
+
video_seqlen = image_seqlen * num_frames
|
571 |
+
if processor.vision_feature_select_strategy == "default":
|
572 |
+
image_seqlen -= 1
|
573 |
+
for message in messages:
|
574 |
+
content = message["content"]
|
575 |
+
while self.image_token in content:
|
576 |
+
num_image_tokens += 1
|
577 |
+
content = content.replace(self.image_token, "{{image}}", 1)
|
578 |
+
while self.video_token in content:
|
579 |
+
num_video_tokens += 1
|
580 |
+
content = content.replace(self.video_token, "{{video}}", 1)
|
581 |
+
|
582 |
+
content = content.replace("{{image}}", self.image_token * image_seqlen)
|
583 |
+
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
584 |
+
|
585 |
+
if len(images) != num_image_tokens:
|
586 |
+
raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token))
|
587 |
+
|
588 |
+
if len(videos) != num_video_tokens:
|
589 |
+
raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token))
|
590 |
+
|
591 |
+
return messages
|
592 |
+
|
593 |
+
@override
|
594 |
+
def get_mm_inputs(
|
595 |
+
self,
|
596 |
+
images: Sequence["ImageInput"],
|
597 |
+
videos: Sequence["VideoInput"],
|
598 |
+
imglens: Sequence[int],
|
599 |
+
vidlens: Sequence[int],
|
600 |
+
seqlens: Sequence[int],
|
601 |
+
processor: Optional["ProcessorMixin"],
|
602 |
+
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
603 |
+
self._validate_input(images, videos)
|
604 |
+
return self._get_mm_inputs(images, videos, processor)
|
605 |
+
|
606 |
+
|
607 |
+
PLUGINS = {
|
608 |
+
"base": BasePlugin,
|
609 |
+
"llava": LlavaPlugin,
|
610 |
+
"llava_next": LlavaNextPlugin,
|
611 |
+
"llava_next_video": LlavaNextVideoPlugin,
|
612 |
+
"paligemma": PaliGemmaPlugin,
|
613 |
+
"qwen2_vl": Qwen2vlPlugin,
|
614 |
+
"video_llava": VideoLlavaPlugin,
|
615 |
+
}
|
616 |
+
|
617 |
+
|
618 |
+
def get_mm_plugin(
|
619 |
+
name: str,
|
620 |
+
image_token: Optional[str] = None,
|
621 |
+
video_token: Optional[str] = None,
|
622 |
+
) -> "BasePlugin":
|
623 |
+
plugin_class = PLUGINS.get(name, None)
|
624 |
+
if plugin_class is None:
|
625 |
+
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
626 |
+
|
627 |
+
return plugin_class(image_token, video_token)
|
llamafactory/data/parser.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Any, Dict, List, Literal, Optional, Sequence
|
19 |
+
|
20 |
+
from transformers.utils import cached_file
|
21 |
+
|
22 |
+
from ..extras.constants import DATA_CONFIG
|
23 |
+
from ..extras.misc import use_modelscope
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class DatasetAttr:
|
28 |
+
r"""
|
29 |
+
Dataset attributes.
|
30 |
+
"""
|
31 |
+
|
32 |
+
# basic configs
|
33 |
+
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
34 |
+
dataset_name: str
|
35 |
+
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
36 |
+
ranking: bool = False
|
37 |
+
# extra configs
|
38 |
+
subset: Optional[str] = None
|
39 |
+
split: str = "train"
|
40 |
+
folder: Optional[str] = None
|
41 |
+
num_samples: Optional[int] = None
|
42 |
+
# common columns
|
43 |
+
system: Optional[str] = None
|
44 |
+
tools: Optional[str] = None
|
45 |
+
images: Optional[str] = None
|
46 |
+
videos: Optional[str] = None
|
47 |
+
# rlhf columns
|
48 |
+
chosen: Optional[str] = None
|
49 |
+
rejected: Optional[str] = None
|
50 |
+
kto_tag: Optional[str] = None
|
51 |
+
# alpaca columns
|
52 |
+
prompt: Optional[str] = "instruction"
|
53 |
+
query: Optional[str] = "input"
|
54 |
+
response: Optional[str] = "output"
|
55 |
+
history: Optional[str] = None
|
56 |
+
# sharegpt columns
|
57 |
+
messages: Optional[str] = "conversations"
|
58 |
+
# sharegpt tags
|
59 |
+
role_tag: Optional[str] = "from"
|
60 |
+
content_tag: Optional[str] = "value"
|
61 |
+
user_tag: Optional[str] = "human"
|
62 |
+
assistant_tag: Optional[str] = "gpt"
|
63 |
+
observation_tag: Optional[str] = "observation"
|
64 |
+
function_tag: Optional[str] = "function_call"
|
65 |
+
system_tag: Optional[str] = "system"
|
66 |
+
|
67 |
+
def __repr__(self) -> str:
|
68 |
+
return self.dataset_name
|
69 |
+
|
70 |
+
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
71 |
+
setattr(self, key, obj.get(key, default))
|
72 |
+
|
73 |
+
|
74 |
+
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
75 |
+
r"""
|
76 |
+
Gets the attributes of the datasets.
|
77 |
+
"""
|
78 |
+
if dataset_names is None:
|
79 |
+
dataset_names = []
|
80 |
+
|
81 |
+
if dataset_dir == "ONLINE":
|
82 |
+
dataset_info = None
|
83 |
+
else:
|
84 |
+
if dataset_dir.startswith("REMOTE:"):
|
85 |
+
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
86 |
+
else:
|
87 |
+
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
88 |
+
|
89 |
+
try:
|
90 |
+
with open(config_path, "r") as f:
|
91 |
+
dataset_info = json.load(f)
|
92 |
+
except Exception as err:
|
93 |
+
if len(dataset_names) != 0:
|
94 |
+
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
95 |
+
|
96 |
+
dataset_info = None
|
97 |
+
|
98 |
+
dataset_list: List["DatasetAttr"] = []
|
99 |
+
for name in dataset_names:
|
100 |
+
if dataset_info is None: # dataset_dir is ONLINE
|
101 |
+
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
102 |
+
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
103 |
+
dataset_list.append(dataset_attr)
|
104 |
+
continue
|
105 |
+
|
106 |
+
if name not in dataset_info:
|
107 |
+
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
108 |
+
|
109 |
+
has_hf_url = "hf_hub_url" in dataset_info[name]
|
110 |
+
has_ms_url = "ms_hub_url" in dataset_info[name]
|
111 |
+
|
112 |
+
if has_hf_url or has_ms_url:
|
113 |
+
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
114 |
+
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
115 |
+
else:
|
116 |
+
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
117 |
+
elif "script_url" in dataset_info[name]:
|
118 |
+
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
119 |
+
else:
|
120 |
+
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
121 |
+
|
122 |
+
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
123 |
+
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
124 |
+
dataset_attr.set_attr("subset", dataset_info[name])
|
125 |
+
dataset_attr.set_attr("split", dataset_info[name], default="train")
|
126 |
+
dataset_attr.set_attr("folder", dataset_info[name])
|
127 |
+
dataset_attr.set_attr("num_samples", dataset_info[name])
|
128 |
+
|
129 |
+
if "columns" in dataset_info[name]:
|
130 |
+
column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
|
131 |
+
if dataset_attr.formatting == "alpaca":
|
132 |
+
column_names.extend(["prompt", "query", "response", "history"])
|
133 |
+
else:
|
134 |
+
column_names.extend(["messages"])
|
135 |
+
|
136 |
+
for column_name in column_names:
|
137 |
+
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
138 |
+
|
139 |
+
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
140 |
+
tag_names = (
|
141 |
+
"role_tag",
|
142 |
+
"content_tag",
|
143 |
+
"user_tag",
|
144 |
+
"assistant_tag",
|
145 |
+
"observation_tag",
|
146 |
+
"function_tag",
|
147 |
+
"system_tag",
|
148 |
+
)
|
149 |
+
for tag in tag_names:
|
150 |
+
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
151 |
+
|
152 |
+
dataset_list.append(dataset_attr)
|
153 |
+
|
154 |
+
return dataset_list
|
llamafactory/data/preprocess.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from functools import partial
|
16 |
+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
17 |
+
|
18 |
+
from .processors.feedback import preprocess_feedback_dataset
|
19 |
+
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
20 |
+
from .processors.pretrain import preprocess_pretrain_dataset
|
21 |
+
from .processors.supervised import (
|
22 |
+
preprocess_packed_supervised_dataset,
|
23 |
+
preprocess_supervised_dataset,
|
24 |
+
print_supervised_dataset_example,
|
25 |
+
)
|
26 |
+
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
|
27 |
+
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
31 |
+
|
32 |
+
from ..hparams import DataArguments
|
33 |
+
from .template import Template
|
34 |
+
|
35 |
+
|
36 |
+
def get_preprocess_and_print_func(
|
37 |
+
data_args: "DataArguments",
|
38 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
39 |
+
template: "Template",
|
40 |
+
tokenizer: "PreTrainedTokenizer",
|
41 |
+
processor: Optional["ProcessorMixin"],
|
42 |
+
do_generate: bool = False,
|
43 |
+
) -> Tuple[Callable, Callable]:
|
44 |
+
if stage == "pt":
|
45 |
+
preprocess_func = partial(
|
46 |
+
preprocess_pretrain_dataset,
|
47 |
+
tokenizer=tokenizer,
|
48 |
+
data_args=data_args,
|
49 |
+
)
|
50 |
+
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
51 |
+
elif stage == "sft" and not do_generate:
|
52 |
+
if data_args.packing:
|
53 |
+
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
54 |
+
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
55 |
+
|
56 |
+
def __init__(self, data, **kwargs):
|
57 |
+
return TypedSequence.__init__(
|
58 |
+
self,
|
59 |
+
data,
|
60 |
+
type=kwargs.pop("type", None),
|
61 |
+
try_type=kwargs.pop("try_type", None),
|
62 |
+
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
63 |
+
)
|
64 |
+
|
65 |
+
OptimizedTypedSequence.__init__ = __init__
|
66 |
+
preprocess_func = partial(
|
67 |
+
preprocess_packed_supervised_dataset,
|
68 |
+
template=template,
|
69 |
+
tokenizer=tokenizer,
|
70 |
+
processor=processor,
|
71 |
+
data_args=data_args,
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
preprocess_func = partial(
|
75 |
+
preprocess_supervised_dataset,
|
76 |
+
template=template,
|
77 |
+
tokenizer=tokenizer,
|
78 |
+
processor=processor,
|
79 |
+
data_args=data_args,
|
80 |
+
)
|
81 |
+
|
82 |
+
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
83 |
+
elif stage == "rm":
|
84 |
+
preprocess_func = partial(
|
85 |
+
preprocess_pairwise_dataset,
|
86 |
+
template=template,
|
87 |
+
tokenizer=tokenizer,
|
88 |
+
processor=processor,
|
89 |
+
data_args=data_args,
|
90 |
+
)
|
91 |
+
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
92 |
+
elif stage == "kto":
|
93 |
+
preprocess_func = partial(
|
94 |
+
preprocess_feedback_dataset,
|
95 |
+
template=template,
|
96 |
+
tokenizer=tokenizer,
|
97 |
+
processor=processor,
|
98 |
+
data_args=data_args,
|
99 |
+
)
|
100 |
+
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
101 |
+
else:
|
102 |
+
preprocess_func = partial(
|
103 |
+
preprocess_unsupervised_dataset,
|
104 |
+
template=template,
|
105 |
+
tokenizer=tokenizer,
|
106 |
+
processor=processor,
|
107 |
+
data_args=data_args,
|
108 |
+
)
|
109 |
+
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
110 |
+
|
111 |
+
return preprocess_func, print_function
|
llamafactory/data/processors/__init__.py
ADDED
File without changes
|
llamafactory/data/processors/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (173 Bytes). View file
|
|
llamafactory/data/processors/__pycache__/feedback.cpython-311.pyc
ADDED
Binary file (6.84 kB). View file
|
|
llamafactory/data/processors/__pycache__/pairwise.cpython-311.pyc
ADDED
Binary file (7.81 kB). View file
|
|