File size: 4,359 Bytes
be6ea2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
[**中文**](./README_ZH.md) | [**English**](./README.md)
<p align="center" width="100%">
<a href="https://github.com/daiyizheng/TCMChat" target="_blank"><img src="./logo.png" alt="TCMChat" style="width: 25%; min-width: 300px; display: block; margin: auto;"></a>
</p>
# TCMChat: A Generative Large Language Model for Traditional Chinese Medicine
[![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese/blob/main/LICENSE) [![Python 3.10.12](https://img.shields.io/badge/python-3.10.12-blue.svg)](https://www.python.org/downloads/release/python-390/)
## News
[2024-5-17] Open source model weight on HuggingFace.
## Application
### Install
```
git clone https://github.com/daiyizheng/TCMChat
cd TCMChat
```
First install the dependency package. python environment 3.10+ is recommended.
```
pip install -r requirements.txt
```
### Weights download
- [TCMChat](https://huggingface.co/daiyizheng/TCMChat): QA and recommendation of TCM knowledge based on baichuan2-7B-Chat.
### Inference
#### Command line
```
python cli_infer.py \
--model_name_or_path /your/model/path \
--model_type chat
```
#### Web demo
```
python gradio_demo.py
```
We provide an online tool:[https://xomics.com.cn/tcmchat](https://xomics.com.cn/tcmchat)
### Retrain
#### Dataset Download
- [Pretrain dataset](https://github.com/ZJUFanLab/TCMChat/tree/master/data/pretrain)
- [SFT dataset](https://github.com/ZJUFanLab/TCMChat/tree/master/data/sft)
- [Benchmark dataset](https://github.com/ZJUFanLab/TCMChat/tree/master/data/evaluate)
> Note: Currently only sample data is provided. In the near future, we will fully open source the original data.
#### Pre-training
```shell
train_type="pretrain"
train_file="data/pretrain/train"
validation_file="data/pretrain/test"
block_size="1024"
deepspeed_dir="data/resources/deepspeed_zero_stage2_config.yml"
num_train_epochs="2"
export WANDB_PROJECT="TCM-${train_type}"
date_time=$(date +"%Y%m%d%H%M%S")
run_name="${date_time}_${block_size}"
model_name_or_path="your/path/Baichuan2-7B-Chat"
output_dir="output/${train_type}/${date_time}_${block_size}"
accelerate launch --config_file ${deepspeed_dir} src/pretraining.py \
--model_name_or_path ${model_name_or_path} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--preprocessing_num_workers 20 \
--cache_dir ./cache \
--block_size ${block_size} \
--seed 42 \
--do_train \
--do_eval \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--num_train_epochs ${num_train_epochs} \
--low_cpu_mem_usage True \
--torch_dtype bfloat16 \
--bf16 \
--ddp_find_unused_parameters False \
--gradient_checkpointing True \
--learning_rate 2e-4 \
--warmup_ratio 0.05 \
--weight_decay 0.01 \
--report_to wandb \
--run_name ${run_name} \
--logging_dir logs \
--logging_strategy steps \
--logging_steps 10 \
--eval_steps 50 \
--evaluation_strategy steps \
--save_steps 100 \
--save_strategy steps \
--save_total_limit 13 \
--output_dir ${output_dir} \
--overwrite_output_dir
```
#### Fine-tuning
```shell
train_type="SFT"
model_max_length="1024"
date_time=$(date +"%Y%m%d%H%M%S")
data_path="data/sft/sample_train_baichuan_data.json"
model_name_or_path="your/path/pretrain"
deepspeed_dir="data/resources/deepspeed_zero_stage2_confi_baichuan2.json"
export WANDB_PROJECT="TCM-${train_type}"
run_name="${train_type}_${date_time}"
output_dir="output/${train_type}/${date_time}_${model_max_length}"
deepspeed --hostfile="" src/fine-tune.py \
--report_to "wandb" \
--run_name ${run_name} \
--data_path ${data_path} \
--model_name_or_path ${model_name_or_path} \
--output_dir ${output_dir} \
--model_max_length ${model_max_length} \
--num_train_epochs 4 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 1 \
--save_strategy epoch \
--learning_rate 2e-5 \
--lr_scheduler_type constant \
--adam_beta1 0.9 \
--adam_beta2 0.98 \
--adam_epsilon 1e-8 \
--max_grad_norm 1.0 \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--logging_steps 1 \
--gradient_checkpointing True \
--deepspeed ${deepspeed_dir} \
--bf16 True \
--tf32 True
```
### Training details
Please refer to the experimental section of the paper for instructions.
|