Edit model card

image/webp

Jambatypus-v0.1

This model is a QLoRA fine-tuned version of ai21labs/Jamba-v0.1 on the chargoddard/Open-Platypus-Chat dataset.

It has been trained on 2xA100 80 GB using my LazyAxolotl - Jamba notebook.

This repo contains both the adapter and the merged model in FP16 precision.

I recommend using the ChatML template to use this model.

Built with Axolotl

See axolotl config

axolotl version: 0.4.0


base_model: ai21labs/Jamba-v0.1
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: chargoddard/Open-Platypus-Chat
    type: sharegpt
chat_template: chatml
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false

use_wandb: true
wandb_project: axolotl
wandb_entity:
wandb_watch:
wandb_name: Jambatypus-v0.1
wandb_log_model:

adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true

low_cpu_mem_usage: true
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 4
save_total_limit: 2
debug:
deepspeed:
weight_decay: 0.0
special_tokens:

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0002
  • train_batch_size: 1
  • eval_batch_size: 1
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 2
  • gradient_accumulation_steps: 8
  • total_train_batch_size: 16
  • total_eval_batch_size: 2
  • optimizer: Adam with betas=(0.9,0.95) and epsilon=1e-05
  • lr_scheduler_type: cosine
  • lr_scheduler_warmup_steps: 10
  • num_epochs: 1

Training results

Training Loss Epoch Step Validation Loss
0.6274 0.01 1 1.0298
0.44 0.25 42 0.9770
0.4406 0.5 84 0.9653
0.4445 0.75 126 0.9645
0.4609 1.0 168 0.9641

Framework versions

  • PEFT 0.10.0
  • Transformers 4.40.0.dev0
  • Pytorch 2.1.2+cu118
  • Datasets 2.18.0
  • Tokenizers 0.15.0

πŸ’» Usage

The following code creates a Gradio chat interface with Jambatypus.

!pip install -qqq -U git+https://github.com/huggingface/transformers
!pip install -qqq mamba-ssm causal-conv1d>=1.2.0
!pip install -qqq accelerate bitsandbytes torch datasets peft gradio
!pip install -qqq flash-attn --no-build-isolation


import torch
import gradio as gr
from threading import Thread
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

STOP_TOKEN = "<|im_end|>"

def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    # Format history with a given chat template
    stop_token = "<|im_end|>"
    instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
    for human, assistant in history:
        instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
    instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    generate_kwargs = dict(
        {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for new_token in streamer:
        if STOP_TOKEN in new_token:
            outputs.append(new_token[:-len(stop_token)-1])
            yield "".join(outputs)
            break
        outputs.append(new_token)
        yield "".join(outputs)


# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

# 4-bit precision quant config
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_skip_modules=["mamba"]
)
# Load model and tokenizer with ChatML format
model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/Jamba-v0.1",
    trust_remote_code=True,
    torch_dtype=torch.cuda.is_bf16_supported() and torch.bfloat16 or torch.float16,
    attn_implementation="flash_attention_2",
    low_cpu_mem_usage=True,
    quantization_config=quantization_config
)
config = PeftConfig.from_pretrained("mlabonne/Jambatypus-v0.1")
model = PeftModel.from_pretrained(model, "mlabonne/Jambatypus-v0.1")

# Create Gradio interface
gr.ChatInterface(
    predict,
    title="Jambatypus",
    description="Chat with Jambatypus!",
    examples=[
        ["Can you solve the equation 2x + 3 = 11 for x?"],
        ["Write an epic poem about Ancient Rome."],
        ["Who was the first person to walk on the Moon?"],
        ["Use a list comprehension to create a list of squares for numbers from 1 to 10."],
        ["Recommend some popular science fiction books."],
        ["Can you write a short story about a time-traveling detective?"]
    ],
    additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=False),
    additional_inputs=[
        gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
        gr.Slider(0, 1, 0.8, label="Temperature"),
        gr.Slider(128, 4096, 1024, label="Max new tokens"),
        gr.Slider(1, 80, 40, label="Top K sampling"),
        gr.Slider(0, 2, 1.1, label="Repetition penalty"),
        gr.Slider(0, 1, 0.95, label="Top P sampling"),
    ],
    theme=gr.themes.Soft(primary_hue="green"),
).queue().launch(share=True)
Downloads last month
16
Safetensors
Model size
51.6B params
Tensor type
FP16
Β·
F32
Β·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for mlabonne/Jambatypus-v0.1

Quantized
(4)
this model