Safetensors
File size: 4,601 Bytes
c74db76
 
f42d64a
 
c74db76
 
170c18b
e3fca71
a20b574
e38d61e
b49e481
94a3267
c74db76
a20b574
e38d61e
e3fca71
94a3267
c74db76
3ae6e36
 
94a3267
 
 
c74db76
a20b574
e38d61e
c74db76
 
4aad6e1
c74db76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aad6e1
c74db76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a20b574
e38d61e
94a3267
587b5bc
 
86b41fe
 
87bbcd3
86b41fe
 
7732356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
---
license: cc-by-4.0
datasets:
- cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental
---

Mixtral8X7B Instructの日本語生成を安定させるためのLoraです。  

**目的**

Mixtral-8x7Bは高性能な言語モデルですが、日本語出力に多言語が混入するcode-switchingがよく見られます。
元の性能を維持しながら、日本語生成を安定させる方法として、Loraの効果を検証しました。

**学習データセット**

学習データセットとして、下記のDPOデータセットを使用しています。  
DPO trainingはVRAM消費が多く、今回はchosenのデータを使用したsft学習しています。

Chatbot Arena Conversations JA (calm2) Dataset  
:[cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental](https://huggingface.co/datasets/cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental)  
指示文 : [lmsys/chatbot_arena_conversations](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations)のユーザ入力(CC-BY 4.0)を利用。  
指示文の和訳 : [facebookの翻訳モデル(MIT License)](https://huggingface.co/facebook/wmt21-dense-24-wide-en-x)が使用されています。  
応答文 : calm2-7b-chat(Apache 2.0)の出力です。

**evaluation**

大きな性能低下がないことを確認しました

##Lora

num_fewshot: 2, batch_size: 1
|         Task         |Version|  Metric   | Value |   |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3        |    1.1|exact_match|72.3323|   |      |
|                      |       |f1         |85.4772|   |      |
|jcommonsenseqa-1.1-0.3|    1.1|acc        | 0.7498|±  |0.0130|
|                      |       |acc_norm   | 0.4138|±  |0.0147|


num_fewshot: 2, batch_size: 1
|      Task       |Version|  Metric   | Value |   |Stderr|
|-----------------|------:|-----------|------:|---|-----:|
|jnli-1.1-0.3     |    1.1|acc        | 0.5912|±  |0.0100|
|                 |       |acc_norm   | 0.4108|±  |0.0100|
|marc_ja-1.1-0.3  |    1.1|acc        | 0.9620|±  |0.0025|
|                 |       |acc_norm   | 0.9620|±  |0.0025|
|jaqket_v2-0.1-0.3|    0.1|exact_match|71.6495|   |      |
|                 |       |f1         |79.4725|   |      |


##Base model

num_fewshot: 3,3, batch_size: 1
|         Task         |Version|  Metric   | Value |   |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3        |    1.1|exact_match|68.1225|   |      |
|                      |       |f1         |83.5285|   |      |
|jcommonsenseqa-1.1-0.3|    1.1|acc        | 0.7766|±  |0.0125|
|                      |       |acc_norm   | 0.4629|±  |0.0149|


num_fewshot: 2, batch_size: 1
|      Task       |Version|  Metric   | Value |   |Stderr|
|-----------------|------:|-----------|------:|---|-----:|
|jnli-1.1-0.3     |    1.1|acc        | 0.6228|±  |0.0098|
|                 |       |acc_norm   | 0.5288|±  |0.0101|
|marc_ja-1.1-0.3  |    1.1|acc        | 0.9630|±  |0.0025|
|                 |       |acc_norm   | 0.9630|±  |0.0025|
|jaqket_v2-0.1-0.3|    0.1|exact_match|67.9553|   |      |
|                 |       |f1         |78.7550|   |      |


**その他**

Lora学習時のcontext長は4096tokenまでですが、4k token以上の出力も可能です。  

注:bf16での使用を想定しています。
量子化推論する場合は、bf16でモデルを読み込んだ状態でLora適応またはマージ、その後に量子化してください。

**2/8更新**
学習強度が1/3と、2/3のcheck pointも公開しました  
こちらのほうがベースモデルの汎化性能維持できている可能性があります

**learningstrength0.3**  
num_fewshot: 2,2, batch_size: 1
|         Task         |Version|  Metric   | Value |   |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3        |    1.1|exact_match|72.1747|   |      |
|                      |       |f1         |85.3325|   |      |
|jcommonsenseqa-1.1-0.3|    1.1|acc        | 0.7534|±  |0.0129|
|                      |       |acc_norm   | 0.4111|±  |0.0147|

**learningstrength0.6**  
num_fewshot: 2,2, batch_size: 1
|         Task         |Version|  Metric   | Value |   |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3        |    1.1|exact_match|72.3548|   |      |
|                      |       |f1         |85.5144|   |      |
|jcommonsenseqa-1.1-0.3|    1.1|acc        | 0.7480|±  |0.0130|
|                      |       |acc_norm   | 0.4111|±  |0.0147|