Jambatypus-v0.1
This model is a QLoRA fine-tuned version of ai21labs/Jamba-v0.1 on the chargoddard/Open-Platypus-Chat dataset.
It has been trained on 2xA100 80 GB using my LazyAxolotl - Jamba notebook.
This repo contains both the adapter and the merged model in FP16 precision.
I recommend using the ChatML template to use this model.
See axolotl config
axolotl version: 0.4.0
base_model: ai21labs/Jamba-v0.1
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: chargoddard/Open-Platypus-Chat
type: sharegpt
chat_template: chatml
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
use_wandb: true
wandb_project: axolotl
wandb_entity:
wandb_watch:
wandb_name: Jambatypus-v0.1
wandb_log_model:
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
low_cpu_mem_usage: true
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 4
save_total_limit: 2
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.0002
- train_batch_size: 1
- eval_batch_size: 1
- seed: 42
- distributed_type: multi-GPU
- num_devices: 2
- gradient_accumulation_steps: 8
- total_train_batch_size: 16
- total_eval_batch_size: 2
- optimizer: Adam with betas=(0.9,0.95) and epsilon=1e-05
- lr_scheduler_type: cosine
- lr_scheduler_warmup_steps: 10
- num_epochs: 1
Training results
Training Loss | Epoch | Step | Validation Loss |
---|---|---|---|
0.6274 | 0.01 | 1 | 1.0298 |
0.44 | 0.25 | 42 | 0.9770 |
0.4406 | 0.5 | 84 | 0.9653 |
0.4445 | 0.75 | 126 | 0.9645 |
0.4609 | 1.0 | 168 | 0.9641 |
Framework versions
- PEFT 0.10.0
- Transformers 4.40.0.dev0
- Pytorch 2.1.2+cu118
- Datasets 2.18.0
- Tokenizers 0.15.0
π» Usage
The following code creates a Gradio chat interface with Jambatypus.
!pip install -qqq -U git+https://github.com/huggingface/transformers
!pip install -qqq mamba-ssm causal-conv1d>=1.2.0
!pip install -qqq accelerate bitsandbytes torch datasets peft gradio
!pip install -qqq flash-attn --no-build-isolation
import torch
import gradio as gr
from threading import Thread
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
STOP_TOKEN = "<|im_end|>"
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
# Format history with a given chat template
stop_token = "<|im_end|>"
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for human, assistant in history:
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
generate_kwargs = dict(
{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
if STOP_TOKEN in new_token:
outputs.append(new_token[:-len(stop_token)-1])
yield "".join(outputs)
break
outputs.append(new_token)
yield "".join(outputs)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
# 4-bit precision quant config
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_skip_modules=["mamba"]
)
# Load model and tokenizer with ChatML format
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.cuda.is_bf16_supported() and torch.bfloat16 or torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
quantization_config=quantization_config
)
config = PeftConfig.from_pretrained("mlabonne/Jambatypus-v0.1")
model = PeftModel.from_pretrained(model, "mlabonne/Jambatypus-v0.1")
# Create Gradio interface
gr.ChatInterface(
predict,
title="Jambatypus",
description="Chat with Jambatypus!",
examples=[
["Can you solve the equation 2x + 3 = 11 for x?"],
["Write an epic poem about Ancient Rome."],
["Who was the first person to walk on the Moon?"],
["Use a list comprehension to create a list of squares for numbers from 1 to 10."],
["Recommend some popular science fiction books."],
["Can you write a short story about a time-traveling detective?"]
],
additional_inputs_accordion=gr.Accordion(label="βοΈ Parameters", open=False),
additional_inputs=[
gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
gr.Slider(0, 1, 0.8, label="Temperature"),
gr.Slider(128, 4096, 1024, label="Max new tokens"),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
theme=gr.themes.Soft(primary_hue="green"),
).queue().launch(share=True)
- Downloads last month
- 16
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for mlabonne/Jambatypus-v0.1
Base model
ai21labs/Jamba-v0.1