3v324v23 commited on
Commit
438f2e2
1 Parent(s): ccb7d61

add training scripts

Browse files
Files changed (1) hide show
  1. train.sh +137 -0
train.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN_FILE=/home/jhju/datasets/qrecc/qrecc_train.json
2
+ EVAL_FILE=/home/jhju/datasets/qrecc/qrecc_test.json
3
+ TEST_FILE=dataset/2023_test_topics.json
4
+ BASE=google/flan-t5-base
5
+
6
+ preprocess:
7
+ # convert naacl baseline to run
8
+ python3 utils/convert_scai_baseline_to_run.py \
9
+ --scai-baseline-json dataset/scai-qrecc21-naacl-baseline.json
10
+ # convert qrels to trec
11
+ python3 utils/convert_scai_qrels_to_trec.py \
12
+ --scai-qrels-json dataset/scai_qrecc_test_qrel.json
13
+
14
+ train_flatten:
15
+ python3 train_flatten.py \
16
+ --model_name_or_path google/flan-t5-base \
17
+ --tokenizer_name google/flan-t5-base \
18
+ --config_name google/flan-t5-base \
19
+ --train_file ${TRAIN_FILE} \
20
+ --eval_file ${EVAL_FILE} \
21
+ --output_dir models/ckpt/function-base-flatten \
22
+ --per_device_train_batch_size 8 \
23
+ --max_src_length 256 \
24
+ --max_tgt_length 64 \
25
+ --learning_rate 1e-4 \
26
+ --evaluation_strategy steps \
27
+ --max_steps 20000 \
28
+ --save_steps 5000 \
29
+ --eval_steps 500 \
30
+ --do_train \
31
+ --do_eval \
32
+ --optim adafactor \
33
+ --n_conversations 6 \
34
+ --warmup_steps 1000 \
35
+ --lr_scheduler_type linear \
36
+ --instruction_prefix 'Based on previous conversations, rewrite the user utterance: {} into a standalone query.' \
37
+ --conversation_prefix 'user: {0} system: {1}' \
38
+ --report_to wandb
39
+
40
+
41
+ train:
42
+ python3 train.py \
43
+ --model_name_or_path google/flan-t5-base \
44
+ --tokenizer_name google/flan-t5-base \
45
+ --config_name google/flan-t5-base \
46
+ --train_file ${TRAIN_FILE} \
47
+ --eval_file ${EVAL_FILE} \
48
+ --output_dir models/ckpt/function-base \
49
+ --per_device_train_batch_size 8 \
50
+ --max_src_length 256 \
51
+ --max_tgt_length 64 \
52
+ --evaluation_strategy steps \
53
+ --max_steps 20000 \
54
+ --save_steps 5000 \
55
+ --eval_steps 500 \
56
+ --do_train \
57
+ --do_eval \
58
+ --optim adafactor \
59
+ --n_conversations 6 \
60
+ --warmup_steps 1000 \
61
+ --learning_rate 1e-3 \
62
+ --lr_scheduler_type linear \
63
+ --instruction_prefix 'Rewrite the user query: {0} based on the context: turn number: {1} question: {2} response: {3}' \
64
+ --report_to wandb
65
+
66
+ train_ntr:
67
+ python3 train_ntr.py \
68
+ --model_name_or_path google/flan-t5-base \
69
+ --tokenizer_name google/flan-t5-base \
70
+ --config_name google/flan-t5-base \
71
+ --train_file ${TRAIN_FILE} \
72
+ --eval_file ${EVAL_FILE} \
73
+ --output_dir models/ckpt/ntr-base \
74
+ --per_device_train_batch_size 8 \
75
+ --per_device_eval_batch_size 8 \
76
+ --max_src_length 512 \
77
+ --max_tgt_length 64 \
78
+ --evaluation_strategy steps \
79
+ --max_steps 20000 \
80
+ --save_steps 5000 \
81
+ --eval_steps 500 \
82
+ --do_train \
83
+ --do_eval \
84
+ --optim adafactor \
85
+ --n_conversations 6 \
86
+ --learning_rate 1e-3 \
87
+ --lr_scheduler_type linear \
88
+ --report_to wandb
89
+
90
+ train_compressed:
91
+ python3 train_compressed.py \
92
+ --model_name_or_path google/flan-t5-base \
93
+ --tokenizer_name google/flan-t5-base \
94
+ --config_name google/flan-t5-base \
95
+ --train_file ${TRAIN_FILE} \
96
+ --eval_file ${EVAL_FILE} \
97
+ --output_dir models/ckpt/function-base-compressed \
98
+ --per_device_train_batch_size 8 \
99
+ --max_src_length 64 \
100
+ --max_tgt_length 64 \
101
+ --max_src_conv_length 256 \
102
+ --learning_rate 1e-4 \
103
+ --evaluation_strategy steps \
104
+ --max_steps 20000 \
105
+ --save_steps 5000 \
106
+ --eval_steps 500 \
107
+ --do_train \
108
+ --do_eval \
109
+ --optim adafactor \
110
+ --n_conversations 10 \
111
+ --warmup_steps 1000 \
112
+ --lr_scheduler_type linear \
113
+ --instruction_prefix 'Rewrite the user utterance: {}, based on previous conversations. conversation: ' \
114
+ --conversation_prefix 'user: {0} system: {1}' \
115
+ --report_to wandb
116
+
117
+ rewrite_by_t5ntr:
118
+ python3 generate_ikat.py \
119
+ --model_name castorini/t5-base-canard \
120
+ --model_path castorini/t5-base-canard \
121
+ --input_file ${TEST_FILE} \
122
+ --output_jsonl results/ikat_test/t5ntr_history_3-3.jsonl \
123
+ --device cuda:0 \
124
+ --batch_size 4 \
125
+ --n_conversations 3 \
126
+ --n_responses 3 \
127
+ --num_beams 5 \
128
+ --max_src_length 512 \
129
+ --max_tgt_length 256
130
+
131
+ index_bm25:
132
+ python3 -m pyserini.index.lucene \
133
+ --collection JsonCollection \
134
+ --input /home/jhju/datasets/qrecc/collection-paragraph/ \
135
+ --index /home/jhju/indexes/qrecc-commoncrawl-lucene/ \
136
+ --generator DefaultLuceneDocumentGenerator \
137
+ --threads 8