Safetensors
aixsatoshi's picture
Update README.md
87bbcd3 verified
---
license: cc-by-4.0
datasets:
- cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental
---
Mixtral8X7B Instructの日本語生成を安定させるためのLoraです。
**目的**
Mixtral-8x7Bは高性能な言語モデルですが、日本語出力に多言語が混入するcode-switchingがよく見られます。
元の性能を維持しながら、日本語生成を安定させる方法として、Loraの効果を検証しました。
**学習データセット**
学習データセットとして、下記のDPOデータセットを使用しています。
DPO trainingはVRAM消費が多く、今回はchosenのデータを使用したsft学習しています。
Chatbot Arena Conversations JA (calm2) Dataset
:[cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental](https://huggingface.co/datasets/cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental)
指示文 : [lmsys/chatbot_arena_conversations](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations)のユーザ入力(CC-BY 4.0)を利用。
指示文の和訳 : [facebookの翻訳モデル(MIT License)](https://huggingface.co/facebook/wmt21-dense-24-wide-en-x)が使用されています。
応答文 : calm2-7b-chat(Apache 2.0)の出力です。
**evaluation**
大きな性能低下がないことを確認しました
##Lora
num_fewshot: 2, batch_size: 1
| Task |Version| Metric | Value | |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3 | 1.1|exact_match|72.3323| | |
| | |f1 |85.4772| | |
|jcommonsenseqa-1.1-0.3| 1.1|acc | 0.7498|± |0.0130|
| | |acc_norm | 0.4138|± |0.0147|
num_fewshot: 2, batch_size: 1
| Task |Version| Metric | Value | |Stderr|
|-----------------|------:|-----------|------:|---|-----:|
|jnli-1.1-0.3 | 1.1|acc | 0.5912|± |0.0100|
| | |acc_norm | 0.4108|± |0.0100|
|marc_ja-1.1-0.3 | 1.1|acc | 0.9620|± |0.0025|
| | |acc_norm | 0.9620|± |0.0025|
|jaqket_v2-0.1-0.3| 0.1|exact_match|71.6495| | |
| | |f1 |79.4725| | |
##Base model
num_fewshot: 3,3, batch_size: 1
| Task |Version| Metric | Value | |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3 | 1.1|exact_match|68.1225| | |
| | |f1 |83.5285| | |
|jcommonsenseqa-1.1-0.3| 1.1|acc | 0.7766|± |0.0125|
| | |acc_norm | 0.4629|± |0.0149|
num_fewshot: 2, batch_size: 1
| Task |Version| Metric | Value | |Stderr|
|-----------------|------:|-----------|------:|---|-----:|
|jnli-1.1-0.3 | 1.1|acc | 0.6228|± |0.0098|
| | |acc_norm | 0.5288|± |0.0101|
|marc_ja-1.1-0.3 | 1.1|acc | 0.9630|± |0.0025|
| | |acc_norm | 0.9630|± |0.0025|
|jaqket_v2-0.1-0.3| 0.1|exact_match|67.9553| | |
| | |f1 |78.7550| | |
**その他**
Lora学習時のcontext長は4096tokenまでですが、4k token以上の出力も可能です。
注:bf16での使用を想定しています。
量子化推論する場合は、bf16でモデルを読み込んだ状態でLora適応またはマージ、その後に量子化してください。
**2/8更新**
学習強度が1/3と、2/3のcheck pointも公開しました
こちらのほうがベースモデルの汎化性能維持できている可能性があります
**learningstrength0.3**
num_fewshot: 2,2, batch_size: 1
| Task |Version| Metric | Value | |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3 | 1.1|exact_match|72.1747| | |
| | |f1 |85.3325| | |
|jcommonsenseqa-1.1-0.3| 1.1|acc | 0.7534|± |0.0129|
| | |acc_norm | 0.4111|± |0.0147|
**learningstrength0.6**
num_fewshot: 2,2, batch_size: 1
| Task |Version| Metric | Value | |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3 | 1.1|exact_match|72.3548| | |
| | |f1 |85.5144| | |
|jcommonsenseqa-1.1-0.3| 1.1|acc | 0.7480|± |0.0130|
| | |acc_norm | 0.4111|± |0.0147|