File size: 4,261 Bytes
be6ea2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
[**中文**](./README_ZH.md) | [**English**](./README.md)
<p align="center" width="100%">
<a href="https://github.com/daiyizheng/TCMChat" target="_blank"><img src="./logo.png" alt="TCMChat" style="width: 25%; min-width: 300px; display: block; margin: auto;"></a>
</p>
# TCMChat: Traditional Chinese Medicine Recommendation System based on Large Language Model
[![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese/blob/main/LICENSE) [![Python 3.10.12](https://img.shields.io/badge/python-3.10.12-blue.svg)](https://www.python.org/downloads/release/python-390/)
## 新闻
[2024-5-17] huggingface 开源模型权重
## 应用
### 安装
```
git clone https://github.com/daiyizheng/TCMChat
cd TCMChat
```
首先安装依赖包,python环境建议3.10+
```
pip install -r requirements.txt
```
### 权重下载
- [TCMChat](https://huggingface.co/daiyizheng/TCMChat): 基于baichuan2-7B-Chat的中药、方剂知识问答与推荐。
### 推理
#### 命令行测试
```
python cli_infer.py \
--model_name_or_path /your/model/path \
--model_type chat
```
#### Web页面测试
```
python gradio_demo.py
```
我们提供了一个在线的体验工具:[https://xomics.com.cn/tcmchat](https://xomics.com.cn/tcmchat)
### 重新训练
#### 数据集下载
- [预训练数据](https://github.com/ZJUFanLab/TCMChat/tree/master/data/pretrain)
- [微调数据](https://github.com/ZJUFanLab/TCMChat/tree/master/data/sft)
- [基准评测数据](https://github.com/ZJUFanLab/TCMChat/tree/master/data/evaluate)
> 注意:目前只提供样例数据,不久将来,我们将完全开源原始数据
#### 预训练
```shell
train_type="pretrain"
train_file="data/pretrain/train"
validation_file="data/pretrain/test"
block_size="1024"
deepspeed_dir="data/resources/deepspeed_zero_stage2_config.yml"
num_train_epochs="2"
export WANDB_PROJECT="TCM-${train_type}"
date_time=$(date +"%Y%m%d%H%M%S")
run_name="${date_time}_${block_size}"
model_name_or_path="your/path/Baichuan2-7B-Chat"
output_dir="output/${train_type}/${date_time}_${block_size}"
accelerate launch --config_file ${deepspeed_dir} src/pretraining.py \
--model_name_or_path ${model_name_or_path} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--preprocessing_num_workers 20 \
--cache_dir ./cache \
--block_size ${block_size} \
--seed 42 \
--do_train \
--do_eval \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--num_train_epochs ${num_train_epochs} \
--low_cpu_mem_usage True \
--torch_dtype bfloat16 \
--bf16 \
--ddp_find_unused_parameters False \
--gradient_checkpointing True \
--learning_rate 2e-4 \
--warmup_ratio 0.05 \
--weight_decay 0.01 \
--report_to wandb \
--run_name ${run_name} \
--logging_dir logs \
--logging_strategy steps \
--logging_steps 10 \
--eval_steps 50 \
--evaluation_strategy steps \
--save_steps 100 \
--save_strategy steps \
--save_total_limit 13 \
--output_dir ${output_dir} \
--overwrite_output_dir
```
#### 微调
```shell
train_type="SFT"
model_max_length="1024"
date_time=$(date +"%Y%m%d%H%M%S")
data_path="data/sft/sample_train_baichuan_data.json"
model_name_or_path="your/path/pretrain"
deepspeed_dir="data/resources/deepspeed_zero_stage2_confi_baichuan2.json"
export WANDB_PROJECT="TCM-${train_type}"
run_name="${train_type}_${date_time}"
output_dir="output/${train_type}/${date_time}_${model_max_length}"
deepspeed --hostfile="" src/fine-tune.py \
--report_to "wandb" \
--run_name ${run_name} \
--data_path ${data_path} \
--model_name_or_path ${model_name_or_path} \
--output_dir ${output_dir} \
--model_max_length ${model_max_length} \
--num_train_epochs 4 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 1 \
--save_strategy epoch \
--learning_rate 2e-5 \
--lr_scheduler_type constant \
--adam_beta1 0.9 \
--adam_beta2 0.98 \
--adam_epsilon 1e-8 \
--max_grad_norm 1.0 \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--logging_steps 1 \
--gradient_checkpointing True \
--deepspeed ${deepspeed_dir} \
--bf16 True \
--tf32 True
```
### 训练细节
请参考论文实验部分说明。
|