realaer commited on
Commit
f6f64ac
1 Parent(s): 6fbb369

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -8
  2. api.py +33 -0
  3. llamafactory.egg-info/PKG-INFO +815 -0
  4. llamafactory.egg-info/SOURCES.txt +123 -0
  5. llamafactory.egg-info/dependency_links.txt +1 -0
  6. llamafactory.egg-info/entry_points.txt +3 -0
  7. llamafactory.egg-info/requires.txt +82 -0
  8. llamafactory.egg-info/top_level.txt +1 -0
  9. llamafactory/__init__.py +46 -0
  10. llamafactory/__pycache__/__init__.cpython-311.pyc +0 -0
  11. llamafactory/api/__init__.py +0 -0
  12. llamafactory/api/app.py +134 -0
  13. llamafactory/api/chat.py +237 -0
  14. llamafactory/api/common.py +34 -0
  15. llamafactory/api/protocol.py +153 -0
  16. llamafactory/chat/__init__.py +19 -0
  17. llamafactory/chat/__pycache__/__init__.cpython-311.pyc +0 -0
  18. llamafactory/chat/__pycache__/base_engine.cpython-311.pyc +0 -0
  19. llamafactory/chat/__pycache__/chat_model.cpython-311.pyc +0 -0
  20. llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc +0 -0
  21. llamafactory/chat/__pycache__/vllm_engine.cpython-311.pyc +0 -0
  22. llamafactory/chat/base_engine.py +102 -0
  23. llamafactory/chat/chat_model.py +187 -0
  24. llamafactory/chat/hf_engine.py +343 -0
  25. llamafactory/chat/vllm_engine.py +230 -0
  26. llamafactory/cli.py +121 -0
  27. llamafactory/data/__init__.py +37 -0
  28. llamafactory/data/__pycache__/__init__.cpython-311.pyc +0 -0
  29. llamafactory/data/__pycache__/aligner.cpython-311.pyc +0 -0
  30. llamafactory/data/__pycache__/collator.cpython-311.pyc +0 -0
  31. llamafactory/data/__pycache__/data_utils.cpython-311.pyc +0 -0
  32. llamafactory/data/__pycache__/formatter.cpython-311.pyc +0 -0
  33. llamafactory/data/__pycache__/loader.cpython-311.pyc +0 -0
  34. llamafactory/data/__pycache__/mm_plugin.cpython-311.pyc +0 -0
  35. llamafactory/data/__pycache__/parser.cpython-311.pyc +0 -0
  36. llamafactory/data/__pycache__/preprocess.cpython-311.pyc +0 -0
  37. llamafactory/data/__pycache__/template.cpython-311.pyc +0 -0
  38. llamafactory/data/__pycache__/tool_utils.cpython-311.pyc +0 -0
  39. llamafactory/data/aligner.py +258 -0
  40. llamafactory/data/collator.py +189 -0
  41. llamafactory/data/data_utils.py +92 -0
  42. llamafactory/data/formatter.py +148 -0
  43. llamafactory/data/loader.py +292 -0
  44. llamafactory/data/mm_plugin.py +627 -0
  45. llamafactory/data/parser.py +154 -0
  46. llamafactory/data/preprocess.py +111 -0
  47. llamafactory/data/processors/__init__.py +0 -0
  48. llamafactory/data/processors/__pycache__/__init__.cpython-311.pyc +0 -0
  49. llamafactory/data/processors/__pycache__/feedback.cpython-311.pyc +0 -0
  50. llamafactory/data/processors/__pycache__/pairwise.cpython-311.pyc +0 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Src
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.44.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: src
3
+ app_file: webui.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.44.1
 
 
6
  ---
 
 
api.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import uvicorn
18
+
19
+ from llamafactory.api.app import create_app
20
+ from llamafactory.chat import ChatModel
21
+
22
+
23
+ def main():
24
+ chat_model = ChatModel()
25
+ app = create_app(chat_model)
26
+ api_host = os.environ.get("API_HOST", "0.0.0.0")
27
+ api_port = int(os.environ.get("API_PORT", "8000"))
28
+ print("Visit http://localhost:{}/docs for API document.".format(api_port))
29
+ uvicorn.run(app, host=api_host, port=api_port)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main()
llamafactory.egg-info/PKG-INFO ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: llamafactory
3
+ Version: 0.9.1.dev0
4
+ Summary: Easy-to-use LLM fine-tuning framework
5
+ Home-page: https://github.com/hiyouga/LLaMA-Factory
6
+ Author: hiyouga
7
+ Author-email: [email protected]
8
+ License: Apache 2.0 License
9
+ Keywords: LLaMA,BLOOM,Falcon,LLM,ChatGPT,transformer,pytorch,deep learning
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: Apache Software License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.8
18
+ Classifier: Programming Language :: Python :: 3.9
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Requires-Python: >=3.8.0
23
+ Description-Content-Type: text/markdown
24
+ License-File: LICENSE
25
+ Requires-Dist: transformers<=4.45.0,>=4.41.2
26
+ Requires-Dist: datasets<=2.21.0,>=2.16.0
27
+ Requires-Dist: accelerate<=0.34.2,>=0.30.1
28
+ Requires-Dist: peft<=0.12.0,>=0.11.1
29
+ Requires-Dist: trl<=0.9.6,>=0.8.6
30
+ Requires-Dist: gradio>=4.0.0
31
+ Requires-Dist: pandas>=2.0.0
32
+ Requires-Dist: scipy
33
+ Requires-Dist: einops
34
+ Requires-Dist: sentencepiece
35
+ Requires-Dist: tiktoken
36
+ Requires-Dist: protobuf
37
+ Requires-Dist: uvicorn
38
+ Requires-Dist: pydantic
39
+ Requires-Dist: fastapi
40
+ Requires-Dist: sse-starlette
41
+ Requires-Dist: matplotlib>=3.7.0
42
+ Requires-Dist: fire
43
+ Requires-Dist: packaging
44
+ Requires-Dist: pyyaml
45
+ Requires-Dist: numpy<2.0.0
46
+ Requires-Dist: av
47
+ Provides-Extra: torch
48
+ Requires-Dist: torch>=1.13.1; extra == "torch"
49
+ Provides-Extra: torch-npu
50
+ Requires-Dist: torch==2.1.0; extra == "torch-npu"
51
+ Requires-Dist: torch-npu==2.1.0.post3; extra == "torch-npu"
52
+ Requires-Dist: decorator; extra == "torch-npu"
53
+ Provides-Extra: metrics
54
+ Requires-Dist: nltk; extra == "metrics"
55
+ Requires-Dist: jieba; extra == "metrics"
56
+ Requires-Dist: rouge-chinese; extra == "metrics"
57
+ Provides-Extra: deepspeed
58
+ Requires-Dist: deepspeed<=0.14.4,>=0.10.0; extra == "deepspeed"
59
+ Provides-Extra: liger-kernel
60
+ Requires-Dist: liger-kernel; extra == "liger-kernel"
61
+ Provides-Extra: bitsandbytes
62
+ Requires-Dist: bitsandbytes>=0.39.0; extra == "bitsandbytes"
63
+ Provides-Extra: hqq
64
+ Requires-Dist: hqq; extra == "hqq"
65
+ Provides-Extra: eetq
66
+ Requires-Dist: eetq; extra == "eetq"
67
+ Provides-Extra: gptq
68
+ Requires-Dist: optimum>=1.17.0; extra == "gptq"
69
+ Requires-Dist: auto-gptq>=0.5.0; extra == "gptq"
70
+ Provides-Extra: awq
71
+ Requires-Dist: autoawq; extra == "awq"
72
+ Provides-Extra: aqlm
73
+ Requires-Dist: aqlm[gpu]>=1.1.0; extra == "aqlm"
74
+ Provides-Extra: vllm
75
+ Requires-Dist: vllm<=0.6.2,>=0.4.3; extra == "vllm"
76
+ Provides-Extra: galore
77
+ Requires-Dist: galore-torch; extra == "galore"
78
+ Provides-Extra: badam
79
+ Requires-Dist: badam>=1.2.1; extra == "badam"
80
+ Provides-Extra: adam-mini
81
+ Requires-Dist: adam-mini; extra == "adam-mini"
82
+ Provides-Extra: qwen
83
+ Requires-Dist: transformers_stream_generator; extra == "qwen"
84
+ Provides-Extra: modelscope
85
+ Requires-Dist: modelscope; extra == "modelscope"
86
+ Provides-Extra: dev
87
+ Requires-Dist: ruff; extra == "dev"
88
+ Requires-Dist: pytest; extra == "dev"
89
+
90
+ ![# LLaMA Factory](assets/logo.png)
91
+
92
+ [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
93
+ [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
94
+ [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
95
+ [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
96
+ [![Citation](https://img.shields.io/badge/citation-91-green)](#projects-using-llama-factory)
97
+ [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
98
+ [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
99
+ [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
100
+ [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
101
+ [![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
102
+ [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
103
+ [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
104
+
105
+ [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
106
+
107
+ 👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
108
+
109
+ \[ English | [中文](README_zh.md) \]
110
+
111
+ **Fine-tuning a large language model can be easy as...**
112
+
113
+ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
114
+
115
+ Choose your path:
116
+
117
+ - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
118
+ - **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
119
+ - **Local machine**: Please refer to [usage](#getting-started)
120
+ - **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
121
+
122
+ > [!NOTE]
123
+ > Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
124
+
125
+ ## Table of Contents
126
+
127
+ - [Features](#features)
128
+ - [Benchmark](#benchmark)
129
+ - [Changelog](#changelog)
130
+ - [Supported Models](#supported-models)
131
+ - [Supported Training Approaches](#supported-training-approaches)
132
+ - [Provided Datasets](#provided-datasets)
133
+ - [Requirement](#requirement)
134
+ - [Getting Started](#getting-started)
135
+ - [Projects using LLaMA Factory](#projects-using-llama-factory)
136
+ - [License](#license)
137
+ - [Citation](#citation)
138
+ - [Acknowledgement](#acknowledgement)
139
+
140
+ ## Features
141
+
142
+ - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
143
+ - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
144
+ - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
145
+ - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
146
+ - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
147
+ - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
148
+ - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
149
+
150
+ ## Benchmark
151
+
152
+ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
153
+
154
+ ![benchmark](assets/benchmark.svg)
155
+
156
+ <details><summary>Definitions</summary>
157
+
158
+ - **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
159
+ - **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
160
+ - **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
161
+ - We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
162
+
163
+ </details>
164
+
165
+ ## Changelog
166
+
167
+ [24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
168
+
169
+ [24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
170
+
171
+ [24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
172
+
173
+ [24/08/09] We support **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
174
+
175
+ <details><summary>Full Changelog</summary>
176
+
177
+ [24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
178
+
179
+ [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
180
+
181
+ [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
182
+
183
+ [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
184
+
185
+ [24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
186
+
187
+ [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
188
+
189
+ [24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
190
+
191
+ [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
192
+
193
+ [24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
194
+
195
+ [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
196
+
197
+ [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
198
+
199
+ [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
200
+
201
+ [24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage.
202
+
203
+ [24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
204
+
205
+ [24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage.
206
+
207
+ [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
208
+
209
+ [24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
210
+
211
+ [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
212
+
213
+ [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training.
214
+
215
+ [24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage.
216
+
217
+ [24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
218
+
219
+ [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
220
+
221
+ [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
222
+
223
+ [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
224
+
225
+ [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
226
+
227
+ [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
228
+
229
+ [23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention.
230
+
231
+ [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage.
232
+
233
+ [23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
234
+
235
+ [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings.
236
+
237
+ [23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
238
+
239
+ [23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
240
+
241
+ [23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
242
+
243
+ [23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
244
+
245
+ [23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
246
+
247
+ [23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
248
+
249
+ [23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
250
+
251
+ [23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage.
252
+
253
+ </details>
254
+
255
+ ## Supported Models
256
+
257
+ | Model | Model size | Template |
258
+ | ----------------------------------------------------------------- | -------------------------------- | ---------------- |
259
+ | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
260
+ | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
261
+ | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
262
+ | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
263
+ | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
264
+ | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
265
+ | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
266
+ | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
267
+ | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
268
+ | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
269
+ | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
270
+ | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
271
+ | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
272
+ | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
273
+ | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
274
+ | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
275
+ | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
276
+ | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
277
+ | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
278
+ | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
279
+ | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
280
+ | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
281
+ | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
282
+ | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
283
+ | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
284
+ | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
285
+ | [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
286
+ | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
287
+
288
+ > [!NOTE]
289
+ > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
290
+ >
291
+ > Remember to use the **SAME** template in training and inference.
292
+
293
+ Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
294
+
295
+ You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
296
+
297
+ ## Supported Training Approaches
298
+
299
+ | Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
300
+ | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
301
+ | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
302
+ | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
303
+ | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
304
+ | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
305
+ | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
306
+ | KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
307
+ | ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
308
+ | SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
309
+
310
+ > [!TIP]
311
+ > The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
312
+
313
+ ## Provided Datasets
314
+
315
+ <details><summary>Pre-training datasets</summary>
316
+
317
+ - [Wiki Demo (en)](data/wiki_demo.txt)
318
+ - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
319
+ - [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
320
+ - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
321
+ - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
322
+ - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
323
+ - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
324
+ - [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
325
+ - [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
326
+ - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
327
+ - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
328
+
329
+ </details>
330
+
331
+ <details><summary>Supervised fine-tuning datasets</summary>
332
+
333
+ - [Identity (en&zh)](data/identity.json)
334
+ - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
335
+ - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
336
+ - [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
337
+ - [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
338
+ - [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
339
+ - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
340
+ - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
341
+ - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
342
+ - [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
343
+ - [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
344
+ - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
345
+ - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
346
+ - [UltraChat (en)](https://github.com/thunlp/UltraChat)
347
+ - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
348
+ - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
349
+ - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
350
+ - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
351
+ - [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
352
+ - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
353
+ - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
354
+ - [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
355
+ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
356
+ - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
357
+ - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
358
+ - [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
359
+ - [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
360
+ - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
361
+ - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
362
+ - [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
363
+ - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
364
+ - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
365
+ - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
366
+ - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
367
+ - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
368
+ - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
369
+ - [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
370
+ - [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
371
+ - [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
372
+ - [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
373
+ - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
374
+ - [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
375
+ - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
376
+ - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
377
+ - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
378
+ - [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
379
+ - [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
380
+ - [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
381
+ - [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
382
+ - [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
383
+ - [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
384
+
385
+ </details>
386
+
387
+ <details><summary>Preference datasets</summary>
388
+
389
+ - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
390
+ - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
391
+ - [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
392
+ - [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
393
+ - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
394
+ - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
395
+ - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
396
+ - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
397
+ - [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
398
+
399
+ </details>
400
+
401
+ Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
402
+
403
+ ```bash
404
+ pip install --upgrade huggingface_hub
405
+ huggingface-cli login
406
+ ```
407
+
408
+ ## Requirement
409
+
410
+ | Mandatory | Minimum | Recommend |
411
+ | ------------ | ------- | --------- |
412
+ | python | 3.8 | 3.11 |
413
+ | torch | 1.13.1 | 2.4.0 |
414
+ | transformers | 4.41.2 | 4.43.4 |
415
+ | datasets | 2.16.0 | 2.20.0 |
416
+ | accelerate | 0.30.1 | 0.32.0 |
417
+ | peft | 0.11.1 | 0.12.0 |
418
+ | trl | 0.8.6 | 0.9.6 |
419
+
420
+ | Optional | Minimum | Recommend |
421
+ | ------------ | ------- | --------- |
422
+ | CUDA | 11.6 | 12.2 |
423
+ | deepspeed | 0.10.0 | 0.14.0 |
424
+ | bitsandbytes | 0.39.0 | 0.43.1 |
425
+ | vllm | 0.4.3 | 0.5.0 |
426
+ | flash-attn | 2.3.0 | 2.6.3 |
427
+
428
+ ### Hardware Requirement
429
+
430
+ \* *estimated*
431
+
432
+ | Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
433
+ | ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
434
+ | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
435
+ | Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
436
+ | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
437
+ | LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
438
+ | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
439
+ | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
440
+ | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
441
+
442
+ ## Getting Started
443
+
444
+ ### Installation
445
+
446
+ > [!IMPORTANT]
447
+ > Installation is mandatory.
448
+
449
+ ```bash
450
+ git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
451
+ cd LLaMA-Factory
452
+ pip install -e ".[torch,metrics]"
453
+ ```
454
+
455
+ Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
456
+
457
+ > [!TIP]
458
+ > Use `pip install --no-deps -e .` to resolve package conflicts.
459
+
460
+ <details><summary>For Windows users</summary>
461
+
462
+ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
463
+
464
+ ```bash
465
+ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
466
+ ```
467
+
468
+ To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
469
+
470
+ </details>
471
+
472
+ <details><summary>For Ascend NPU users</summary>
473
+
474
+ To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
475
+
476
+ ```bash
477
+ # replace the url according to your CANN version and devices
478
+ # install CANN Toolkit
479
+ wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
480
+ bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
481
+
482
+ # install CANN Kernels
483
+ wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
484
+ bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
485
+
486
+ # set env variables
487
+ source /usr/local/Ascend/ascend-toolkit/set_env.sh
488
+ ```
489
+
490
+ | Requirement | Minimum | Recommend |
491
+ | ------------ | ------- | ----------- |
492
+ | CANN | 8.0.RC1 | 8.0.RC1 |
493
+ | torch | 2.1.0 | 2.1.0 |
494
+ | torch-npu | 2.1.0 | 2.1.0.post3 |
495
+ | deepspeed | 0.13.2 | 0.13.2 |
496
+
497
+ Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
498
+
499
+ If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
500
+
501
+ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
502
+
503
+ </details>
504
+
505
+ ### Data Preparation
506
+
507
+ Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
508
+
509
+ > [!NOTE]
510
+ > Please update `data/dataset_info.json` to use your custom dataset.
511
+
512
+ ### Quickstart
513
+
514
+ Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
515
+
516
+ ```bash
517
+ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
518
+ llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
519
+ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
520
+ ```
521
+
522
+ See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
523
+
524
+ > [!TIP]
525
+ > Use `llamafactory-cli help` to show help information.
526
+
527
+ ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
528
+
529
+ ```bash
530
+ llamafactory-cli webui
531
+ ```
532
+
533
+ ### Build Docker
534
+
535
+ For CUDA users:
536
+
537
+ ```bash
538
+ cd docker/docker-cuda/
539
+ docker compose up -d
540
+ docker compose exec llamafactory bash
541
+ ```
542
+
543
+ For Ascend NPU users:
544
+
545
+ ```bash
546
+ cd docker/docker-npu/
547
+ docker compose up -d
548
+ docker compose exec llamafactory bash
549
+ ```
550
+
551
+ For AMD ROCm users:
552
+
553
+ ```bash
554
+ cd docker/docker-rocm/
555
+ docker compose up -d
556
+ docker compose exec llamafactory bash
557
+ ```
558
+
559
+ <details><summary>Build without Docker Compose</summary>
560
+
561
+ For CUDA users:
562
+
563
+ ```bash
564
+ docker build -f ./docker/docker-cuda/Dockerfile \
565
+ --build-arg INSTALL_BNB=false \
566
+ --build-arg INSTALL_VLLM=false \
567
+ --build-arg INSTALL_DEEPSPEED=false \
568
+ --build-arg INSTALL_FLASHATTN=false \
569
+ --build-arg PIP_INDEX=https://pypi.org/simple \
570
+ -t llamafactory:latest .
571
+
572
+ docker run -dit --gpus=all \
573
+ -v ./hf_cache:/root/.cache/huggingface \
574
+ -v ./ms_cache:/root/.cache/modelscope \
575
+ -v ./data:/app/data \
576
+ -v ./output:/app/output \
577
+ -p 7860:7860 \
578
+ -p 8000:8000 \
579
+ --shm-size 16G \
580
+ --name llamafactory \
581
+ llamafactory:latest
582
+
583
+ docker exec -it llamafactory bash
584
+ ```
585
+
586
+ For Ascend NPU users:
587
+
588
+ ```bash
589
+ # Choose docker image upon your environment
590
+ docker build -f ./docker/docker-npu/Dockerfile \
591
+ --build-arg INSTALL_DEEPSPEED=false \
592
+ --build-arg PIP_INDEX=https://pypi.org/simple \
593
+ -t llamafactory:latest .
594
+
595
+ # Change `device` upon your resources
596
+ docker run -dit \
597
+ -v ./hf_cache:/root/.cache/huggingface \
598
+ -v ./ms_cache:/root/.cache/modelscope \
599
+ -v ./data:/app/data \
600
+ -v ./output:/app/output \
601
+ -v /usr/local/dcmi:/usr/local/dcmi \
602
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
603
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
604
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
605
+ -p 7860:7860 \
606
+ -p 8000:8000 \
607
+ --device /dev/davinci0 \
608
+ --device /dev/davinci_manager \
609
+ --device /dev/devmm_svm \
610
+ --device /dev/hisi_hdc \
611
+ --shm-size 16G \
612
+ --name llamafactory \
613
+ llamafactory:latest
614
+
615
+ docker exec -it llamafactory bash
616
+ ```
617
+
618
+ For AMD ROCm users:
619
+
620
+ ```bash
621
+ docker build -f ./docker/docker-rocm/Dockerfile \
622
+ --build-arg INSTALL_BNB=false \
623
+ --build-arg INSTALL_VLLM=false \
624
+ --build-arg INSTALL_DEEPSPEED=false \
625
+ --build-arg INSTALL_FLASHATTN=false \
626
+ --build-arg PIP_INDEX=https://pypi.org/simple \
627
+ -t llamafactory:latest .
628
+
629
+ docker run -dit \
630
+ -v ./hf_cache:/root/.cache/huggingface \
631
+ -v ./ms_cache:/root/.cache/modelscope \
632
+ -v ./data:/app/data \
633
+ -v ./output:/app/output \
634
+ -v ./saves:/app/saves \
635
+ -p 7860:7860 \
636
+ -p 8000:8000 \
637
+ --device /dev/kfd \
638
+ --device /dev/dri \
639
+ --shm-size 16G \
640
+ --name llamafactory \
641
+ llamafactory:latest
642
+
643
+ docker exec -it llamafactory bash
644
+ ```
645
+
646
+ </details>
647
+
648
+ <details><summary>Details about volume</summary>
649
+
650
+ - `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
651
+ - `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
652
+ - `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
653
+ - `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
654
+
655
+ </details>
656
+
657
+ ### Deploy with OpenAI-style API and vLLM
658
+
659
+ ```bash
660
+ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
661
+ ```
662
+
663
+ > [!TIP]
664
+ > Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
665
+
666
+ ### Download from ModelScope Hub
667
+
668
+ If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
669
+
670
+ ```bash
671
+ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
672
+ ```
673
+
674
+ Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
675
+
676
+ ### Use W&B Logger
677
+
678
+ To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
679
+
680
+ ```yaml
681
+ report_to: wandb
682
+ run_name: test_run # optional
683
+ ```
684
+
685
+ Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
686
+
687
+ ## Projects using LLaMA Factory
688
+
689
+ If you have a project that should be incorporated, please contact via email or create a pull request.
690
+
691
+ <details><summary>Click to show</summary>
692
+
693
+ 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
694
+ 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
695
+ 1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
696
+ 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
697
+ 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
698
+ 1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
699
+ 1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
700
+ 1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
701
+ 1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
702
+ 1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
703
+ 1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
704
+ 1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
705
+ 1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
706
+ 1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
707
+ 1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
708
+ 1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
709
+ 1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
710
+ 1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
711
+ 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
712
+ 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
713
+ 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
714
+ 1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
715
+ 1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
716
+ 1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
717
+ 1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
718
+ 1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
719
+ 1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
720
+ 1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
721
+ 1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
722
+ 1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
723
+ 1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
724
+ 1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
725
+ 1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
726
+ 1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
727
+ 1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
728
+ 1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
729
+ 1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
730
+ 1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
731
+ 1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
732
+ 1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
733
+ 1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
734
+ 1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
735
+ 1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
736
+ 1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
737
+ 1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
738
+ 1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
739
+ 1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
740
+ 1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
741
+ 1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
742
+ 1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
743
+ 1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
744
+ 1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
745
+ 1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
746
+ 1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
747
+ 1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
748
+ 1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
749
+ 1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
750
+ 1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
751
+ 1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
752
+ 1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
753
+ 1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
754
+ 1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
755
+ 1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
756
+ 1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
757
+ 1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
758
+ 1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
759
+ 1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
760
+ 1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
761
+ 1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
762
+ 1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
763
+ 1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
764
+ 1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
765
+ 1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
766
+ 1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
767
+ 1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
768
+ 1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
769
+ 1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
770
+ 1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
771
+ 1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
772
+ 1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
773
+ 1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
774
+ 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
775
+ 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
776
+ 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
777
+ 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
778
+ 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
779
+ 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
780
+ 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
781
+ 1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
782
+ 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
783
+ 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
784
+
785
+ </details>
786
+
787
+ ## License
788
+
789
+ This repository is licensed under the [Apache-2.0 License](LICENSE).
790
+
791
+ Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
792
+
793
+ ## Citation
794
+
795
+ If this work is helpful, please kindly cite as:
796
+
797
+ ```bibtex
798
+ @inproceedings{zheng2024llamafactory,
799
+ title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
800
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
801
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
802
+ address={Bangkok, Thailand},
803
+ publisher={Association for Computational Linguistics},
804
+ year={2024},
805
+ url={http://arxiv.org/abs/2403.13372}
806
+ }
807
+ ```
808
+
809
+ ## Acknowledgement
810
+
811
+ This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
812
+
813
+ ## Star History
814
+
815
+ ![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)
llamafactory.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ MANIFEST.in
3
+ README.md
4
+ pyproject.toml
5
+ requirements.txt
6
+ setup.py
7
+ src/llamafactory/__init__.py
8
+ src/llamafactory/cli.py
9
+ src/llamafactory/launcher.py
10
+ src/llamafactory.egg-info/PKG-INFO
11
+ src/llamafactory.egg-info/SOURCES.txt
12
+ src/llamafactory.egg-info/dependency_links.txt
13
+ src/llamafactory.egg-info/entry_points.txt
14
+ src/llamafactory.egg-info/requires.txt
15
+ src/llamafactory.egg-info/top_level.txt
16
+ src/llamafactory/api/__init__.py
17
+ src/llamafactory/api/app.py
18
+ src/llamafactory/api/chat.py
19
+ src/llamafactory/api/common.py
20
+ src/llamafactory/api/protocol.py
21
+ src/llamafactory/chat/__init__.py
22
+ src/llamafactory/chat/base_engine.py
23
+ src/llamafactory/chat/chat_model.py
24
+ src/llamafactory/chat/hf_engine.py
25
+ src/llamafactory/chat/vllm_engine.py
26
+ src/llamafactory/data/__init__.py
27
+ src/llamafactory/data/aligner.py
28
+ src/llamafactory/data/collator.py
29
+ src/llamafactory/data/data_utils.py
30
+ src/llamafactory/data/formatter.py
31
+ src/llamafactory/data/loader.py
32
+ src/llamafactory/data/mm_plugin.py
33
+ src/llamafactory/data/parser.py
34
+ src/llamafactory/data/preprocess.py
35
+ src/llamafactory/data/template.py
36
+ src/llamafactory/data/tool_utils.py
37
+ src/llamafactory/data/processors/__init__.py
38
+ src/llamafactory/data/processors/feedback.py
39
+ src/llamafactory/data/processors/pairwise.py
40
+ src/llamafactory/data/processors/pretrain.py
41
+ src/llamafactory/data/processors/processor_utils.py
42
+ src/llamafactory/data/processors/supervised.py
43
+ src/llamafactory/data/processors/unsupervised.py
44
+ src/llamafactory/eval/__init__.py
45
+ src/llamafactory/eval/evaluator.py
46
+ src/llamafactory/eval/template.py
47
+ src/llamafactory/extras/__init__.py
48
+ src/llamafactory/extras/constants.py
49
+ src/llamafactory/extras/env.py
50
+ src/llamafactory/extras/logging.py
51
+ src/llamafactory/extras/misc.py
52
+ src/llamafactory/extras/packages.py
53
+ src/llamafactory/extras/ploting.py
54
+ src/llamafactory/hparams/__init__.py
55
+ src/llamafactory/hparams/data_args.py
56
+ src/llamafactory/hparams/evaluation_args.py
57
+ src/llamafactory/hparams/finetuning_args.py
58
+ src/llamafactory/hparams/generating_args.py
59
+ src/llamafactory/hparams/model_args.py
60
+ src/llamafactory/hparams/parser.py
61
+ src/llamafactory/model/__init__.py
62
+ src/llamafactory/model/adapter.py
63
+ src/llamafactory/model/loader.py
64
+ src/llamafactory/model/patcher.py
65
+ src/llamafactory/model/model_utils/__init__.py
66
+ src/llamafactory/model/model_utils/attention.py
67
+ src/llamafactory/model/model_utils/checkpointing.py
68
+ src/llamafactory/model/model_utils/embedding.py
69
+ src/llamafactory/model/model_utils/liger_kernel.py
70
+ src/llamafactory/model/model_utils/longlora.py
71
+ src/llamafactory/model/model_utils/misc.py
72
+ src/llamafactory/model/model_utils/mod.py
73
+ src/llamafactory/model/model_utils/moe.py
74
+ src/llamafactory/model/model_utils/packing.py
75
+ src/llamafactory/model/model_utils/quantization.py
76
+ src/llamafactory/model/model_utils/rope.py
77
+ src/llamafactory/model/model_utils/unsloth.py
78
+ src/llamafactory/model/model_utils/valuehead.py
79
+ src/llamafactory/model/model_utils/visual.py
80
+ src/llamafactory/train/__init__.py
81
+ src/llamafactory/train/callbacks.py
82
+ src/llamafactory/train/test_utils.py
83
+ src/llamafactory/train/trainer_utils.py
84
+ src/llamafactory/train/tuner.py
85
+ src/llamafactory/train/dpo/__init__.py
86
+ src/llamafactory/train/dpo/trainer.py
87
+ src/llamafactory/train/dpo/workflow.py
88
+ src/llamafactory/train/kto/__init__.py
89
+ src/llamafactory/train/kto/trainer.py
90
+ src/llamafactory/train/kto/workflow.py
91
+ src/llamafactory/train/ppo/__init__.py
92
+ src/llamafactory/train/ppo/ppo_utils.py
93
+ src/llamafactory/train/ppo/trainer.py
94
+ src/llamafactory/train/ppo/workflow.py
95
+ src/llamafactory/train/pt/__init__.py
96
+ src/llamafactory/train/pt/trainer.py
97
+ src/llamafactory/train/pt/workflow.py
98
+ src/llamafactory/train/rm/__init__.py
99
+ src/llamafactory/train/rm/metric.py
100
+ src/llamafactory/train/rm/trainer.py
101
+ src/llamafactory/train/rm/workflow.py
102
+ src/llamafactory/train/sft/__init__.py
103
+ src/llamafactory/train/sft/metric.py
104
+ src/llamafactory/train/sft/trainer.py
105
+ src/llamafactory/train/sft/workflow.py
106
+ src/llamafactory/webui/__init__.py
107
+ src/llamafactory/webui/chatter.py
108
+ src/llamafactory/webui/common.py
109
+ src/llamafactory/webui/css.py
110
+ src/llamafactory/webui/engine.py
111
+ src/llamafactory/webui/interface.py
112
+ src/llamafactory/webui/locales.py
113
+ src/llamafactory/webui/manager.py
114
+ src/llamafactory/webui/runner.py
115
+ src/llamafactory/webui/utils.py
116
+ src/llamafactory/webui/components/__init__.py
117
+ src/llamafactory/webui/components/chatbot.py
118
+ src/llamafactory/webui/components/data.py
119
+ src/llamafactory/webui/components/eval.py
120
+ src/llamafactory/webui/components/export.py
121
+ src/llamafactory/webui/components/infer.py
122
+ src/llamafactory/webui/components/top.py
123
+ src/llamafactory/webui/components/train.py
llamafactory.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
llamafactory.egg-info/entry_points.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [console_scripts]
2
+ llamafactory-cli = llamafactory.cli:main
3
+ lmf = llamafactory.cli:main
llamafactory.egg-info/requires.txt ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers<=4.45.0,>=4.41.2
2
+ datasets<=2.21.0,>=2.16.0
3
+ accelerate<=0.34.2,>=0.30.1
4
+ peft<=0.12.0,>=0.11.1
5
+ trl<=0.9.6,>=0.8.6
6
+ gradio>=4.0.0
7
+ pandas>=2.0.0
8
+ scipy
9
+ einops
10
+ sentencepiece
11
+ tiktoken
12
+ protobuf
13
+ uvicorn
14
+ pydantic
15
+ fastapi
16
+ sse-starlette
17
+ matplotlib>=3.7.0
18
+ fire
19
+ packaging
20
+ pyyaml
21
+ numpy<2.0.0
22
+ av
23
+
24
+ [adam-mini]
25
+ adam-mini
26
+
27
+ [aqlm]
28
+ aqlm[gpu]>=1.1.0
29
+
30
+ [awq]
31
+ autoawq
32
+
33
+ [badam]
34
+ badam>=1.2.1
35
+
36
+ [bitsandbytes]
37
+ bitsandbytes>=0.39.0
38
+
39
+ [deepspeed]
40
+ deepspeed<=0.14.4,>=0.10.0
41
+
42
+ [dev]
43
+ ruff
44
+ pytest
45
+
46
+ [eetq]
47
+ eetq
48
+
49
+ [galore]
50
+ galore-torch
51
+
52
+ [gptq]
53
+ optimum>=1.17.0
54
+ auto-gptq>=0.5.0
55
+
56
+ [hqq]
57
+ hqq
58
+
59
+ [liger-kernel]
60
+ liger-kernel
61
+
62
+ [metrics]
63
+ nltk
64
+ jieba
65
+ rouge-chinese
66
+
67
+ [modelscope]
68
+ modelscope
69
+
70
+ [qwen]
71
+ transformers_stream_generator
72
+
73
+ [torch]
74
+ torch>=1.13.1
75
+
76
+ [torch-npu]
77
+ torch==2.1.0
78
+ torch-npu==2.1.0.post3
79
+ decorator
80
+
81
+ [vllm]
82
+ vllm<=0.6.2,>=0.4.3
llamafactory.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ llamafactory
llamafactory/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""
16
+ Efficient fine-tuning of large language models.
17
+
18
+ Level:
19
+ api, webui > chat, eval, train > data, model > hparams > extras
20
+
21
+ Dependency graph:
22
+ main:
23
+ transformers>=4.41.2,<=4.45.0
24
+ datasets>=2.16.0,<=2.21.0
25
+ accelerate>=0.30.1,<=0.34.2
26
+ peft>=0.11.1,<=0.12.0
27
+ trl>=0.8.6,<=0.9.6
28
+ attention:
29
+ transformers>=4.42.4 (gemma+fa2)
30
+ longlora:
31
+ transformers>=4.41.2,<=4.45.0
32
+ packing:
33
+ transformers>=4.41.2,<=4.45.0
34
+
35
+ Disable version checking: DISABLE_VERSION_CHECK=1
36
+ Enable VRAM recording: RECORD_VRAM=1
37
+ Force check imports: FORCE_CHECK_IMPORTS=1
38
+ Force using torchrun: FORCE_TORCHRUN=1
39
+ Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
40
+ Use modelscope: USE_MODELSCOPE_HUB=1
41
+ """
42
+
43
+ from .extras.env import VERSION
44
+
45
+
46
+ __version__ = VERSION
llamafactory/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (959 Bytes). View file
 
llamafactory/api/__init__.py ADDED
File without changes
llamafactory/api/app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import os
17
+ from contextlib import asynccontextmanager
18
+ from functools import partial
19
+ from typing import Optional
20
+
21
+ from typing_extensions import Annotated
22
+
23
+ from ..chat import ChatModel
24
+ from ..extras.misc import torch_gc
25
+ from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
26
+ from .chat import (
27
+ create_chat_completion_response,
28
+ create_score_evaluation_response,
29
+ create_stream_chat_completion_response,
30
+ )
31
+ from .protocol import (
32
+ ChatCompletionRequest,
33
+ ChatCompletionResponse,
34
+ ModelCard,
35
+ ModelList,
36
+ ScoreEvaluationRequest,
37
+ ScoreEvaluationResponse,
38
+ )
39
+
40
+
41
+ if is_fastapi_available():
42
+ from fastapi import Depends, FastAPI, HTTPException, status
43
+ from fastapi.middleware.cors import CORSMiddleware
44
+ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
45
+
46
+
47
+ if is_starlette_available():
48
+ from sse_starlette import EventSourceResponse
49
+
50
+
51
+ if is_uvicorn_available():
52
+ import uvicorn
53
+
54
+
55
+ async def sweeper() -> None:
56
+ while True:
57
+ torch_gc()
58
+ await asyncio.sleep(300)
59
+
60
+
61
+ @asynccontextmanager
62
+ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
63
+ if chat_model.engine_type == "huggingface":
64
+ asyncio.create_task(sweeper())
65
+
66
+ yield
67
+ torch_gc()
68
+
69
+
70
+ def create_app(chat_model: "ChatModel") -> "FastAPI":
71
+ root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
72
+ app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
73
+ app.add_middleware(
74
+ CORSMiddleware,
75
+ allow_origins=["*"],
76
+ allow_credentials=True,
77
+ allow_methods=["*"],
78
+ allow_headers=["*"],
79
+ )
80
+ api_key = os.environ.get("API_KEY", None)
81
+ security = HTTPBearer(auto_error=False)
82
+
83
+ async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
84
+ if api_key and (auth is None or auth.credentials != api_key):
85
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
86
+
87
+ @app.get(
88
+ "/v1/models",
89
+ response_model=ModelList,
90
+ status_code=status.HTTP_200_OK,
91
+ dependencies=[Depends(verify_api_key)],
92
+ )
93
+ async def list_models():
94
+ model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
95
+ return ModelList(data=[model_card])
96
+
97
+ @app.post(
98
+ "/v1/chat/completions",
99
+ response_model=ChatCompletionResponse,
100
+ status_code=status.HTTP_200_OK,
101
+ dependencies=[Depends(verify_api_key)],
102
+ )
103
+ async def create_chat_completion(request: ChatCompletionRequest):
104
+ if not chat_model.engine.can_generate:
105
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
106
+
107
+ if request.stream:
108
+ generate = create_stream_chat_completion_response(request, chat_model)
109
+ return EventSourceResponse(generate, media_type="text/event-stream")
110
+ else:
111
+ return await create_chat_completion_response(request, chat_model)
112
+
113
+ @app.post(
114
+ "/v1/score/evaluation",
115
+ response_model=ScoreEvaluationResponse,
116
+ status_code=status.HTTP_200_OK,
117
+ dependencies=[Depends(verify_api_key)],
118
+ )
119
+ async def create_score_evaluation(request: ScoreEvaluationRequest):
120
+ if chat_model.engine.can_generate:
121
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
122
+
123
+ return await create_score_evaluation_response(request, chat_model)
124
+
125
+ return app
126
+
127
+
128
+ def run_api() -> None:
129
+ chat_model = ChatModel()
130
+ app = create_app(chat_model)
131
+ api_host = os.environ.get("API_HOST", "0.0.0.0")
132
+ api_port = int(os.environ.get("API_PORT", "8000"))
133
+ print("Visit http://localhost:{}/docs for API document.".format(api_port))
134
+ uvicorn.run(app, host=api_host, port=api_port)
llamafactory/api/chat.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import io
17
+ import json
18
+ import os
19
+ import re
20
+ import uuid
21
+ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
22
+
23
+ from ..data import Role as DataRole
24
+ from ..extras.logging import get_logger
25
+ from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
26
+ from .common import dictify, jsonify
27
+ from .protocol import (
28
+ ChatCompletionMessage,
29
+ ChatCompletionResponse,
30
+ ChatCompletionResponseChoice,
31
+ ChatCompletionResponseUsage,
32
+ ChatCompletionStreamResponse,
33
+ ChatCompletionStreamResponseChoice,
34
+ Finish,
35
+ Function,
36
+ FunctionCall,
37
+ Role,
38
+ ScoreEvaluationResponse,
39
+ )
40
+
41
+
42
+ if is_fastapi_available():
43
+ from fastapi import HTTPException, status
44
+
45
+
46
+ if is_pillow_available():
47
+ from PIL import Image
48
+
49
+
50
+ if is_requests_available():
51
+ import requests
52
+
53
+
54
+ if TYPE_CHECKING:
55
+ from ..chat import ChatModel
56
+ from ..data.mm_plugin import ImageInput
57
+ from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
58
+
59
+
60
+ logger = get_logger(__name__)
61
+ ROLE_MAPPING = {
62
+ Role.USER: DataRole.USER.value,
63
+ Role.ASSISTANT: DataRole.ASSISTANT.value,
64
+ Role.SYSTEM: DataRole.SYSTEM.value,
65
+ Role.FUNCTION: DataRole.FUNCTION.value,
66
+ Role.TOOL: DataRole.OBSERVATION.value,
67
+ }
68
+
69
+
70
+ def _process_request(
71
+ request: "ChatCompletionRequest",
72
+ ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
73
+ logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
74
+
75
+ if len(request.messages) == 0:
76
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
77
+
78
+ if request.messages[0].role == Role.SYSTEM:
79
+ system = request.messages.pop(0).content
80
+ else:
81
+ system = None
82
+
83
+ if len(request.messages) % 2 == 0:
84
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
85
+
86
+ input_messages = []
87
+ image = None
88
+ for i, message in enumerate(request.messages):
89
+ if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
90
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
91
+ elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
92
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
93
+
94
+ if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
95
+ tool_calls = [
96
+ {"name": tool_call.function.name, "arguments": tool_call.function.arguments}
97
+ for tool_call in message.tool_calls
98
+ ]
99
+ content = json.dumps(tool_calls, ensure_ascii=False)
100
+ input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
101
+ elif isinstance(message.content, list):
102
+ for input_item in message.content:
103
+ if input_item.type == "text":
104
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
105
+ else:
106
+ image_url = input_item.image_url.url
107
+ if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
108
+ image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
109
+ elif os.path.isfile(image_url): # local file
110
+ image_stream = open(image_url, "rb")
111
+ else: # web uri
112
+ image_stream = requests.get(image_url, stream=True).raw
113
+
114
+ image = Image.open(image_stream).convert("RGB")
115
+ else:
116
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
117
+
118
+ tool_list = request.tools
119
+ if isinstance(tool_list, list) and len(tool_list):
120
+ try:
121
+ tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
122
+ except json.JSONDecodeError:
123
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
124
+ else:
125
+ tools = None
126
+
127
+ return input_messages, system, tools, image
128
+
129
+
130
+ def _create_stream_chat_completion_chunk(
131
+ completion_id: str,
132
+ model: str,
133
+ delta: "ChatCompletionMessage",
134
+ index: Optional[int] = 0,
135
+ finish_reason: Optional["Finish"] = None,
136
+ ) -> str:
137
+ choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
138
+ chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
139
+ return jsonify(chunk)
140
+
141
+
142
+ async def create_chat_completion_response(
143
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
144
+ ) -> "ChatCompletionResponse":
145
+ completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
146
+ input_messages, system, tools, image = _process_request(request)
147
+ responses = await chat_model.achat(
148
+ input_messages,
149
+ system,
150
+ tools,
151
+ image,
152
+ do_sample=request.do_sample,
153
+ temperature=request.temperature,
154
+ top_p=request.top_p,
155
+ max_new_tokens=request.max_tokens,
156
+ num_return_sequences=request.n,
157
+ stop=request.stop,
158
+ )
159
+
160
+ prompt_length, response_length = 0, 0
161
+ choices = []
162
+ for i, response in enumerate(responses):
163
+ if tools:
164
+ result = chat_model.engine.template.extract_tool(response.response_text)
165
+ else:
166
+ result = response.response_text
167
+
168
+ if isinstance(result, list):
169
+ tool_calls = []
170
+ for tool in result:
171
+ function = Function(name=tool[0], arguments=tool[1])
172
+ tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
173
+
174
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
175
+ finish_reason = Finish.TOOL
176
+ else:
177
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
178
+ finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
179
+
180
+ choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
181
+ prompt_length = response.prompt_length
182
+ response_length += response.response_length
183
+
184
+ usage = ChatCompletionResponseUsage(
185
+ prompt_tokens=prompt_length,
186
+ completion_tokens=response_length,
187
+ total_tokens=prompt_length + response_length,
188
+ )
189
+
190
+ return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
191
+
192
+
193
+ async def create_stream_chat_completion_response(
194
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
195
+ ) -> AsyncGenerator[str, None]:
196
+ completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
197
+ input_messages, system, tools, image = _process_request(request)
198
+ if tools:
199
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
200
+
201
+ if request.n > 1:
202
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
203
+
204
+ yield _create_stream_chat_completion_chunk(
205
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
206
+ )
207
+ async for new_token in chat_model.astream_chat(
208
+ input_messages,
209
+ system,
210
+ tools,
211
+ image,
212
+ do_sample=request.do_sample,
213
+ temperature=request.temperature,
214
+ top_p=request.top_p,
215
+ max_new_tokens=request.max_tokens,
216
+ stop=request.stop,
217
+ ):
218
+ if len(new_token) != 0:
219
+ yield _create_stream_chat_completion_chunk(
220
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
221
+ )
222
+
223
+ yield _create_stream_chat_completion_chunk(
224
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
225
+ )
226
+ yield "[DONE]"
227
+
228
+
229
+ async def create_score_evaluation_response(
230
+ request: "ScoreEvaluationRequest", chat_model: "ChatModel"
231
+ ) -> "ScoreEvaluationResponse":
232
+ score_id = "scoreval-{}".format(uuid.uuid4().hex)
233
+ if len(request.messages) == 0:
234
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
235
+
236
+ scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
237
+ return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
llamafactory/api/common.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ from typing import TYPE_CHECKING, Any, Dict
17
+
18
+
19
+ if TYPE_CHECKING:
20
+ from pydantic import BaseModel
21
+
22
+
23
+ def dictify(data: "BaseModel") -> Dict[str, Any]:
24
+ try: # pydantic v2
25
+ return data.model_dump(exclude_unset=True)
26
+ except AttributeError: # pydantic v1
27
+ return data.dict(exclude_unset=True)
28
+
29
+
30
+ def jsonify(data: "BaseModel") -> str:
31
+ try: # pydantic v2
32
+ return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
33
+ except AttributeError: # pydantic v1
34
+ return data.json(exclude_unset=True, ensure_ascii=False)
llamafactory/api/protocol.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+ from enum import Enum, unique
17
+ from typing import Any, Dict, List, Optional, Union
18
+
19
+ from pydantic import BaseModel, Field
20
+ from typing_extensions import Literal
21
+
22
+
23
+ @unique
24
+ class Role(str, Enum):
25
+ USER = "user"
26
+ ASSISTANT = "assistant"
27
+ SYSTEM = "system"
28
+ FUNCTION = "function"
29
+ TOOL = "tool"
30
+
31
+
32
+ @unique
33
+ class Finish(str, Enum):
34
+ STOP = "stop"
35
+ LENGTH = "length"
36
+ TOOL = "tool_calls"
37
+
38
+
39
+ class ModelCard(BaseModel):
40
+ id: str
41
+ object: Literal["model"] = "model"
42
+ created: int = Field(default_factory=lambda: int(time.time()))
43
+ owned_by: Literal["owner"] = "owner"
44
+
45
+
46
+ class ModelList(BaseModel):
47
+ object: Literal["list"] = "list"
48
+ data: List[ModelCard] = []
49
+
50
+
51
+ class Function(BaseModel):
52
+ name: str
53
+ arguments: str
54
+
55
+
56
+ class FunctionDefinition(BaseModel):
57
+ name: str
58
+ description: str
59
+ parameters: Dict[str, Any]
60
+
61
+
62
+ class FunctionAvailable(BaseModel):
63
+ type: Literal["function", "code_interpreter"] = "function"
64
+ function: Optional[FunctionDefinition] = None
65
+
66
+
67
+ class FunctionCall(BaseModel):
68
+ id: str
69
+ type: Literal["function"] = "function"
70
+ function: Function
71
+
72
+
73
+ class ImageURL(BaseModel):
74
+ url: str
75
+
76
+
77
+ class MultimodalInputItem(BaseModel):
78
+ type: Literal["text", "image_url"]
79
+ text: Optional[str] = None
80
+ image_url: Optional[ImageURL] = None
81
+
82
+
83
+ class ChatMessage(BaseModel):
84
+ role: Role
85
+ content: Optional[Union[str, List[MultimodalInputItem]]] = None
86
+ tool_calls: Optional[List[FunctionCall]] = None
87
+
88
+
89
+ class ChatCompletionMessage(BaseModel):
90
+ role: Optional[Role] = None
91
+ content: Optional[str] = None
92
+ tool_calls: Optional[List[FunctionCall]] = None
93
+
94
+
95
+ class ChatCompletionRequest(BaseModel):
96
+ model: str
97
+ messages: List[ChatMessage]
98
+ tools: Optional[List[FunctionAvailable]] = None
99
+ do_sample: Optional[bool] = None
100
+ temperature: Optional[float] = None
101
+ top_p: Optional[float] = None
102
+ n: int = 1
103
+ max_tokens: Optional[int] = None
104
+ stop: Optional[Union[str, List[str]]] = None
105
+ stream: bool = False
106
+
107
+
108
+ class ChatCompletionResponseChoice(BaseModel):
109
+ index: int
110
+ message: ChatCompletionMessage
111
+ finish_reason: Finish
112
+
113
+
114
+ class ChatCompletionStreamResponseChoice(BaseModel):
115
+ index: int
116
+ delta: ChatCompletionMessage
117
+ finish_reason: Optional[Finish] = None
118
+
119
+
120
+ class ChatCompletionResponseUsage(BaseModel):
121
+ prompt_tokens: int
122
+ completion_tokens: int
123
+ total_tokens: int
124
+
125
+
126
+ class ChatCompletionResponse(BaseModel):
127
+ id: str
128
+ object: Literal["chat.completion"] = "chat.completion"
129
+ created: int = Field(default_factory=lambda: int(time.time()))
130
+ model: str
131
+ choices: List[ChatCompletionResponseChoice]
132
+ usage: ChatCompletionResponseUsage
133
+
134
+
135
+ class ChatCompletionStreamResponse(BaseModel):
136
+ id: str
137
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
138
+ created: int = Field(default_factory=lambda: int(time.time()))
139
+ model: str
140
+ choices: List[ChatCompletionStreamResponseChoice]
141
+
142
+
143
+ class ScoreEvaluationRequest(BaseModel):
144
+ model: str
145
+ messages: List[str]
146
+ max_length: Optional[int] = None
147
+
148
+
149
+ class ScoreEvaluationResponse(BaseModel):
150
+ id: str
151
+ object: Literal["score.evaluation"] = "score.evaluation"
152
+ model: str
153
+ scores: List[float]
llamafactory/chat/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base_engine import BaseEngine
16
+ from .chat_model import ChatModel
17
+
18
+
19
+ __all__ = ["BaseEngine", "ChatModel"]
llamafactory/chat/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (321 Bytes). View file
 
llamafactory/chat/__pycache__/base_engine.cpython-311.pyc ADDED
Binary file (4.03 kB). View file
 
llamafactory/chat/__pycache__/chat_model.cpython-311.pyc ADDED
Binary file (8.77 kB). View file
 
llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
llamafactory/chat/__pycache__/vllm_engine.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
llamafactory/chat/base_engine.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from dataclasses import dataclass
17
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers import PreTrainedModel, PreTrainedTokenizer
22
+ from vllm import AsyncLLMEngine
23
+
24
+ from ..data import Template
25
+ from ..data.mm_plugin import ImageInput, VideoInput
26
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
27
+
28
+
29
+ @dataclass
30
+ class Response:
31
+ response_text: str
32
+ response_length: int
33
+ prompt_length: int
34
+ finish_reason: Literal["stop", "length"]
35
+
36
+
37
+ class BaseEngine(ABC):
38
+ r"""
39
+ Base class for inference engine of chat models.
40
+
41
+ Must implements async methods: chat(), stream_chat() and get_scores().
42
+ """
43
+
44
+ model: Union["PreTrainedModel", "AsyncLLMEngine"]
45
+ tokenizer: "PreTrainedTokenizer"
46
+ can_generate: bool
47
+ template: "Template"
48
+ generating_args: Dict[str, Any]
49
+
50
+ @abstractmethod
51
+ def __init__(
52
+ self,
53
+ model_args: "ModelArguments",
54
+ data_args: "DataArguments",
55
+ finetuning_args: "FinetuningArguments",
56
+ generating_args: "GeneratingArguments",
57
+ ) -> None:
58
+ r"""
59
+ Initializes an inference engine.
60
+ """
61
+ ...
62
+
63
+ @abstractmethod
64
+ async def chat(
65
+ self,
66
+ messages: Sequence[Dict[str, str]],
67
+ system: Optional[str] = None,
68
+ tools: Optional[str] = None,
69
+ image: Optional["ImageInput"] = None,
70
+ video: Optional["VideoInput"] = None,
71
+ **input_kwargs,
72
+ ) -> List["Response"]:
73
+ r"""
74
+ Gets a list of responses of the chat model.
75
+ """
76
+ ...
77
+
78
+ @abstractmethod
79
+ async def stream_chat(
80
+ self,
81
+ messages: Sequence[Dict[str, str]],
82
+ system: Optional[str] = None,
83
+ tools: Optional[str] = None,
84
+ image: Optional["ImageInput"] = None,
85
+ video: Optional["VideoInput"] = None,
86
+ **input_kwargs,
87
+ ) -> AsyncGenerator[str, None]:
88
+ r"""
89
+ Gets the response token-by-token of the chat model.
90
+ """
91
+ ...
92
+
93
+ @abstractmethod
94
+ async def get_scores(
95
+ self,
96
+ batch_input: List[str],
97
+ **input_kwargs,
98
+ ) -> List[float]:
99
+ r"""
100
+ Gets a list of scores of the reward model.
101
+ """
102
+ ...
llamafactory/chat/chat_model.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 THUDM and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the THUDM's ChatGLM implementation.
4
+ # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import asyncio
19
+ import os
20
+ from threading import Thread
21
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
22
+
23
+ from ..extras.misc import torch_gc
24
+ from ..hparams import get_infer_args
25
+ from .hf_engine import HuggingfaceEngine
26
+ from .vllm_engine import VllmEngine
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from ..data.mm_plugin import ImageInput, VideoInput
31
+ from .base_engine import BaseEngine, Response
32
+
33
+
34
+ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
35
+ asyncio.set_event_loop(loop)
36
+ loop.run_forever()
37
+
38
+
39
+ class ChatModel:
40
+ r"""
41
+ General class for chat models. Backed by huggingface or vllm engines.
42
+
43
+ Supports both sync and async methods.
44
+ Sync methods: chat(), stream_chat() and get_scores().
45
+ Async methods: achat(), astream_chat() and aget_scores().
46
+ """
47
+
48
+ def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
49
+ model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
50
+ self.engine_type = model_args.infer_backend
51
+ if model_args.infer_backend == "huggingface":
52
+ self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
53
+ elif model_args.infer_backend == "vllm":
54
+ self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
55
+ else:
56
+ raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
57
+
58
+ self._loop = asyncio.new_event_loop()
59
+ self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
60
+ self._thread.start()
61
+
62
+ def chat(
63
+ self,
64
+ messages: Sequence[Dict[str, str]],
65
+ system: Optional[str] = None,
66
+ tools: Optional[str] = None,
67
+ image: Optional["ImageInput"] = None,
68
+ video: Optional["VideoInput"] = None,
69
+ **input_kwargs,
70
+ ) -> List["Response"]:
71
+ r"""
72
+ Gets a list of responses of the chat model.
73
+ """
74
+ task = asyncio.run_coroutine_threadsafe(
75
+ self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
76
+ )
77
+ return task.result()
78
+
79
+ async def achat(
80
+ self,
81
+ messages: Sequence[Dict[str, str]],
82
+ system: Optional[str] = None,
83
+ tools: Optional[str] = None,
84
+ image: Optional["ImageInput"] = None,
85
+ video: Optional["VideoInput"] = None,
86
+ **input_kwargs,
87
+ ) -> List["Response"]:
88
+ r"""
89
+ Asynchronously gets a list of responses of the chat model.
90
+ """
91
+ return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
92
+
93
+ def stream_chat(
94
+ self,
95
+ messages: Sequence[Dict[str, str]],
96
+ system: Optional[str] = None,
97
+ tools: Optional[str] = None,
98
+ image: Optional["ImageInput"] = None,
99
+ video: Optional["VideoInput"] = None,
100
+ **input_kwargs,
101
+ ) -> Generator[str, None, None]:
102
+ r"""
103
+ Gets the response token-by-token of the chat model.
104
+ """
105
+ generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
106
+ while True:
107
+ try:
108
+ task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
109
+ yield task.result()
110
+ except StopAsyncIteration:
111
+ break
112
+
113
+ async def astream_chat(
114
+ self,
115
+ messages: Sequence[Dict[str, str]],
116
+ system: Optional[str] = None,
117
+ tools: Optional[str] = None,
118
+ image: Optional["ImageInput"] = None,
119
+ video: Optional["VideoInput"] = None,
120
+ **input_kwargs,
121
+ ) -> AsyncGenerator[str, None]:
122
+ r"""
123
+ Asynchronously gets the response token-by-token of the chat model.
124
+ """
125
+ async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
126
+ yield new_token
127
+
128
+ def get_scores(
129
+ self,
130
+ batch_input: List[str],
131
+ **input_kwargs,
132
+ ) -> List[float]:
133
+ r"""
134
+ Gets a list of scores of the reward model.
135
+ """
136
+ task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
137
+ return task.result()
138
+
139
+ async def aget_scores(
140
+ self,
141
+ batch_input: List[str],
142
+ **input_kwargs,
143
+ ) -> List[float]:
144
+ r"""
145
+ Asynchronously gets a list of scores of the reward model.
146
+ """
147
+ return await self.engine.get_scores(batch_input, **input_kwargs)
148
+
149
+
150
+ def run_chat() -> None:
151
+ if os.name != "nt":
152
+ try:
153
+ import readline # noqa: F401
154
+ except ImportError:
155
+ print("Install `readline` for a better experience.")
156
+
157
+ chat_model = ChatModel()
158
+ messages = []
159
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
160
+
161
+ while True:
162
+ try:
163
+ query = input("\nUser: ")
164
+ except UnicodeDecodeError:
165
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
166
+ continue
167
+ except Exception:
168
+ raise
169
+
170
+ if query.strip() == "exit":
171
+ break
172
+
173
+ if query.strip() == "clear":
174
+ messages = []
175
+ torch_gc()
176
+ print("History has been removed.")
177
+ continue
178
+
179
+ messages.append({"role": "user", "content": query})
180
+ print("Assistant: ", end="", flush=True)
181
+
182
+ response = ""
183
+ for new_text in chat_model.stream_chat(messages):
184
+ print(new_text, end="", flush=True)
185
+ response += new_text
186
+ print()
187
+ messages.append({"role": "assistant", "content": response})
llamafactory/chat/hf_engine.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import concurrent.futures
17
+ import os
18
+ from threading import Thread
19
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
20
+
21
+ import torch
22
+ from transformers import GenerationConfig, TextIteratorStreamer
23
+ from typing_extensions import override
24
+
25
+ from ..data import get_template_and_fix_tokenizer
26
+ from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
27
+ from ..extras.logging import get_logger
28
+ from ..extras.misc import get_logits_processor
29
+ from ..model import load_model, load_tokenizer
30
+ from .base_engine import BaseEngine, Response
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
35
+ from trl import PreTrainedModelWrapper
36
+
37
+ from ..data import Template
38
+ from ..data.mm_plugin import ImageInput, VideoInput
39
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
40
+
41
+
42
+ logger = get_logger(__name__)
43
+
44
+
45
+ class HuggingfaceEngine(BaseEngine):
46
+ def __init__(
47
+ self,
48
+ model_args: "ModelArguments",
49
+ data_args: "DataArguments",
50
+ finetuning_args: "FinetuningArguments",
51
+ generating_args: "GeneratingArguments",
52
+ ) -> None:
53
+ self.can_generate = finetuning_args.stage == "sft"
54
+ tokenizer_module = load_tokenizer(model_args)
55
+ self.tokenizer = tokenizer_module["tokenizer"]
56
+ self.processor = tokenizer_module["processor"]
57
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
58
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
59
+ self.model = load_model(
60
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
61
+ ) # must after fixing tokenizer to resize vocab
62
+ self.generating_args = generating_args.to_dict()
63
+ try:
64
+ asyncio.get_event_loop()
65
+ except RuntimeError:
66
+ logger.warning("There is no current event loop, creating a new one.")
67
+ loop = asyncio.new_event_loop()
68
+ asyncio.set_event_loop(loop)
69
+
70
+ self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
71
+
72
+ @staticmethod
73
+ def _process_args(
74
+ model: "PreTrainedModel",
75
+ tokenizer: "PreTrainedTokenizer",
76
+ processor: Optional["ProcessorMixin"],
77
+ template: "Template",
78
+ generating_args: Dict[str, Any],
79
+ messages: Sequence[Dict[str, str]],
80
+ system: Optional[str] = None,
81
+ tools: Optional[str] = None,
82
+ image: Optional["ImageInput"] = None,
83
+ video: Optional["VideoInput"] = None,
84
+ input_kwargs: Optional[Dict[str, Any]] = {},
85
+ ) -> Tuple[Dict[str, Any], int]:
86
+ mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
87
+ if image is not None:
88
+ mm_input_dict.update({"images": [image], "imglens": [1]})
89
+ if IMAGE_PLACEHOLDER not in messages[0]["content"]:
90
+ messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
91
+
92
+ if video is not None:
93
+ mm_input_dict.update({"videos": [video], "vidlens": [1]})
94
+ if VIDEO_PLACEHOLDER not in messages[0]["content"]:
95
+ messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
96
+
97
+ messages = template.mm_plugin.process_messages(
98
+ messages, mm_input_dict["images"], mm_input_dict["videos"], processor
99
+ )
100
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
101
+ system = system or generating_args["default_system"]
102
+ prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
103
+ prompt_ids, _ = template.mm_plugin.process_token_ids(
104
+ prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
105
+ )
106
+ prompt_length = len(prompt_ids)
107
+ inputs = torch.tensor([prompt_ids], device=model.device)
108
+ attention_mask = torch.ones_like(inputs, dtype=torch.bool)
109
+
110
+ do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
111
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
112
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
113
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
114
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
115
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
116
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
117
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
118
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
119
+ stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
120
+
121
+ if stop is not None:
122
+ logger.warning("Stop parameter is not supported by the huggingface engine yet.")
123
+
124
+ generating_args = generating_args.copy()
125
+ generating_args.update(
126
+ dict(
127
+ do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
128
+ temperature=temperature if temperature is not None else generating_args["temperature"],
129
+ top_p=top_p if top_p is not None else generating_args["top_p"],
130
+ top_k=top_k if top_k is not None else generating_args["top_k"],
131
+ num_return_sequences=num_return_sequences,
132
+ repetition_penalty=repetition_penalty
133
+ if repetition_penalty is not None
134
+ else generating_args["repetition_penalty"],
135
+ length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
136
+ eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
137
+ pad_token_id=tokenizer.pad_token_id,
138
+ )
139
+ )
140
+
141
+ if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
142
+ generating_args["do_sample"] = True
143
+ generating_args["temperature"] = generating_args["temperature"] or 1.0
144
+
145
+ if not generating_args["temperature"]:
146
+ generating_args["do_sample"] = False
147
+
148
+ if not generating_args["do_sample"]:
149
+ generating_args.pop("temperature", None)
150
+ generating_args.pop("top_p", None)
151
+
152
+ if max_length:
153
+ generating_args.pop("max_new_tokens", None)
154
+ generating_args["max_length"] = max_length
155
+
156
+ if max_new_tokens:
157
+ generating_args.pop("max_length", None)
158
+ generating_args["max_new_tokens"] = max_new_tokens
159
+
160
+ gen_kwargs = dict(
161
+ inputs=inputs,
162
+ attention_mask=attention_mask,
163
+ generation_config=GenerationConfig(**generating_args),
164
+ logits_processor=get_logits_processor(),
165
+ )
166
+
167
+ mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
168
+ for key, value in mm_inputs.items():
169
+ value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
170
+ gen_kwargs[key] = value.to(model.device)
171
+
172
+ return gen_kwargs, prompt_length
173
+
174
+ @staticmethod
175
+ @torch.inference_mode()
176
+ def _chat(
177
+ model: "PreTrainedModel",
178
+ tokenizer: "PreTrainedTokenizer",
179
+ processor: Optional["ProcessorMixin"],
180
+ template: "Template",
181
+ generating_args: Dict[str, Any],
182
+ messages: Sequence[Dict[str, str]],
183
+ system: Optional[str] = None,
184
+ tools: Optional[str] = None,
185
+ image: Optional["ImageInput"] = None,
186
+ video: Optional["VideoInput"] = None,
187
+ input_kwargs: Optional[Dict[str, Any]] = {},
188
+ ) -> List["Response"]:
189
+ gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
190
+ model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
191
+ )
192
+ generate_output = model.generate(**gen_kwargs)
193
+ response_ids = generate_output[:, prompt_length:]
194
+ response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
195
+ results = []
196
+ for i in range(len(response)):
197
+ eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
198
+ response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
199
+ results.append(
200
+ Response(
201
+ response_text=response[i],
202
+ response_length=response_length,
203
+ prompt_length=prompt_length,
204
+ finish_reason="stop" if len(eos_index) else "length",
205
+ )
206
+ )
207
+
208
+ return results
209
+
210
+ @staticmethod
211
+ @torch.inference_mode()
212
+ def _stream_chat(
213
+ model: "PreTrainedModel",
214
+ tokenizer: "PreTrainedTokenizer",
215
+ processor: Optional["ProcessorMixin"],
216
+ template: "Template",
217
+ generating_args: Dict[str, Any],
218
+ messages: Sequence[Dict[str, str]],
219
+ system: Optional[str] = None,
220
+ tools: Optional[str] = None,
221
+ image: Optional["ImageInput"] = None,
222
+ video: Optional["VideoInput"] = None,
223
+ input_kwargs: Optional[Dict[str, Any]] = {},
224
+ ) -> Callable[[], str]:
225
+ gen_kwargs, _ = HuggingfaceEngine._process_args(
226
+ model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
227
+ )
228
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
229
+ gen_kwargs["streamer"] = streamer
230
+ thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
231
+ thread.start()
232
+
233
+ def stream():
234
+ try:
235
+ return streamer.__next__()
236
+ except StopIteration:
237
+ raise StopAsyncIteration()
238
+
239
+ return stream
240
+
241
+ @staticmethod
242
+ @torch.inference_mode()
243
+ def _get_scores(
244
+ model: "PreTrainedModelWrapper",
245
+ tokenizer: "PreTrainedTokenizer",
246
+ batch_input: List[str],
247
+ input_kwargs: Optional[Dict[str, Any]] = {},
248
+ ) -> List[float]:
249
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
250
+ device = getattr(model.pretrained_model, "device", "cuda")
251
+ inputs: Dict[str, "torch.Tensor"] = tokenizer(
252
+ batch_input,
253
+ padding=True,
254
+ truncation=True,
255
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
256
+ return_tensors="pt",
257
+ add_special_tokens=False,
258
+ ).to(device)
259
+ values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
260
+ scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
261
+ return scores
262
+
263
+ @override
264
+ async def chat(
265
+ self,
266
+ messages: Sequence[Dict[str, str]],
267
+ system: Optional[str] = None,
268
+ tools: Optional[str] = None,
269
+ image: Optional["ImageInput"] = None,
270
+ video: Optional["VideoInput"] = None,
271
+ **input_kwargs,
272
+ ) -> List["Response"]:
273
+ if not self.can_generate:
274
+ raise ValueError("The current model does not support `chat`.")
275
+
276
+ loop = asyncio.get_running_loop()
277
+ input_args = (
278
+ self.model,
279
+ self.tokenizer,
280
+ self.processor,
281
+ self.template,
282
+ self.generating_args,
283
+ messages,
284
+ system,
285
+ tools,
286
+ image,
287
+ video,
288
+ input_kwargs,
289
+ )
290
+ async with self.semaphore:
291
+ with concurrent.futures.ThreadPoolExecutor() as pool:
292
+ return await loop.run_in_executor(pool, self._chat, *input_args)
293
+
294
+ @override
295
+ async def stream_chat(
296
+ self,
297
+ messages: Sequence[Dict[str, str]],
298
+ system: Optional[str] = None,
299
+ tools: Optional[str] = None,
300
+ image: Optional["ImageInput"] = None,
301
+ video: Optional["VideoInput"] = None,
302
+ **input_kwargs,
303
+ ) -> AsyncGenerator[str, None]:
304
+ if not self.can_generate:
305
+ raise ValueError("The current model does not support `stream_chat`.")
306
+
307
+ loop = asyncio.get_running_loop()
308
+ input_args = (
309
+ self.model,
310
+ self.tokenizer,
311
+ self.processor,
312
+ self.template,
313
+ self.generating_args,
314
+ messages,
315
+ system,
316
+ tools,
317
+ image,
318
+ video,
319
+ input_kwargs,
320
+ )
321
+ async with self.semaphore:
322
+ with concurrent.futures.ThreadPoolExecutor() as pool:
323
+ stream = self._stream_chat(*input_args)
324
+ while True:
325
+ try:
326
+ yield await loop.run_in_executor(pool, stream)
327
+ except StopAsyncIteration:
328
+ break
329
+
330
+ @override
331
+ async def get_scores(
332
+ self,
333
+ batch_input: List[str],
334
+ **input_kwargs,
335
+ ) -> List[float]:
336
+ if self.can_generate:
337
+ raise ValueError("Cannot get scores using an auto-regressive model.")
338
+
339
+ loop = asyncio.get_running_loop()
340
+ input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
341
+ async with self.semaphore:
342
+ with concurrent.futures.ThreadPoolExecutor() as pool:
343
+ return await loop.run_in_executor(pool, self._get_scores, *input_args)
llamafactory/chat/vllm_engine.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import uuid
16
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
17
+
18
+ from typing_extensions import override
19
+
20
+ from ..data import get_template_and_fix_tokenizer
21
+ from ..extras.constants import IMAGE_PLACEHOLDER
22
+ from ..extras.logging import get_logger
23
+ from ..extras.misc import get_device_count
24
+ from ..extras.packages import is_pillow_available, is_vllm_available
25
+ from ..model import load_config, load_tokenizer
26
+ from ..model.model_utils.quantization import QuantizationMethod
27
+ from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
28
+ from .base_engine import BaseEngine, Response
29
+
30
+
31
+ if is_pillow_available():
32
+ from PIL import Image
33
+ from PIL.Image import Image as ImageObject
34
+
35
+
36
+ if is_vllm_available():
37
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
38
+ from vllm.lora.request import LoRARequest
39
+
40
+
41
+ if TYPE_CHECKING:
42
+ from ..data.mm_plugin import ImageInput, VideoInput
43
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
44
+
45
+
46
+ logger = get_logger(__name__)
47
+
48
+
49
+ class VllmEngine(BaseEngine):
50
+ def __init__(
51
+ self,
52
+ model_args: "ModelArguments",
53
+ data_args: "DataArguments",
54
+ finetuning_args: "FinetuningArguments",
55
+ generating_args: "GeneratingArguments",
56
+ ) -> None:
57
+ config = load_config(model_args) # may download model from ms hub
58
+ if getattr(config, "quantization_config", None): # gptq models should use float16
59
+ quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
60
+ quant_method = quantization_config.get("quant_method", "")
61
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
62
+ model_args.infer_dtype = "float16"
63
+
64
+ self.can_generate = finetuning_args.stage == "sft"
65
+ tokenizer_module = load_tokenizer(model_args)
66
+ self.tokenizer = tokenizer_module["tokenizer"]
67
+ self.processor = tokenizer_module["processor"]
68
+ self.tokenizer.padding_side = "left"
69
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
70
+ self.generating_args = generating_args.to_dict()
71
+
72
+ engine_args = {
73
+ "model": model_args.model_name_or_path,
74
+ "trust_remote_code": True,
75
+ "download_dir": model_args.cache_dir,
76
+ "dtype": model_args.infer_dtype,
77
+ "max_model_len": model_args.vllm_maxlen,
78
+ "tensor_parallel_size": get_device_count() or 1,
79
+ "gpu_memory_utilization": model_args.vllm_gpu_util,
80
+ "disable_log_stats": True,
81
+ "disable_log_requests": True,
82
+ "enforce_eager": model_args.vllm_enforce_eager,
83
+ "enable_lora": model_args.adapter_name_or_path is not None,
84
+ "max_lora_rank": model_args.vllm_max_lora_rank,
85
+ }
86
+
87
+ if getattr(config, "is_yi_vl_derived_model", None):
88
+ import vllm.model_executor.models.llava
89
+
90
+ logger.info("Detected Yi-VL model, applying projector patch.")
91
+ vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
92
+
93
+ self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
94
+ if model_args.adapter_name_or_path is not None:
95
+ self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
96
+ else:
97
+ self.lora_request = None
98
+
99
+ async def _generate(
100
+ self,
101
+ messages: Sequence[Dict[str, str]],
102
+ system: Optional[str] = None,
103
+ tools: Optional[str] = None,
104
+ image: Optional["ImageInput"] = None,
105
+ video: Optional["VideoInput"] = None,
106
+ **input_kwargs,
107
+ ) -> AsyncIterator["RequestOutput"]:
108
+ request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
109
+ if image is not None:
110
+ if IMAGE_PLACEHOLDER not in messages[0]["content"]:
111
+ messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
112
+
113
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
114
+ system = system or self.generating_args["default_system"]
115
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
116
+ prompt_length = len(prompt_ids)
117
+
118
+ use_beam_search: bool = self.generating_args["num_beams"] > 1
119
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
120
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
121
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
122
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
123
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
124
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
125
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
126
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
127
+ stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
128
+
129
+ if "max_new_tokens" in self.generating_args:
130
+ max_tokens = self.generating_args["max_new_tokens"]
131
+ elif "max_length" in self.generating_args:
132
+ if self.generating_args["max_length"] > prompt_length:
133
+ max_tokens = self.generating_args["max_length"] - prompt_length
134
+ else:
135
+ max_tokens = 1
136
+
137
+ if max_length:
138
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
139
+
140
+ if max_new_tokens:
141
+ max_tokens = max_new_tokens
142
+
143
+ sampling_params = SamplingParams(
144
+ n=num_return_sequences,
145
+ repetition_penalty=(
146
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
147
+ )
148
+ or 1.0, # repetition_penalty must > 0
149
+ temperature=temperature if temperature is not None else self.generating_args["temperature"],
150
+ top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
151
+ top_k=top_k if top_k is not None else self.generating_args["top_k"],
152
+ use_beam_search=use_beam_search,
153
+ length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
154
+ stop=stop,
155
+ stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
156
+ max_tokens=max_tokens,
157
+ skip_special_tokens=True,
158
+ )
159
+
160
+ if image is not None: # add image features
161
+ if not isinstance(image, (str, ImageObject)):
162
+ raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
163
+
164
+ if isinstance(image, str):
165
+ image = Image.open(image).convert("RGB")
166
+
167
+ multi_modal_data = {"image": image}
168
+ else:
169
+ multi_modal_data = None
170
+
171
+ result_generator = self.model.generate(
172
+ inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
173
+ sampling_params=sampling_params,
174
+ request_id=request_id,
175
+ lora_request=self.lora_request,
176
+ )
177
+ return result_generator
178
+
179
+ @override
180
+ async def chat(
181
+ self,
182
+ messages: Sequence[Dict[str, str]],
183
+ system: Optional[str] = None,
184
+ tools: Optional[str] = None,
185
+ image: Optional["ImageInput"] = None,
186
+ video: Optional["VideoInput"] = None,
187
+ **input_kwargs,
188
+ ) -> List["Response"]:
189
+ final_output = None
190
+ generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
191
+ async for request_output in generator:
192
+ final_output = request_output
193
+
194
+ results = []
195
+ for output in final_output.outputs:
196
+ results.append(
197
+ Response(
198
+ response_text=output.text,
199
+ response_length=len(output.token_ids),
200
+ prompt_length=len(final_output.prompt_token_ids),
201
+ finish_reason=output.finish_reason,
202
+ )
203
+ )
204
+
205
+ return results
206
+
207
+ @override
208
+ async def stream_chat(
209
+ self,
210
+ messages: Sequence[Dict[str, str]],
211
+ system: Optional[str] = None,
212
+ tools: Optional[str] = None,
213
+ image: Optional["ImageInput"] = None,
214
+ video: Optional["VideoInput"] = None,
215
+ **input_kwargs,
216
+ ) -> AsyncGenerator[str, None]:
217
+ generated_text = ""
218
+ generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
219
+ async for result in generator:
220
+ delta_text = result.outputs[0].text[len(generated_text) :]
221
+ generated_text = result.outputs[0].text
222
+ yield delta_text
223
+
224
+ @override
225
+ async def get_scores(
226
+ self,
227
+ batch_input: List[str],
228
+ **input_kwargs,
229
+ ) -> List[float]:
230
+ raise NotImplementedError("vLLM engine does not support get_scores.")
llamafactory/cli.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import random
17
+ import subprocess
18
+ import sys
19
+ from enum import Enum, unique
20
+
21
+ from . import launcher
22
+ from .api.app import run_api
23
+ from .chat.chat_model import run_chat
24
+ from .eval.evaluator import run_eval
25
+ from .extras.env import VERSION, print_env
26
+ from .extras.logging import get_logger
27
+ from .extras.misc import get_device_count
28
+ from .train.tuner import export_model, run_exp
29
+ from .webui.interface import run_web_demo, run_web_ui
30
+
31
+
32
+ USAGE = (
33
+ "-" * 70
34
+ + "\n"
35
+ + "| Usage: |\n"
36
+ + "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
37
+ + "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
38
+ + "| llamafactory-cli eval -h: evaluate models |\n"
39
+ + "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
40
+ + "| llamafactory-cli train -h: train models |\n"
41
+ + "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
42
+ + "| llamafactory-cli webui: launch LlamaBoard |\n"
43
+ + "| llamafactory-cli version: show version info |\n"
44
+ + "-" * 70
45
+ )
46
+
47
+ WELCOME = (
48
+ "-" * 58
49
+ + "\n"
50
+ + "| Welcome to LLaMA Factory, version {}".format(VERSION)
51
+ + " " * (21 - len(VERSION))
52
+ + "|\n|"
53
+ + " " * 56
54
+ + "|\n"
55
+ + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
56
+ + "-" * 58
57
+ )
58
+
59
+ logger = get_logger(__name__)
60
+
61
+
62
+ @unique
63
+ class Command(str, Enum):
64
+ API = "api"
65
+ CHAT = "chat"
66
+ ENV = "env"
67
+ EVAL = "eval"
68
+ EXPORT = "export"
69
+ TRAIN = "train"
70
+ WEBDEMO = "webchat"
71
+ WEBUI = "webui"
72
+ VER = "version"
73
+ HELP = "help"
74
+
75
+
76
+ def main():
77
+ command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
78
+ if command == Command.API:
79
+ run_api()
80
+ elif command == Command.CHAT:
81
+ run_chat()
82
+ elif command == Command.ENV:
83
+ print_env()
84
+ elif command == Command.EVAL:
85
+ run_eval()
86
+ elif command == Command.EXPORT:
87
+ export_model()
88
+ elif command == Command.TRAIN:
89
+ force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
90
+ if force_torchrun or get_device_count() > 1:
91
+ master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
92
+ master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
93
+ logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
94
+ process = subprocess.run(
95
+ (
96
+ "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
97
+ "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
98
+ ).format(
99
+ nnodes=os.environ.get("NNODES", "1"),
100
+ node_rank=os.environ.get("RANK", "0"),
101
+ nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
102
+ master_addr=master_addr,
103
+ master_port=master_port,
104
+ file_name=launcher.__file__,
105
+ args=" ".join(sys.argv[1:]),
106
+ ),
107
+ shell=True,
108
+ )
109
+ sys.exit(process.returncode)
110
+ else:
111
+ run_exp()
112
+ elif command == Command.WEBDEMO:
113
+ run_web_demo()
114
+ elif command == Command.WEBUI:
115
+ run_web_ui()
116
+ elif command == Command.VER:
117
+ print(WELCOME)
118
+ elif command == Command.HELP:
119
+ print(USAGE)
120
+ else:
121
+ raise NotImplementedError("Unknown command: {}.".format(command))
llamafactory/data/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .collator import (
16
+ KTODataCollatorWithPadding,
17
+ MultiModalDataCollatorForSeq2Seq,
18
+ PairwiseDataCollatorWithPadding,
19
+ SFTDataCollatorWith4DAttentionMask,
20
+ )
21
+ from .data_utils import Role, split_dataset
22
+ from .loader import get_dataset
23
+ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
24
+
25
+
26
+ __all__ = [
27
+ "KTODataCollatorWithPadding",
28
+ "MultiModalDataCollatorForSeq2Seq",
29
+ "PairwiseDataCollatorWithPadding",
30
+ "SFTDataCollatorWith4DAttentionMask",
31
+ "Role",
32
+ "split_dataset",
33
+ "get_dataset",
34
+ "TEMPLATES",
35
+ "Template",
36
+ "get_template_and_fix_tokenizer",
37
+ ]
llamafactory/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (766 Bytes). View file
 
llamafactory/data/__pycache__/aligner.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
llamafactory/data/__pycache__/collator.cpython-311.pyc ADDED
Binary file (9.38 kB). View file
 
llamafactory/data/__pycache__/data_utils.cpython-311.pyc ADDED
Binary file (4.43 kB). View file
 
llamafactory/data/__pycache__/formatter.cpython-311.pyc ADDED
Binary file (8.97 kB). View file
 
llamafactory/data/__pycache__/loader.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
llamafactory/data/__pycache__/mm_plugin.cpython-311.pyc ADDED
Binary file (32.6 kB). View file
 
llamafactory/data/__pycache__/parser.cpython-311.pyc ADDED
Binary file (7.31 kB). View file
 
llamafactory/data/__pycache__/preprocess.cpython-311.pyc ADDED
Binary file (3.63 kB). View file
 
llamafactory/data/__pycache__/template.cpython-311.pyc ADDED
Binary file (41.7 kB). View file
 
llamafactory/data/__pycache__/tool_utils.cpython-311.pyc ADDED
Binary file (8.97 kB). View file
 
llamafactory/data/aligner.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from functools import partial
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
18
+
19
+ from ..extras.logging import get_logger
20
+ from .data_utils import Role
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from datasets import Dataset, IterableDataset
25
+ from transformers import Seq2SeqTrainingArguments
26
+
27
+ from ..hparams import DataArguments
28
+ from .mm_plugin import ImageInput, VideoInput
29
+ from .parser import DatasetAttr
30
+
31
+
32
+ logger = get_logger(__name__)
33
+
34
+
35
+ def _convert_images(
36
+ images: Sequence["ImageInput"],
37
+ dataset_attr: "DatasetAttr",
38
+ data_args: "DataArguments",
39
+ ) -> Optional[List["ImageInput"]]:
40
+ r"""
41
+ Optionally concatenates image path to dataset dir when loading from local disk.
42
+ """
43
+ if len(images) == 0:
44
+ return None
45
+
46
+ images = images[:]
47
+ if dataset_attr.load_from in ["script", "file"]:
48
+ for i in range(len(images)):
49
+ if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
50
+ images[i] = os.path.join(data_args.dataset_dir, images[i])
51
+
52
+ return images
53
+
54
+
55
+ def _convert_videos(
56
+ videos: Sequence["VideoInput"],
57
+ dataset_attr: "DatasetAttr",
58
+ data_args: "DataArguments",
59
+ ) -> Optional[List["VideoInput"]]:
60
+ r"""
61
+ Optionally concatenates video path to dataset dir when loading from local disk.
62
+ """
63
+ if len(videos) == 0:
64
+ return None
65
+
66
+ videos = videos[:]
67
+ if dataset_attr.load_from in ["script", "file"]:
68
+ for i in range(len(videos)):
69
+ if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
70
+ videos[i] = os.path.join(data_args.dataset_dir, videos[i])
71
+
72
+ return videos
73
+
74
+
75
+ def convert_alpaca(
76
+ example: Dict[str, Any],
77
+ dataset_attr: "DatasetAttr",
78
+ data_args: "DataArguments",
79
+ ) -> Dict[str, Any]:
80
+ r"""
81
+ Converts alpaca format dataset to the standard format.
82
+ """
83
+ prompt = []
84
+ if dataset_attr.history and isinstance(example[dataset_attr.history], list):
85
+ for old_prompt, old_response in example[dataset_attr.history]:
86
+ prompt.append({"role": Role.USER.value, "content": old_prompt})
87
+ prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
88
+
89
+ query = []
90
+ if dataset_attr.prompt and example[dataset_attr.prompt]:
91
+ query.append(example[dataset_attr.prompt])
92
+
93
+ if dataset_attr.query and example[dataset_attr.query]:
94
+ query.append(example[dataset_attr.query])
95
+
96
+ prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
97
+
98
+ if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
99
+ response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
100
+ if example[dataset_attr.kto_tag]:
101
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
102
+ else:
103
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
104
+ elif (
105
+ dataset_attr.ranking
106
+ and isinstance(example[dataset_attr.chosen], str)
107
+ and isinstance(example[dataset_attr.rejected], str)
108
+ ): # pairwise example
109
+ response = [
110
+ {"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
111
+ {"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
112
+ ]
113
+ elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
114
+ response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
115
+ else: # unsupervised
116
+ response = []
117
+
118
+ convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
119
+ convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
120
+ output = {
121
+ "_prompt": prompt,
122
+ "_response": response,
123
+ "_system": example[dataset_attr.system] if dataset_attr.system else "",
124
+ "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
125
+ "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
126
+ "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
127
+ }
128
+ return output
129
+
130
+
131
+ def convert_sharegpt(
132
+ example: Dict[str, Any],
133
+ dataset_attr: "DatasetAttr",
134
+ data_args: "DataArguments",
135
+ ) -> Dict[str, Any]:
136
+ r"""
137
+ Converts sharegpt format dataset to the standard format.
138
+ """
139
+ tag_mapping = {
140
+ dataset_attr.user_tag: Role.USER.value,
141
+ dataset_attr.assistant_tag: Role.ASSISTANT.value,
142
+ dataset_attr.observation_tag: Role.OBSERVATION.value,
143
+ dataset_attr.function_tag: Role.FUNCTION.value,
144
+ dataset_attr.system_tag: Role.SYSTEM.value,
145
+ }
146
+ odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
147
+ even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
148
+ accept_tags = (odd_tags, even_tags)
149
+ messages = example[dataset_attr.messages]
150
+ if (
151
+ dataset_attr.system_tag
152
+ and len(messages) != 0
153
+ and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
154
+ ):
155
+ system = messages[0][dataset_attr.content_tag]
156
+ messages = messages[1:]
157
+ else:
158
+ system = example[dataset_attr.system] if dataset_attr.system else ""
159
+
160
+ aligned_messages = []
161
+ broken_data = False
162
+ for turn_idx, message in enumerate(messages):
163
+ if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
164
+ logger.warning("Invalid role tag in {}.".format(messages))
165
+ broken_data = True
166
+
167
+ aligned_messages.append(
168
+ {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
169
+ )
170
+
171
+ if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
172
+ dataset_attr.ranking and len(aligned_messages) % 2 == 0
173
+ ):
174
+ logger.warning("Invalid message count in {}.".format(messages))
175
+ broken_data = True
176
+
177
+ if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
178
+ prompt = aligned_messages[:-1]
179
+ response = aligned_messages[-1:]
180
+ if example[dataset_attr.kto_tag]:
181
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
182
+ else:
183
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
184
+ elif (
185
+ dataset_attr.ranking
186
+ and isinstance(example[dataset_attr.chosen], dict)
187
+ and isinstance(example[dataset_attr.rejected], dict)
188
+ ): # pairwise example
189
+ chosen = example[dataset_attr.chosen]
190
+ rejected = example[dataset_attr.rejected]
191
+ if (
192
+ chosen[dataset_attr.role_tag] not in accept_tags[-1]
193
+ or rejected[dataset_attr.role_tag] not in accept_tags[-1]
194
+ ):
195
+ logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
196
+ broken_data = True
197
+
198
+ prompt = aligned_messages
199
+ response = [
200
+ {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
201
+ {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
202
+ ]
203
+ else: # normal example
204
+ prompt = aligned_messages[:-1]
205
+ response = aligned_messages[-1:]
206
+
207
+ if broken_data:
208
+ logger.warning("Skipping this abnormal example.")
209
+ prompt, response = [], []
210
+
211
+ convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
212
+ convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
213
+ output = {
214
+ "_prompt": prompt,
215
+ "_response": response,
216
+ "_system": system,
217
+ "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
218
+ "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
219
+ "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
220
+ }
221
+ return output
222
+
223
+
224
+ def align_dataset(
225
+ dataset: Union["Dataset", "IterableDataset"],
226
+ dataset_attr: "DatasetAttr",
227
+ data_args: "DataArguments",
228
+ training_args: "Seq2SeqTrainingArguments",
229
+ ) -> Union["Dataset", "IterableDataset"]:
230
+ r"""
231
+ Aligned dataset:
232
+ _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
233
+ _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
234
+ _system: "..."
235
+ _tools: "...",
236
+ _images: [],
237
+ _videos: [],
238
+ """
239
+ if dataset_attr.formatting == "alpaca":
240
+ convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
241
+ else:
242
+ convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
243
+
244
+ column_names = list(next(iter(dataset)).keys())
245
+ kwargs = {}
246
+ if not data_args.streaming:
247
+ kwargs = dict(
248
+ num_proc=data_args.preprocessing_num_workers,
249
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
250
+ desc="Converting format of dataset",
251
+ )
252
+
253
+ return dataset.map(
254
+ convert_func,
255
+ batched=False,
256
+ remove_columns=column_names,
257
+ **kwargs,
258
+ )
llamafactory/data/collator.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the OpenAccess AI Collective's axolotl library.
4
+ # https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from dataclasses import dataclass
19
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
20
+
21
+ import torch
22
+ from transformers import DataCollatorForSeq2Seq
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers import ProcessorMixin
27
+
28
+ from .template import Template
29
+
30
+
31
+ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
32
+ r"""
33
+ Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
34
+ while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
35
+
36
+ e.g.
37
+ ```python
38
+ # input
39
+ [[1, 1, 2, 2, 2, 0]]
40
+ # output
41
+ [
42
+ [
43
+ [
44
+ [o, x, x, x, x, x],
45
+ [o, o, x, x, x, x],
46
+ [x, x, o, x, x, x],
47
+ [x, x, o, o, x, x],
48
+ [x, x, o, o, o, x],
49
+ [x, x, x, x, x, x],
50
+ ]
51
+ ]
52
+ ]
53
+ ```
54
+ where `o` equals to `0.0`, `x` equals to `min_dtype`.
55
+ """
56
+ bsz, seq_len = attention_mask_with_indices.size()
57
+ min_dtype = torch.finfo(dtype).min
58
+ expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
59
+ # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
60
+ padding_mask = torch.where(expanded_mask != 0, 1, 0)
61
+ # Create a block-diagonal mask.
62
+ attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
63
+ # Use the lower triangular mask to zero out the upper triangular part
64
+ attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
65
+ # Invert the attention mask.
66
+ attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
67
+ return attention_mask_4d
68
+
69
+
70
+ @dataclass
71
+ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
72
+ r"""
73
+ Data collator that supports VLMs.
74
+
75
+ Features should contain input_ids, attention_mask, labels and images.
76
+ """
77
+
78
+ template: Optional["Template"] = None
79
+ processor: Optional["ProcessorMixin"] = None
80
+
81
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
82
+ batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
83
+ for feature in features:
84
+ images = feature.pop("images", None) or []
85
+ videos = feature.pop("videos", None) or []
86
+ batch_images.extend(images)
87
+ batch_videos.extend(videos)
88
+ batch_imglens.append(len(images))
89
+ batch_vidlens.append(len(videos))
90
+ batch_seqlens.append(len(feature["input_ids"]))
91
+
92
+ mm_inputs = self.template.mm_plugin.get_mm_inputs(
93
+ batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
94
+ )
95
+ if "token_type_ids" in mm_inputs:
96
+ token_type_ids = mm_inputs.pop("token_type_ids")
97
+ for i, feature in enumerate(features):
98
+ feature["token_type_ids"] = token_type_ids[i]
99
+
100
+ features: Dict[str, "torch.Tensor"] = super().__call__(features)
101
+ features.update(mm_inputs)
102
+ return features
103
+
104
+
105
+ @dataclass
106
+ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
107
+ r"""
108
+ Data collator for 4d attention mask.
109
+ """
110
+
111
+ block_diag_attn: bool = False
112
+ attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
113
+ compute_dtype: "torch.dtype" = torch.float32
114
+
115
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
116
+ features = super().__call__(features)
117
+ if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
118
+ features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
119
+
120
+ return features
121
+
122
+
123
+ @dataclass
124
+ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
125
+ r"""
126
+ Data collator for pairwise data.
127
+ """
128
+
129
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
130
+ r"""
131
+ Pads batched data to the longest sequence in the batch.
132
+
133
+ We generate 2 * n examples where the first n examples represent chosen examples and
134
+ the last n examples represent rejected examples.
135
+ """
136
+ concatenated_features = []
137
+ for key in ("chosen", "rejected"):
138
+ for feature in features:
139
+ target_feature = {
140
+ "input_ids": feature["{}_input_ids".format(key)],
141
+ "attention_mask": feature["{}_attention_mask".format(key)],
142
+ "labels": feature["{}_labels".format(key)],
143
+ "images": feature["images"],
144
+ "videos": feature["videos"],
145
+ }
146
+ concatenated_features.append(target_feature)
147
+
148
+ return super().__call__(concatenated_features)
149
+
150
+
151
+ @dataclass
152
+ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
153
+ r"""
154
+ Data collator for KTO data.
155
+ """
156
+
157
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
158
+ target_features = []
159
+ kl_features = []
160
+ kto_tags = []
161
+ for feature in features:
162
+ target_feature = {
163
+ "input_ids": feature["input_ids"],
164
+ "attention_mask": feature["attention_mask"],
165
+ "labels": feature["labels"],
166
+ "images": feature["images"],
167
+ "videos": feature["videos"],
168
+ }
169
+ kl_feature = {
170
+ "input_ids": feature["kl_input_ids"],
171
+ "attention_mask": feature["kl_attention_mask"],
172
+ "labels": feature["kl_labels"],
173
+ "images": feature["images"],
174
+ "videos": feature["videos"],
175
+ }
176
+ target_features.append(target_feature)
177
+ kl_features.append(kl_feature)
178
+ kto_tags.append(feature["kto_tags"])
179
+
180
+ batch = super().__call__(target_features)
181
+ kl_batch = super().__call__(kl_features)
182
+ batch["kl_input_ids"] = kl_batch["input_ids"]
183
+ batch["kl_attention_mask"] = kl_batch["attention_mask"]
184
+ batch["kl_labels"] = kl_batch["labels"]
185
+ if "token_type_ids" in kl_batch:
186
+ batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
187
+
188
+ batch["kto_tags"] = torch.tensor(kto_tags)
189
+ return batch
llamafactory/data/data_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum, unique
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
17
+
18
+ from datasets import DatasetDict, concatenate_datasets, interleave_datasets
19
+
20
+ from ..extras.logging import get_logger
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from datasets import Dataset, IterableDataset
25
+
26
+ from ..hparams import DataArguments
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
33
+
34
+
35
+ @unique
36
+ class Role(str, Enum):
37
+ USER = "user"
38
+ ASSISTANT = "assistant"
39
+ SYSTEM = "system"
40
+ FUNCTION = "function"
41
+ OBSERVATION = "observation"
42
+
43
+
44
+ class DatasetModule(TypedDict):
45
+ train_dataset: Optional[Union["Dataset", "IterableDataset"]]
46
+ eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
47
+
48
+
49
+ def merge_dataset(
50
+ all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
51
+ ) -> Union["Dataset", "IterableDataset"]:
52
+ r"""
53
+ Merges multiple datasets to a unified dataset.
54
+ """
55
+ if len(all_datasets) == 1:
56
+ return all_datasets[0]
57
+ elif data_args.mix_strategy == "concat":
58
+ if data_args.streaming:
59
+ logger.warning("The samples between different datasets will not be mixed in streaming mode.")
60
+
61
+ return concatenate_datasets(all_datasets)
62
+ elif data_args.mix_strategy.startswith("interleave"):
63
+ if not data_args.streaming:
64
+ logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
65
+
66
+ return interleave_datasets(
67
+ datasets=all_datasets,
68
+ probabilities=data_args.interleave_probs,
69
+ seed=seed,
70
+ stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
71
+ )
72
+ else:
73
+ raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
74
+
75
+
76
+ def split_dataset(
77
+ dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
78
+ ) -> "DatasetDict":
79
+ r"""
80
+ Splits the dataset and returns a dataset dict containing train set and validation set.
81
+
82
+ Supports both map dataset and iterable dataset.
83
+ """
84
+ if data_args.streaming:
85
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
86
+ val_set = dataset.take(int(data_args.val_size))
87
+ train_set = dataset.skip(int(data_args.val_size))
88
+ return DatasetDict({"train": train_set, "validation": val_set})
89
+ else:
90
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
91
+ dataset = dataset.train_test_split(test_size=val_size, seed=seed)
92
+ return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
llamafactory/data/formatter.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass, field
19
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
20
+
21
+ from typing_extensions import override
22
+
23
+ from .data_utils import SLOTS
24
+ from .tool_utils import get_tool_utils
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from .tool_utils import FunctionCall
29
+
30
+
31
+ @dataclass
32
+ class Formatter(ABC):
33
+ slots: SLOTS = field(default_factory=list)
34
+ tool_format: Optional[str] = None
35
+
36
+ @abstractmethod
37
+ def apply(self, **kwargs) -> SLOTS:
38
+ r"""
39
+ Forms a list of slots according to the inputs to encode.
40
+ """
41
+ ...
42
+
43
+ def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
44
+ r"""
45
+ Extract a list of tuples from the response message if using tools.
46
+
47
+ Each tuple consists of function name and function arguments.
48
+ """
49
+ raise NotImplementedError
50
+
51
+
52
+ @dataclass
53
+ class EmptyFormatter(Formatter):
54
+ def __post_init__(self):
55
+ has_placeholder = False
56
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
57
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
58
+ has_placeholder = True
59
+
60
+ if has_placeholder:
61
+ raise ValueError("Empty formatter should not contain any placeholder.")
62
+
63
+ @override
64
+ def apply(self, **kwargs) -> SLOTS:
65
+ return self.slots
66
+
67
+
68
+ @dataclass
69
+ class StringFormatter(Formatter):
70
+ def __post_init__(self):
71
+ has_placeholder = False
72
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
73
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
74
+ has_placeholder = True
75
+
76
+ if not has_placeholder:
77
+ raise ValueError("A placeholder is required in the string formatter.")
78
+
79
+ @override
80
+ def apply(self, **kwargs) -> SLOTS:
81
+ elements = []
82
+ for slot in self.slots:
83
+ if isinstance(slot, str):
84
+ for name, value in kwargs.items():
85
+ if not isinstance(value, str):
86
+ raise RuntimeError("Expected a string, got {}".format(value))
87
+
88
+ slot = slot.replace("{{" + name + "}}", value, 1)
89
+ elements.append(slot)
90
+ elif isinstance(slot, (dict, set)):
91
+ elements.append(slot)
92
+ else:
93
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
94
+
95
+ return elements
96
+
97
+
98
+ @dataclass
99
+ class FunctionFormatter(Formatter):
100
+ def __post_init__(self):
101
+ self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
102
+
103
+ @override
104
+ def apply(self, **kwargs) -> SLOTS:
105
+ content = kwargs.pop("content")
106
+ functions: List[Tuple[str, str]] = []
107
+ try:
108
+ tool_calls = json.loads(content)
109
+ if not isinstance(tool_calls, list): # parallel function call
110
+ tool_calls = [tool_calls]
111
+
112
+ for tool_call in tool_calls:
113
+ functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
114
+
115
+ except json.JSONDecodeError:
116
+ raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string
117
+
118
+ elements = []
119
+ for name, arguments in functions:
120
+ for slot in self.slots:
121
+ if isinstance(slot, str):
122
+ slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
123
+ elements.append(slot)
124
+ elif isinstance(slot, (dict, set)):
125
+ elements.append(slot)
126
+ else:
127
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
128
+
129
+ return elements
130
+
131
+
132
+ @dataclass
133
+ class ToolFormatter(Formatter):
134
+ def __post_init__(self):
135
+ self.tool_utils = get_tool_utils(self.tool_format)
136
+
137
+ @override
138
+ def apply(self, **kwargs) -> SLOTS:
139
+ content = kwargs.pop("content")
140
+ try:
141
+ tools = json.loads(content)
142
+ return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
143
+ except json.JSONDecodeError:
144
+ raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string
145
+
146
+ @override
147
+ def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
148
+ return self.tool_utils.tool_extractor(content)
llamafactory/data/loader.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
18
+
19
+ import numpy as np
20
+ from datasets import DatasetDict, load_dataset, load_from_disk
21
+ from transformers.utils.versions import require_version
22
+
23
+ from ..extras.constants import FILEEXT2TYPE
24
+ from ..extras.logging import get_logger
25
+ from ..extras.misc import has_tokenized_data
26
+ from .aligner import align_dataset
27
+ from .data_utils import merge_dataset, split_dataset
28
+ from .parser import get_dataset_list
29
+ from .preprocess import get_preprocess_and_print_func
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from datasets import Dataset, IterableDataset
34
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
35
+
36
+ from ..hparams import DataArguments, ModelArguments
37
+ from .data_utils import DatasetModule
38
+ from .parser import DatasetAttr
39
+ from .template import Template
40
+
41
+
42
+ logger = get_logger(__name__)
43
+
44
+
45
+ def _load_single_dataset(
46
+ dataset_attr: "DatasetAttr",
47
+ model_args: "ModelArguments",
48
+ data_args: "DataArguments",
49
+ training_args: "Seq2SeqTrainingArguments",
50
+ ) -> Union["Dataset", "IterableDataset"]:
51
+ r"""
52
+ Loads a single dataset and aligns it to the standard format.
53
+ """
54
+ logger.info("Loading dataset {}...".format(dataset_attr))
55
+ data_path, data_name, data_dir, data_files = None, None, None, None
56
+ if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
57
+ data_path = dataset_attr.dataset_name
58
+ data_name = dataset_attr.subset
59
+ data_dir = dataset_attr.folder
60
+
61
+ elif dataset_attr.load_from == "script":
62
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
63
+ data_name = dataset_attr.subset
64
+ data_dir = dataset_attr.folder
65
+
66
+ elif dataset_attr.load_from == "file":
67
+ data_files = []
68
+ local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
69
+ if os.path.isdir(local_path): # is directory
70
+ for file_name in os.listdir(local_path):
71
+ data_files.append(os.path.join(local_path, file_name))
72
+ if data_path is None:
73
+ data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
74
+ elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
75
+ raise ValueError("File types should be identical.")
76
+ elif os.path.isfile(local_path): # is file
77
+ data_files.append(local_path)
78
+ data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
79
+ else:
80
+ raise ValueError("File {} not found.".format(local_path))
81
+
82
+ if data_path is None:
83
+ raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
84
+ else:
85
+ raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
86
+
87
+ if dataset_attr.load_from == "ms_hub":
88
+ require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
89
+ from modelscope import MsDataset
90
+ from modelscope.utils.config_ds import MS_DATASETS_CACHE
91
+
92
+ cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
93
+ dataset = MsDataset.load(
94
+ dataset_name=data_path,
95
+ subset_name=data_name,
96
+ data_dir=data_dir,
97
+ data_files=data_files,
98
+ split=dataset_attr.split,
99
+ cache_dir=cache_dir,
100
+ token=model_args.ms_hub_token,
101
+ use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
102
+ )
103
+ if isinstance(dataset, MsDataset):
104
+ dataset = dataset.to_hf_dataset()
105
+ else:
106
+ dataset = load_dataset(
107
+ path=data_path,
108
+ name=data_name,
109
+ data_dir=data_dir,
110
+ data_files=data_files,
111
+ split=dataset_attr.split,
112
+ cache_dir=model_args.cache_dir,
113
+ token=model_args.hf_hub_token,
114
+ streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
115
+ trust_remote_code=True,
116
+ )
117
+
118
+ if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
119
+ dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
120
+
121
+ if dataset_attr.num_samples is not None and not data_args.streaming:
122
+ target_num = dataset_attr.num_samples
123
+ indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
124
+ target_num -= len(indexes)
125
+ if target_num > 0:
126
+ expand_indexes = np.random.choice(len(dataset), target_num)
127
+ indexes = np.concatenate((indexes, expand_indexes), axis=0)
128
+
129
+ assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
130
+ dataset = dataset.select(indexes)
131
+ logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
132
+
133
+ if data_args.max_samples is not None: # truncate dataset
134
+ max_samples = min(data_args.max_samples, len(dataset))
135
+ dataset = dataset.select(range(max_samples))
136
+
137
+ return align_dataset(dataset, dataset_attr, data_args, training_args)
138
+
139
+
140
+ def _get_merged_dataset(
141
+ dataset_names: Optional[Sequence[str]],
142
+ model_args: "ModelArguments",
143
+ data_args: "DataArguments",
144
+ training_args: "Seq2SeqTrainingArguments",
145
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
146
+ ) -> Optional[Union["Dataset", "IterableDataset"]]:
147
+ r"""
148
+ Gets the merged datasets in the standard format.
149
+ """
150
+ if dataset_names is None:
151
+ return None
152
+
153
+ datasets = []
154
+ for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
155
+ if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
156
+ raise ValueError("The dataset is not applicable in the current training stage.")
157
+
158
+ datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
159
+
160
+ return merge_dataset(datasets, data_args, seed=training_args.seed)
161
+
162
+
163
+ def _get_preprocessed_dataset(
164
+ dataset: Optional[Union["Dataset", "IterableDataset"]],
165
+ data_args: "DataArguments",
166
+ training_args: "Seq2SeqTrainingArguments",
167
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
168
+ template: "Template",
169
+ tokenizer: "PreTrainedTokenizer",
170
+ processor: Optional["ProcessorMixin"] = None,
171
+ is_eval: bool = False,
172
+ ) -> Optional[Union["Dataset", "IterableDataset"]]:
173
+ r"""
174
+ Preprocesses the dataset, including format checking and tokenization.
175
+ """
176
+ if dataset is None:
177
+ return None
178
+
179
+ preprocess_func, print_function = get_preprocess_and_print_func(
180
+ data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
181
+ )
182
+ column_names = list(next(iter(dataset)).keys())
183
+ kwargs = {}
184
+ if not data_args.streaming:
185
+ kwargs = dict(
186
+ num_proc=data_args.preprocessing_num_workers,
187
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
188
+ desc="Running tokenizer on dataset",
189
+ )
190
+
191
+ dataset = dataset.map(
192
+ preprocess_func,
193
+ batched=True,
194
+ batch_size=data_args.preprocessing_batch_size,
195
+ remove_columns=column_names,
196
+ **kwargs,
197
+ )
198
+
199
+ if training_args.should_log:
200
+ try:
201
+ print("eval example:" if is_eval else "training example:")
202
+ print_function(next(iter(dataset)))
203
+ except StopIteration:
204
+ if stage == "pt":
205
+ raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
206
+ else:
207
+ raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
208
+
209
+ return dataset
210
+
211
+
212
+ def get_dataset(
213
+ template: "Template",
214
+ model_args: "ModelArguments",
215
+ data_args: "DataArguments",
216
+ training_args: "Seq2SeqTrainingArguments",
217
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
218
+ tokenizer: "PreTrainedTokenizer",
219
+ processor: Optional["ProcessorMixin"] = None,
220
+ ) -> "DatasetModule":
221
+ r"""
222
+ Gets the train dataset and optionally gets the evaluation dataset.
223
+ """
224
+ # Load tokenized dataset
225
+ if data_args.tokenized_path is not None:
226
+ if has_tokenized_data(data_args.tokenized_path):
227
+ logger.warning("Loading dataset from disk will ignore other data arguments.")
228
+ dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
229
+ logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
230
+
231
+ dataset_module: Dict[str, "Dataset"] = {}
232
+ if "train" in dataset_dict:
233
+ dataset_module["train_dataset"] = dataset_dict["train"]
234
+
235
+ if "validation" in dataset_dict:
236
+ dataset_module["eval_dataset"] = dataset_dict["validation"]
237
+
238
+ if data_args.streaming:
239
+ dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
240
+
241
+ return dataset_module
242
+
243
+ if data_args.streaming:
244
+ raise ValueError("Turn off `streaming` when saving dataset to disk.")
245
+
246
+ # Load and preprocess dataset
247
+ with training_args.main_process_first(desc="load dataset"):
248
+ dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
249
+ eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
250
+
251
+ with training_args.main_process_first(desc="pre-process dataset"):
252
+ dataset = _get_preprocessed_dataset(
253
+ dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
254
+ )
255
+ eval_dataset = _get_preprocessed_dataset(
256
+ eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
257
+ )
258
+
259
+ if data_args.val_size > 1e-6:
260
+ dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
261
+ else:
262
+ dataset_dict = {}
263
+ if dataset is not None:
264
+ if data_args.streaming:
265
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
266
+
267
+ dataset_dict["train"] = dataset
268
+
269
+ if eval_dataset is not None:
270
+ if data_args.streaming:
271
+ eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
272
+
273
+ dataset_dict["validation"] = eval_dataset
274
+
275
+ dataset_dict = DatasetDict(dataset_dict)
276
+
277
+ if data_args.tokenized_path is not None:
278
+ if training_args.should_save:
279
+ dataset_dict.save_to_disk(data_args.tokenized_path)
280
+ logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
281
+ logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
282
+
283
+ sys.exit(0)
284
+
285
+ dataset_module = {}
286
+ if "train" in dataset_dict:
287
+ dataset_module["train_dataset"] = dataset_dict["train"]
288
+
289
+ if "validation" in dataset_dict:
290
+ dataset_module["eval_dataset"] = dataset_dict["validation"]
291
+
292
+ return dataset_module
llamafactory/data/mm_plugin.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from copy import deepcopy
3
+ from io import BytesIO
4
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
5
+
6
+ import numpy as np
7
+ from transformers.image_utils import get_image_size, to_numpy_array
8
+ from typing_extensions import override
9
+
10
+ from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
11
+ from ..extras.packages import is_pillow_available, is_pyav_available
12
+
13
+
14
+ if is_pillow_available():
15
+ from PIL import Image
16
+ from PIL.Image import Image as ImageObject
17
+
18
+
19
+ if is_pyav_available():
20
+ import av
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ import torch
25
+ from av.stream import Stream
26
+ from transformers import PreTrainedTokenizer, ProcessorMixin
27
+ from transformers.image_processing_utils import BaseImageProcessor
28
+
29
+ class EncodedImage(TypedDict):
30
+ path: Optional[str]
31
+ bytes: Optional[bytes]
32
+
33
+ ImageInput = Union[str, EncodedImage, ImageObject]
34
+ VideoInput = str
35
+
36
+
37
+ def _get_paligemma_token_type_ids(
38
+ imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
39
+ ) -> List[List[int]]:
40
+ r"""
41
+ Gets paligemma token type ids for computing loss.
42
+
43
+ Returns:
44
+ batch_token_type_ids: shape (batch_size, sequence_length)
45
+ """
46
+ batch_token_type_ids = []
47
+ for imglen, seqlen in zip(imglens, seqlens):
48
+ image_seqlen = imglen * getattr(processor, "image_seqlen")
49
+ batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
50
+
51
+ return batch_token_type_ids
52
+
53
+
54
+ class BasePlugin:
55
+ def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
56
+ self.image_token = image_token
57
+ self.video_token = video_token
58
+
59
+ def _validate_input(
60
+ self,
61
+ images: Sequence["ImageInput"],
62
+ videos: Sequence["VideoInput"],
63
+ ) -> None:
64
+ r"""
65
+ Validates if this model accepts the input modalities.
66
+ """
67
+ if len(images) != 0 and self.image_token is None:
68
+ raise ValueError("This model does not support image input.")
69
+
70
+ if len(videos) != 0 and self.video_token is None:
71
+ raise ValueError("This model does not support video input.")
72
+
73
+ def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
74
+ r"""
75
+ Pre-processes a single image.
76
+ """
77
+ image_resolution: int = kwargs.get("image_resolution")
78
+ if max(image.width, image.height) > image_resolution:
79
+ resize_factor = image_resolution / max(image.width, image.height)
80
+ width, height = int(image.width * resize_factor), int(image.height * resize_factor)
81
+ image = image.resize((width, height), resample=Image.NEAREST)
82
+
83
+ if image.mode != "RGB":
84
+ image = image.convert("RGB")
85
+
86
+ return image
87
+
88
+ def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
89
+ r"""
90
+ Computes video sample frames according to fps.
91
+ """
92
+ video_fps: float = kwargs.get("video_fps")
93
+ video_maxlen: int = kwargs.get("video_maxlen")
94
+ total_frames = video_stream.frames
95
+ sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
96
+ sample_frames = min(total_frames, video_maxlen, sample_frames)
97
+ return math.floor(sample_frames)
98
+
99
+ def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
100
+ r"""
101
+ Regularizes images to avoid error. Including reading and pre-processing.
102
+ """
103
+ results = []
104
+ for image in images:
105
+ if isinstance(image, str):
106
+ image = Image.open(image)
107
+ elif isinstance(image, dict):
108
+ if image["bytes"] is not None:
109
+ image = Image.open(BytesIO(image["bytes"]))
110
+ else:
111
+ image = Image.open(image["path"])
112
+
113
+ if not isinstance(image, ImageObject):
114
+ raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
115
+
116
+ results.append(self._preprocess_image(image, **kwargs))
117
+
118
+ return results
119
+
120
+ def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
121
+ r"""
122
+ Regularizes videos to avoid error. Including reading, resizing and converting.
123
+ """
124
+ results = []
125
+ for video in videos:
126
+ container = av.open(video, "r")
127
+ video_stream = next(stream for stream in container.streams if stream.type == "video")
128
+ total_frames = video_stream.frames
129
+ sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
130
+ sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
131
+ frames: List["ImageObject"] = []
132
+ container.seek(0)
133
+ for frame_idx, frame in enumerate(container.decode(video_stream)):
134
+ if frame_idx in sample_indices:
135
+ frames.append(frame.to_image())
136
+
137
+ frames = self._regularize_images(frames, **kwargs)
138
+ results.append(frames)
139
+
140
+ return results
141
+
142
+ def _get_mm_inputs(
143
+ self,
144
+ images: Sequence["ImageInput"],
145
+ videos: Sequence["VideoInput"],
146
+ processor: "ProcessorMixin",
147
+ ) -> Dict[str, "torch.Tensor"]:
148
+ r"""
149
+ Processes visual inputs.
150
+
151
+ Returns: (llava and paligemma)
152
+ pixel_values: tensor with shape (B, C, H, W)
153
+
154
+ Returns: (qwen2-vl)
155
+ pixel_values: tensor with shape (num_patches, patch_dim)
156
+ image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
157
+
158
+ It holds num_patches == torch.prod(image_grid_thw)
159
+ """
160
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
161
+ video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
162
+ input_dict = {"images": None} # default key
163
+ if len(images) != 0:
164
+ images = self._regularize_images(
165
+ images,
166
+ image_resolution=getattr(processor, "image_resolution", 512),
167
+ )
168
+ input_dict["images"] = images
169
+
170
+ if len(videos) != 0:
171
+ videos = self._regularize_videos(
172
+ videos,
173
+ image_resolution=getattr(processor, "video_resolution", 128),
174
+ video_fps=getattr(processor, "video_fps", 1.0),
175
+ video_maxlen=getattr(processor, "video_maxlen", 64),
176
+ )
177
+ input_dict["videos"] = videos
178
+
179
+ mm_inputs = {}
180
+ if image_processor != video_processor:
181
+ if input_dict.get("images") is not None:
182
+ mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
183
+ if input_dict.get("videos") is not None:
184
+ mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
185
+ elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
186
+ mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
187
+
188
+ return mm_inputs
189
+
190
+ def process_messages(
191
+ self,
192
+ messages: Sequence[Dict[str, str]],
193
+ images: Sequence["ImageInput"],
194
+ videos: Sequence["VideoInput"],
195
+ processor: Optional["ProcessorMixin"],
196
+ ) -> List[Dict[str, str]]:
197
+ r"""
198
+ Pre-processes input messages before tokenization for VLMs.
199
+ """
200
+ self._validate_input(images, videos)
201
+ return messages
202
+
203
+ def process_token_ids(
204
+ self,
205
+ input_ids: List[int],
206
+ labels: Optional[List[int]],
207
+ images: Sequence["ImageInput"],
208
+ videos: Sequence["VideoInput"],
209
+ tokenizer: "PreTrainedTokenizer",
210
+ processor: Optional["ProcessorMixin"],
211
+ ) -> Tuple[List[int], Optional[List[int]]]:
212
+ r"""
213
+ Pre-processes token ids after tokenization for VLMs.
214
+ """
215
+ self._validate_input(images, videos)
216
+ return input_ids, labels
217
+
218
+ def get_mm_inputs(
219
+ self,
220
+ images: Sequence["ImageInput"],
221
+ videos: Sequence["VideoInput"],
222
+ imglens: Sequence[int],
223
+ vidlens: Sequence[int],
224
+ seqlens: Sequence[int],
225
+ processor: Optional["ProcessorMixin"],
226
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
227
+ r"""
228
+ Builds batched multimodal inputs for VLMs.
229
+ """
230
+ self._validate_input(images, videos)
231
+ return {}
232
+
233
+
234
+ class LlavaPlugin(BasePlugin):
235
+ @override
236
+ def process_messages(
237
+ self,
238
+ messages: Sequence[Dict[str, str]],
239
+ images: Sequence["ImageInput"],
240
+ videos: Sequence["VideoInput"],
241
+ processor: Optional["ProcessorMixin"],
242
+ ) -> List[Dict[str, str]]:
243
+ self._validate_input(images, videos)
244
+ num_image_tokens = 0
245
+ image_seqlen = getattr(processor, "image_seqlen")
246
+ messages = deepcopy(messages)
247
+ for message in messages:
248
+ content = message["content"]
249
+ while IMAGE_PLACEHOLDER in content:
250
+ num_image_tokens += 1
251
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
252
+
253
+ message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
254
+
255
+ if len(images) != num_image_tokens:
256
+ raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
257
+
258
+ return messages
259
+
260
+ @override
261
+ def get_mm_inputs(
262
+ self,
263
+ images: Sequence["ImageInput"],
264
+ videos: Sequence["VideoInput"],
265
+ imglens: Sequence[int],
266
+ vidlens: Sequence[int],
267
+ seqlens: Sequence[int],
268
+ processor: Optional["ProcessorMixin"],
269
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
270
+ self._validate_input(images, videos)
271
+ return self._get_mm_inputs(images, videos, processor)
272
+
273
+
274
+ class LlavaNextPlugin(BasePlugin):
275
+ @override
276
+ def process_messages(
277
+ self,
278
+ messages: Sequence[Dict[str, str]],
279
+ images: Sequence["ImageInput"],
280
+ videos: Sequence["VideoInput"],
281
+ processor: Optional["ProcessorMixin"],
282
+ ) -> List[Dict[str, str]]:
283
+ self._validate_input(images, videos)
284
+ num_image_tokens = 0
285
+ messages = deepcopy(messages)
286
+ mm_inputs = self._get_mm_inputs(images, videos, processor)
287
+ if "image_sizes" in mm_inputs:
288
+ image_sizes = iter(mm_inputs["image_sizes"])
289
+ if "pixel_values" in mm_inputs:
290
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
291
+ for message in messages:
292
+ content = message["content"]
293
+ while self.image_token in content:
294
+ image_size = next(image_sizes)
295
+ orig_height, orig_width = image_size
296
+ image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
297
+ if processor.vision_feature_select_strategy == "default":
298
+ image_seqlen -= 1
299
+ num_image_tokens += 1
300
+ content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
301
+
302
+ message["content"] = content.replace("{{image}}", self.image_token)
303
+
304
+ if len(images) != num_image_tokens:
305
+ raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
306
+ return messages
307
+
308
+ @override
309
+ def get_mm_inputs(
310
+ self,
311
+ images: Sequence["ImageInput"],
312
+ videos: Sequence["VideoInput"],
313
+ imglens: Sequence[int],
314
+ vidlens: Sequence[int],
315
+ seqlens: Sequence[int],
316
+ processor: Optional["ProcessorMixin"],
317
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
318
+ self._validate_input(images, videos)
319
+ res = self._get_mm_inputs(images, videos, processor)
320
+ return res
321
+
322
+
323
+ class LlavaNextVideoPlugin(BasePlugin):
324
+ @override
325
+ def process_messages(
326
+ self,
327
+ messages: Sequence[Dict[str, str]],
328
+ images: Sequence["ImageInput"],
329
+ videos: Sequence["VideoInput"],
330
+ processor: Optional["ProcessorMixin"],
331
+ ) -> List[Dict[str, str]]:
332
+ self._validate_input(images, videos)
333
+ num_image_tokens = 0
334
+ num_video_tokens = 0
335
+ messages = deepcopy(messages)
336
+ mm_inputs = self._get_mm_inputs(images, videos, processor)
337
+ if "pixel_values" in mm_inputs:
338
+ image_sizes = iter(mm_inputs["image_sizes"])
339
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
340
+ for message in messages:
341
+ content = message["content"]
342
+
343
+ while self.image_token in content:
344
+ image_size = next(image_sizes)
345
+ orig_height, orig_width = image_size
346
+ image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
347
+ if processor.vision_feature_select_strategy == "default":
348
+ image_seqlen -= 1
349
+ num_image_tokens += 1
350
+ content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
351
+
352
+ message["content"] = content.replace("{{image}}", self.image_token)
353
+
354
+ if "pixel_values_videos" in mm_inputs:
355
+ pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
356
+ height, width = get_image_size(pixel_values_video[0])
357
+ num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
358
+ image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
359
+ video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
360
+
361
+ for message in messages:
362
+ content = message["content"]
363
+ while self.video_token in content:
364
+ num_video_tokens += 1
365
+ content = content.replace(self.video_token, "{{video}}", 1)
366
+ message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
367
+
368
+ if len(images) != num_image_tokens:
369
+ raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
370
+
371
+ if len(videos) != num_video_tokens:
372
+ raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
373
+
374
+ return messages
375
+
376
+ @override
377
+ def get_mm_inputs(
378
+ self,
379
+ images: Sequence["ImageInput"],
380
+ videos: Sequence["VideoInput"],
381
+ imglens: Sequence[int],
382
+ vidlens: Sequence[int],
383
+ seqlens: Sequence[int],
384
+ processor: Optional["ProcessorMixin"],
385
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
386
+ self._validate_input(images, videos)
387
+ return self._get_mm_inputs(images, videos, processor)
388
+
389
+
390
+ class PaliGemmaPlugin(BasePlugin):
391
+ @override
392
+ def process_messages(
393
+ self,
394
+ messages: Sequence[Dict[str, str]],
395
+ images: Sequence["ImageInput"],
396
+ videos: Sequence["VideoInput"],
397
+ processor: Optional["ProcessorMixin"],
398
+ ) -> List[Dict[str, str]]:
399
+ self._validate_input(images, videos)
400
+ num_image_tokens = 0
401
+ messages = deepcopy(messages)
402
+ for message in messages:
403
+ content = message["content"]
404
+ while IMAGE_PLACEHOLDER in content:
405
+ num_image_tokens += 1
406
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
407
+
408
+ message["content"] = content.replace("{{image}}", "")
409
+
410
+ if len(images) != num_image_tokens:
411
+ raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
412
+
413
+ return messages
414
+
415
+ @override
416
+ def process_token_ids(
417
+ self,
418
+ input_ids: List[int],
419
+ labels: Optional[List[int]],
420
+ images: Sequence["ImageInput"],
421
+ videos: Sequence["VideoInput"],
422
+ tokenizer: "PreTrainedTokenizer",
423
+ processor: Optional["ProcessorMixin"],
424
+ ) -> Tuple[List[int], Optional[List[int]]]:
425
+ self._validate_input(images, videos)
426
+ num_images = len(images)
427
+ image_seqlen = num_images * getattr(processor, "image_seqlen")
428
+ image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
429
+ input_ids = [image_token_id] * image_seqlen + input_ids
430
+ if labels is not None:
431
+ labels = [IGNORE_INDEX] * image_seqlen + labels
432
+
433
+ return input_ids, labels
434
+
435
+ @override
436
+ def get_mm_inputs(
437
+ self,
438
+ images: Sequence["ImageInput"],
439
+ videos: Sequence["VideoInput"],
440
+ imglens: Sequence[int],
441
+ vidlens: Sequence[int],
442
+ seqlens: Sequence[int],
443
+ processor: Optional["ProcessorMixin"],
444
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
445
+ self._validate_input(images, videos)
446
+ mm_inputs = self._get_mm_inputs(images, videos, processor)
447
+ mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
448
+ return mm_inputs
449
+
450
+
451
+ class Qwen2vlPlugin(BasePlugin):
452
+ @override
453
+ def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
454
+ image = super()._preprocess_image(image, **kwargs)
455
+ if min(image.width, image.height) < 28:
456
+ width, height = max(image.width, 28), max(image.height, 28)
457
+ image = image.resize((width, height), resample=Image.NEAREST)
458
+
459
+ if image.width / image.height > 200:
460
+ width, height = image.height * 180, image.height
461
+ image = image.resize((width, height), resample=Image.NEAREST)
462
+
463
+ if image.height / image.width > 200:
464
+ width, height = image.width, image.width * 180
465
+ image = image.resize((width, height), resample=Image.NEAREST)
466
+
467
+ return image
468
+
469
+ @override
470
+ def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
471
+ sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
472
+ sample_frames = sample_frames // 2 * 2
473
+ return sample_frames
474
+
475
+ @override
476
+ def process_messages(
477
+ self,
478
+ messages: Sequence[Dict[str, str]],
479
+ images: Sequence["ImageInput"],
480
+ videos: Sequence["VideoInput"],
481
+ processor: Optional["ProcessorMixin"],
482
+ ) -> List[Dict[str, str]]:
483
+ self._validate_input(images, videos)
484
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
485
+ merge_length: int = getattr(image_processor, "merge_size") ** 2
486
+ mm_inputs = self._get_mm_inputs(images, videos, processor)
487
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
488
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
489
+
490
+ num_image_tokens, num_video_tokens = 0, 0
491
+ messages = deepcopy(messages)
492
+ for message in messages:
493
+ content = message["content"]
494
+ while IMAGE_PLACEHOLDER in content:
495
+ if num_image_tokens >= len(image_grid_thw):
496
+ raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
497
+
498
+ content = content.replace(
499
+ IMAGE_PLACEHOLDER,
500
+ "<|vision_start|>{}<|vision_end|>".format(
501
+ self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
502
+ ),
503
+ 1,
504
+ )
505
+ num_image_tokens += 1
506
+
507
+ while VIDEO_PLACEHOLDER in content:
508
+ if num_video_tokens >= len(video_grid_thw):
509
+ raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
510
+
511
+ content = content.replace(
512
+ VIDEO_PLACEHOLDER,
513
+ "<|vision_start|>{}<|vision_end|>".format(
514
+ self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
515
+ ),
516
+ 1,
517
+ )
518
+ num_video_tokens += 1
519
+
520
+ message["content"] = content
521
+
522
+ if len(images) != num_image_tokens:
523
+ raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
524
+
525
+ if len(videos) != num_video_tokens:
526
+ raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
527
+
528
+ return messages
529
+
530
+ @override
531
+ def get_mm_inputs(
532
+ self,
533
+ images: Sequence["ImageInput"],
534
+ videos: Sequence["VideoInput"],
535
+ imglens: Sequence[int],
536
+ vidlens: Sequence[int],
537
+ seqlens: Sequence[int],
538
+ processor: Optional["ProcessorMixin"],
539
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
540
+ self._validate_input(images, videos)
541
+ return self._get_mm_inputs(images, videos, processor)
542
+
543
+
544
+ class VideoLlavaPlugin(BasePlugin):
545
+ @override
546
+ def process_messages(
547
+ self,
548
+ messages: Sequence[Dict[str, str]],
549
+ images: Sequence["ImageInput"],
550
+ videos: Sequence["VideoInput"],
551
+ processor: Optional["ProcessorMixin"],
552
+ ) -> List[Dict[str, str]]:
553
+ self._validate_input(images, videos)
554
+ num_image_tokens = 0
555
+ num_video_tokens = 0
556
+ messages = deepcopy(messages)
557
+ mm_inputs = self._get_mm_inputs(images, videos, processor)
558
+ num_frames = 0
559
+ exist_images = "pixel_values_images" in mm_inputs
560
+ exist_videos = "pixel_values_videos" in mm_inputs
561
+ if exist_videos or exist_images:
562
+ if exist_images:
563
+ height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
564
+ num_frames = 1
565
+ if exist_videos:
566
+ pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
567
+ height, width = get_image_size(pixel_values_video[0])
568
+ num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
569
+ image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
570
+ video_seqlen = image_seqlen * num_frames
571
+ if processor.vision_feature_select_strategy == "default":
572
+ image_seqlen -= 1
573
+ for message in messages:
574
+ content = message["content"]
575
+ while self.image_token in content:
576
+ num_image_tokens += 1
577
+ content = content.replace(self.image_token, "{{image}}", 1)
578
+ while self.video_token in content:
579
+ num_video_tokens += 1
580
+ content = content.replace(self.video_token, "{{video}}", 1)
581
+
582
+ content = content.replace("{{image}}", self.image_token * image_seqlen)
583
+ message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
584
+
585
+ if len(images) != num_image_tokens:
586
+ raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token))
587
+
588
+ if len(videos) != num_video_tokens:
589
+ raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token))
590
+
591
+ return messages
592
+
593
+ @override
594
+ def get_mm_inputs(
595
+ self,
596
+ images: Sequence["ImageInput"],
597
+ videos: Sequence["VideoInput"],
598
+ imglens: Sequence[int],
599
+ vidlens: Sequence[int],
600
+ seqlens: Sequence[int],
601
+ processor: Optional["ProcessorMixin"],
602
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
603
+ self._validate_input(images, videos)
604
+ return self._get_mm_inputs(images, videos, processor)
605
+
606
+
607
+ PLUGINS = {
608
+ "base": BasePlugin,
609
+ "llava": LlavaPlugin,
610
+ "llava_next": LlavaNextPlugin,
611
+ "llava_next_video": LlavaNextVideoPlugin,
612
+ "paligemma": PaliGemmaPlugin,
613
+ "qwen2_vl": Qwen2vlPlugin,
614
+ "video_llava": VideoLlavaPlugin,
615
+ }
616
+
617
+
618
+ def get_mm_plugin(
619
+ name: str,
620
+ image_token: Optional[str] = None,
621
+ video_token: Optional[str] = None,
622
+ ) -> "BasePlugin":
623
+ plugin_class = PLUGINS.get(name, None)
624
+ if plugin_class is None:
625
+ raise ValueError("Multimodal plugin `{}` not found.".format(name))
626
+
627
+ return plugin_class(image_token, video_token)
llamafactory/data/parser.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, List, Literal, Optional, Sequence
19
+
20
+ from transformers.utils import cached_file
21
+
22
+ from ..extras.constants import DATA_CONFIG
23
+ from ..extras.misc import use_modelscope
24
+
25
+
26
+ @dataclass
27
+ class DatasetAttr:
28
+ r"""
29
+ Dataset attributes.
30
+ """
31
+
32
+ # basic configs
33
+ load_from: Literal["hf_hub", "ms_hub", "script", "file"]
34
+ dataset_name: str
35
+ formatting: Literal["alpaca", "sharegpt"] = "alpaca"
36
+ ranking: bool = False
37
+ # extra configs
38
+ subset: Optional[str] = None
39
+ split: str = "train"
40
+ folder: Optional[str] = None
41
+ num_samples: Optional[int] = None
42
+ # common columns
43
+ system: Optional[str] = None
44
+ tools: Optional[str] = None
45
+ images: Optional[str] = None
46
+ videos: Optional[str] = None
47
+ # rlhf columns
48
+ chosen: Optional[str] = None
49
+ rejected: Optional[str] = None
50
+ kto_tag: Optional[str] = None
51
+ # alpaca columns
52
+ prompt: Optional[str] = "instruction"
53
+ query: Optional[str] = "input"
54
+ response: Optional[str] = "output"
55
+ history: Optional[str] = None
56
+ # sharegpt columns
57
+ messages: Optional[str] = "conversations"
58
+ # sharegpt tags
59
+ role_tag: Optional[str] = "from"
60
+ content_tag: Optional[str] = "value"
61
+ user_tag: Optional[str] = "human"
62
+ assistant_tag: Optional[str] = "gpt"
63
+ observation_tag: Optional[str] = "observation"
64
+ function_tag: Optional[str] = "function_call"
65
+ system_tag: Optional[str] = "system"
66
+
67
+ def __repr__(self) -> str:
68
+ return self.dataset_name
69
+
70
+ def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
71
+ setattr(self, key, obj.get(key, default))
72
+
73
+
74
+ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
75
+ r"""
76
+ Gets the attributes of the datasets.
77
+ """
78
+ if dataset_names is None:
79
+ dataset_names = []
80
+
81
+ if dataset_dir == "ONLINE":
82
+ dataset_info = None
83
+ else:
84
+ if dataset_dir.startswith("REMOTE:"):
85
+ config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
86
+ else:
87
+ config_path = os.path.join(dataset_dir, DATA_CONFIG)
88
+
89
+ try:
90
+ with open(config_path, "r") as f:
91
+ dataset_info = json.load(f)
92
+ except Exception as err:
93
+ if len(dataset_names) != 0:
94
+ raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
95
+
96
+ dataset_info = None
97
+
98
+ dataset_list: List["DatasetAttr"] = []
99
+ for name in dataset_names:
100
+ if dataset_info is None: # dataset_dir is ONLINE
101
+ load_from = "ms_hub" if use_modelscope() else "hf_hub"
102
+ dataset_attr = DatasetAttr(load_from, dataset_name=name)
103
+ dataset_list.append(dataset_attr)
104
+ continue
105
+
106
+ if name not in dataset_info:
107
+ raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
108
+
109
+ has_hf_url = "hf_hub_url" in dataset_info[name]
110
+ has_ms_url = "ms_hub_url" in dataset_info[name]
111
+
112
+ if has_hf_url or has_ms_url:
113
+ if (use_modelscope() and has_ms_url) or (not has_hf_url):
114
+ dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
115
+ else:
116
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
117
+ elif "script_url" in dataset_info[name]:
118
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
119
+ else:
120
+ dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
121
+
122
+ dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
123
+ dataset_attr.set_attr("ranking", dataset_info[name], default=False)
124
+ dataset_attr.set_attr("subset", dataset_info[name])
125
+ dataset_attr.set_attr("split", dataset_info[name], default="train")
126
+ dataset_attr.set_attr("folder", dataset_info[name])
127
+ dataset_attr.set_attr("num_samples", dataset_info[name])
128
+
129
+ if "columns" in dataset_info[name]:
130
+ column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
131
+ if dataset_attr.formatting == "alpaca":
132
+ column_names.extend(["prompt", "query", "response", "history"])
133
+ else:
134
+ column_names.extend(["messages"])
135
+
136
+ for column_name in column_names:
137
+ dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
138
+
139
+ if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
140
+ tag_names = (
141
+ "role_tag",
142
+ "content_tag",
143
+ "user_tag",
144
+ "assistant_tag",
145
+ "observation_tag",
146
+ "function_tag",
147
+ "system_tag",
148
+ )
149
+ for tag in tag_names:
150
+ dataset_attr.set_attr(tag, dataset_info[name]["tags"])
151
+
152
+ dataset_list.append(dataset_attr)
153
+
154
+ return dataset_list
llamafactory/data/preprocess.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import partial
16
+ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
17
+
18
+ from .processors.feedback import preprocess_feedback_dataset
19
+ from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
20
+ from .processors.pretrain import preprocess_pretrain_dataset
21
+ from .processors.supervised import (
22
+ preprocess_packed_supervised_dataset,
23
+ preprocess_supervised_dataset,
24
+ print_supervised_dataset_example,
25
+ )
26
+ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from transformers import PreTrainedTokenizer, ProcessorMixin
31
+
32
+ from ..hparams import DataArguments
33
+ from .template import Template
34
+
35
+
36
+ def get_preprocess_and_print_func(
37
+ data_args: "DataArguments",
38
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
39
+ template: "Template",
40
+ tokenizer: "PreTrainedTokenizer",
41
+ processor: Optional["ProcessorMixin"],
42
+ do_generate: bool = False,
43
+ ) -> Tuple[Callable, Callable]:
44
+ if stage == "pt":
45
+ preprocess_func = partial(
46
+ preprocess_pretrain_dataset,
47
+ tokenizer=tokenizer,
48
+ data_args=data_args,
49
+ )
50
+ print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
51
+ elif stage == "sft" and not do_generate:
52
+ if data_args.packing:
53
+ if data_args.neat_packing: # hack datasets to have int32 attention mask
54
+ from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
55
+
56
+ def __init__(self, data, **kwargs):
57
+ return TypedSequence.__init__(
58
+ self,
59
+ data,
60
+ type=kwargs.pop("type", None),
61
+ try_type=kwargs.pop("try_type", None),
62
+ optimized_int_type=kwargs.pop("optimized_int_type", None),
63
+ )
64
+
65
+ OptimizedTypedSequence.__init__ = __init__
66
+ preprocess_func = partial(
67
+ preprocess_packed_supervised_dataset,
68
+ template=template,
69
+ tokenizer=tokenizer,
70
+ processor=processor,
71
+ data_args=data_args,
72
+ )
73
+ else:
74
+ preprocess_func = partial(
75
+ preprocess_supervised_dataset,
76
+ template=template,
77
+ tokenizer=tokenizer,
78
+ processor=processor,
79
+ data_args=data_args,
80
+ )
81
+
82
+ print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
83
+ elif stage == "rm":
84
+ preprocess_func = partial(
85
+ preprocess_pairwise_dataset,
86
+ template=template,
87
+ tokenizer=tokenizer,
88
+ processor=processor,
89
+ data_args=data_args,
90
+ )
91
+ print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
92
+ elif stage == "kto":
93
+ preprocess_func = partial(
94
+ preprocess_feedback_dataset,
95
+ template=template,
96
+ tokenizer=tokenizer,
97
+ processor=processor,
98
+ data_args=data_args,
99
+ )
100
+ print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
101
+ else:
102
+ preprocess_func = partial(
103
+ preprocess_unsupervised_dataset,
104
+ template=template,
105
+ tokenizer=tokenizer,
106
+ processor=processor,
107
+ data_args=data_args,
108
+ )
109
+ print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
110
+
111
+ return preprocess_func, print_function
llamafactory/data/processors/__init__.py ADDED
File without changes
llamafactory/data/processors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (173 Bytes). View file
 
llamafactory/data/processors/__pycache__/feedback.cpython-311.pyc ADDED
Binary file (6.84 kB). View file
 
llamafactory/data/processors/__pycache__/pairwise.cpython-311.pyc ADDED
Binary file (7.81 kB). View file