File size: 4,261 Bytes
be6ea2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
[**中文**](./README_ZH.md) | [**English**](./README.md)

<p align="center" width="100%">
<a href="https://github.com/daiyizheng/TCMChat" target="_blank"><img src="./logo.png" alt="TCMChat" style="width: 25%; min-width: 300px; display: block; margin: auto;"></a>
</p>

# TCMChat: Traditional Chinese Medicine Recommendation System based on Large Language Model

[![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese/blob/main/LICENSE) [![Python 3.10.12](https://img.shields.io/badge/python-3.10.12-blue.svg)](https://www.python.org/downloads/release/python-390/)

## 新闻

[2024-5-17] huggingface 开源模型权重

## 应用

### 安装

```
git clone https://github.com/daiyizheng/TCMChat
cd TCMChat
```

首先安装依赖包,python环境建议3.10+

```
pip install -r requirements.txt
```

### 权重下载

- [TCMChat](https://huggingface.co/daiyizheng/TCMChat): 基于baichuan2-7B-Chat的中药、方剂知识问答与推荐。

### 推理

#### 命令行测试

```
python cli_infer.py \
--model_name_or_path /your/model/path \
--model_type  chat
```

#### Web页面测试

```
python gradio_demo.py
```

我们提供了一个在线的体验工具:[https://xomics.com.cn/tcmchat](https://xomics.com.cn/tcmchat)


### 重新训练
#### 数据集下载

- [预训练数据](https://github.com/ZJUFanLab/TCMChat/tree/master/data/pretrain) 
- [微调数据](https://github.com/ZJUFanLab/TCMChat/tree/master/data/sft)
- [基准评测数据](https://github.com/ZJUFanLab/TCMChat/tree/master/data/evaluate)

> 注意:目前只提供样例数据,不久将来,我们将完全开源原始数据


#### 预训练

```shell
train_type="pretrain"
train_file="data/pretrain/train"
validation_file="data/pretrain/test"
block_size="1024"
deepspeed_dir="data/resources/deepspeed_zero_stage2_config.yml"
num_train_epochs="2"
export WANDB_PROJECT="TCM-${train_type}"
date_time=$(date +"%Y%m%d%H%M%S")
run_name="${date_time}_${block_size}"
model_name_or_path="your/path/Baichuan2-7B-Chat"
output_dir="output/${train_type}/${date_time}_${block_size}"


accelerate launch --config_file ${deepspeed_dir} src/pretraining.py \
--model_name_or_path ${model_name_or_path}  \
--train_file  ${train_file}  \
--validation_file ${validation_file}  \
--preprocessing_num_workers 20  \
--cache_dir ./cache \
--block_size  ${block_size}  \
--seed 42  \
--do_train  \
--do_eval  \
--per_device_train_batch_size  32  \
--per_device_eval_batch_size  32  \
--num_train_epochs ${num_train_epochs}  \
--low_cpu_mem_usage  True \
--torch_dtype bfloat16  \
--bf16  \
--ddp_find_unused_parameters False  \
--gradient_checkpointing True  \
--learning_rate 2e-4 \
--warmup_ratio 0.05 \
--weight_decay 0.01 \
--report_to wandb  \
--run_name ${run_name}  \
--logging_dir  logs \
--logging_strategy steps \
--logging_steps 10 \
--eval_steps 50 \
--evaluation_strategy steps \
--save_steps 100 \
--save_strategy steps \
--save_total_limit 13 \
--output_dir  ${output_dir}  \
--overwrite_output_dir
```

#### 微调
```shell
train_type="SFT"
model_max_length="1024"
date_time=$(date +"%Y%m%d%H%M%S")
data_path="data/sft/sample_train_baichuan_data.json"
model_name_or_path="your/path/pretrain"
deepspeed_dir="data/resources/deepspeed_zero_stage2_confi_baichuan2.json"
export WANDB_PROJECT="TCM-${train_type}"
run_name="${train_type}_${date_time}"
output_dir="output/${train_type}/${date_time}_${model_max_length}"


deepspeed --hostfile="" src/fine-tune.py  \
    --report_to "wandb" \
    --run_name ${run_name}  \
    --data_path ${data_path} \
    --model_name_or_path ${model_name_or_path} \
    --output_dir ${output_dir} \
    --model_max_length ${model_max_length} \
    --num_train_epochs 4 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --save_strategy epoch \
    --learning_rate 2e-5 \
    --lr_scheduler_type constant \
    --adam_beta1 0.9 \
    --adam_beta2 0.98 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --weight_decay 1e-4 \
    --warmup_ratio 0.0 \
    --logging_steps 1 \
    --gradient_checkpointing True \
    --deepspeed ${deepspeed_dir} \
    --bf16 True \
    --tf32 True
```

### 训练细节

请参考论文实验部分说明。