add example for 4bit inference?

#11
by ct-2 - opened

There seems to be an explanation to finetune the model in 4bit, would it be possible to provide more info on 4bit inference? Thanks!

Can you provide a bit more specifics of what you're looking for over what was provided as an example for 4bit on the model card?

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                         llm_int8_skip_modules=["mamba"])

device_map = {
    'model.embed_tokens': 0,
    'model.layers.0': 0,
    'model.layers.1': 0,
    'model.layers.2': 0,
    'model.layers.3': 0,
    'model.layers.4': 0,
    'model.layers.5': 0,
    'model.layers.6': 0,
    'model.layers.7': 0,
    'model.layers.8': 0,
    'model.layers.9': 0,
    'model.layers.10': 0,
    'model.layers.11': 0,
    'model.layers.12': 0,
    'model.layers.13': 0,
    'model.layers.14': 0,
    'model.layers.15': 0,
    'model.layers.16': 1,
    'model.layers.17': 1,
    'model.layers.18': 1,
    'model.layers.19': 1,
    'model.layers.20': 1,
    'model.layers.21': 1,
    'model.layers.22': 1,
    'model.layers.23': 1,
    'model.layers.24': 1,
    'model.layers.25': 1,
    'model.layers.26': 1,
    'model.layers.27': 1,
    'model.layers.28': 1,
    'model.layers.29': 1,
    'model.layers.30': 1,
    'model.layers.31': 1,
    'model.final_layernorm': 1,
    'lm_head': 1
}

model = AutoModelForCausalLM.from_pretrained("./ai21labs_AI21-Jamba-1.5-Mini",
                                             torch_dtype=torch.bfloat16,
                                             attn_implementation="flash_attention_2",
                                             quantization_config=quantization_config,
                                             device_map=device_map,
                         use_mamba_kernels=False)

tokenizer = AutoTokenizer.from_pretrained("./ai21labs_AI21-Jamba-1.5-Mini")

messages = [
   {"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."},
   {"role": "user", "content": "Hello!"},
]

input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to(model.device)

outputs = model.generate(input_ids, max_new_tokens=216)

# Decode the output
conversation = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Split the conversation to get only the assistant's response
assistant_response = conversation.split(messages[-1]['content'])[1].strip()
print(assistant_response)
# Output: Seek and you shall find. The path is winding, but the journey is enlightening. What wisdom do you seek from the ancient echoes?

hi there, thank you for your reply! Does this example look correct? I have fit jamba-mini into 2 24gb GPUs. However, I don't know how to write and run an example for vllm bitsandbytes (4bits).

Sign up or log in to comment