SteveJobs-Chat / app.py
emya's picture
use quantization config
8b0674e
raw
history blame contribute delete
872 Bytes
import gradio as gr
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
pipeline
)
import torch
quantization_config = BitsAndBytesConfig(llm_int8_enable_fp16_cpu_offload=True)
model_name = "lmsys/vicuna-7b-v1.5"
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
return_dict=True,
#torch_dtype=torch.float16,
device_map="cpu",
load_in_8bit=True,
quantization_config=quantization_config,
)
new_model = "emya/vicuna-7b-v1.5-steve-jobs-8bit-v1"
model = PeftModel.from_pretrained(base_model, new_model, load_in_8bit=True)
pipe = pipeline("translation", model=model)
def predict(text):
prompt = f"{text} (Answer in a few sentences)"
return pipe(prompt)[0]["translation_text"]
demo = gr.Interface(
fn=predict,
inputs='text',
outputs='text',
)
demo.launch()