license: cc-by-4.0
datasets:
- cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental
Mixtral8X7B Instructの日本語生成を安定させるためのLoraです。
目的
Mixtral-8x7Bは高性能な言語モデルですが、日本語出力に多言語が混入するcode-switchingがよく見られます。 元の性能を維持しながら、日本語生成を安定させる方法として、Loraの効果を検証しました。
学習データセット
学習データセットとして、下記のDPOデータセットを使用しています。
DPO trainingはVRAM消費が多く、今回はchosenのデータを使用したsft学習しています。
Chatbot Arena Conversations JA (calm2) Dataset
:cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental
指示文 : lmsys/chatbot_arena_conversationsのユーザ入力(CC-BY 4.0)を利用。
指示文の和訳 : facebookの翻訳モデル(MIT License)が使用されています。
応答文 : calm2-7b-chat(Apache 2.0)の出力です。
evaluation
大きな性能低下がないことを確認しました
##Lora
num_fewshot: 2, batch_size: 1
Task | Version | Metric | Value | Stderr | |
---|---|---|---|---|---|
jsquad-1.1-0.3 | 1.1 | exact_match | 72.3323 | ||
f1 | 85.4772 | ||||
jcommonsenseqa-1.1-0.3 | 1.1 | acc | 0.7498 | ± | 0.0130 |
acc_norm | 0.4138 | ± | 0.0147 |
num_fewshot: 2, batch_size: 1
Task | Version | Metric | Value | Stderr | |
---|---|---|---|---|---|
jnli-1.1-0.3 | 1.1 | acc | 0.5912 | ± | 0.0100 |
acc_norm | 0.4108 | ± | 0.0100 | ||
marc_ja-1.1-0.3 | 1.1 | acc | 0.9620 | ± | 0.0025 |
acc_norm | 0.9620 | ± | 0.0025 | ||
jaqket_v2-0.1-0.3 | 0.1 | exact_match | 71.6495 | ||
f1 | 79.4725 |
##Base model
num_fewshot: 3,3, batch_size: 1
Task | Version | Metric | Value | Stderr | |
---|---|---|---|---|---|
jsquad-1.1-0.3 | 1.1 | exact_match | 68.1225 | ||
f1 | 83.5285 | ||||
jcommonsenseqa-1.1-0.3 | 1.1 | acc | 0.7766 | ± | 0.0125 |
acc_norm | 0.4629 | ± | 0.0149 |
num_fewshot: 2, batch_size: 1
Task | Version | Metric | Value | Stderr | |
---|---|---|---|---|---|
jnli-1.1-0.3 | 1.1 | acc | 0.6228 | ± | 0.0098 |
acc_norm | 0.5288 | ± | 0.0101 | ||
marc_ja-1.1-0.3 | 1.1 | acc | 0.9630 | ± | 0.0025 |
acc_norm | 0.9630 | ± | 0.0025 | ||
jaqket_v2-0.1-0.3 | 0.1 | exact_match | 67.9553 | ||
f1 | 78.7550 |
その他
Lora学習時のcontext長は4096tokenまでですが、4k token以上の出力も可能です。
注:bf16での使用を想定しています。 量子化推論する場合は、bf16でモデルを読み込んだ状態でLora適応またはマージ、その後に量子化してください。
2/8更新
学習強度が1/3と、2/3のcheck pointも公開しました
こちらのほうがベースモデルの汎化性能維持できている可能性があります
learningstrength0.3
num_fewshot: 2,2, batch_size: 1
Task | Version | Metric | Value | Stderr | |
---|---|---|---|---|---|
jsquad-1.1-0.3 | 1.1 | exact_match | 72.1747 | ||
f1 | 85.3325 | ||||
jcommonsenseqa-1.1-0.3 | 1.1 | acc | 0.7534 | ± | 0.0129 |
acc_norm | 0.4111 | ± | 0.0147 |
learningstrength0.6
num_fewshot: 2,2, batch_size: 1
Task | Version | Metric | Value | Stderr | |
---|---|---|---|---|---|
jsquad-1.1-0.3 | 1.1 | exact_match | 72.3548 | ||
f1 | 85.5144 | ||||
jcommonsenseqa-1.1-0.3 | 1.1 | acc | 0.7480 | ± | 0.0130 |
acc_norm | 0.4111 | ± | 0.0147 |