morizon's picture
Create app.py
7dce37f verified
raw
history blame
6.02 kB
import subprocess
import sys
def install_package():
subprocess.check_call([sys.executable, "-m", "pip", "install", "https://github.com/hibikaze-git/LLaVA-JP@feature/tanuki-moe"])
import gradio as gr
import torch
import transformers
from transformers import BitsAndBytesConfig
from llavajp.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llavajp.conversation import conv_templates
from llavajp.model.llava_llama import LlavaLlamaForCausalLM
from llavajp.train.dataset import tokenizer_image_token
import spaces
model_path = "weblab-GENIAC/Tanuki-8B-vision"
# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
bnb_model_from_pretrained_args = {}
bnb_model_from_pretrained_args.update(dict(
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_skip_modules=["mm_projector", "vision_tower"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
))
model = LlavaLlamaForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
use_safetensors=True,
**bnb_model_from_pretrained_args
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
model_max_length=8192,
padding_side="right",
use_fast=False,
)
model.eval()
conv_mode = "v1"
@torch.inference_mode()
def inference_fn(
image,
prompt,
max_len,
temperature,
top_p,
no_repeat_ngram_size
):
# prepare inputs
# image pre-process
image_size = model.get_model().vision_tower.image_processor.size["height"]
if model.get_model().vision_tower.scales is not None:
image_size = model.get_model().vision_tower.image_processor.size[
"height"
] * len(model.get_model().vision_tower.scales)
if device == "cuda":
image_tensor = (
model.get_model()
.vision_tower.image_processor(
image,
return_tensors="pt",
size={"height": image_size, "width": image_size},
)["pixel_values"]
.half()
.cuda()
.to(torch_dtype)
)
else:
image_tensor = (
model.get_model()
.vision_tower.image_processor(
image,
return_tensors="pt",
size={"height": image_size, "width": image_size},
)["pixel_values"]
.to(torch_dtype)
)
# create prompt
inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0)
if device == "cuda":
input_ids = input_ids.to(device)
input_ids = input_ids[:, :-1] # </sep>がinputの最後に入るので削除する
# generate
output_ids = model.generate(
inputs=input_ids,
images=image_tensor,
do_sample=temperature != 0.0,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_len,
repetition_penalty=1.0,
use_cache=False,
no_repeat_ngram_size=no_repeat_ngram_size
)
output_ids = [
token_id for token_id in output_ids.tolist()[0] if token_id != IMAGE_TOKEN_INDEX
]
print(output_ids)
output = tokenizer.decode(output_ids, skip_special_tokens=True)
print(output)
target = "システム: "
idx = output.find(target)
output = output[idx + len(target) :]
return output
@spaces.GPU
with gr.Blocks() as demo:
gr.Markdown("# LLaVA-JP Demo")
with gr.Row():
with gr.Column():
# input_instruction = gr.TextArea(label="instruction", value=DEFAULT_INSTRUCTION)
input_image = gr.Image(type="pil", label="image")
prompt = gr.Textbox(label="prompt (optional)", value="")
with gr.Accordion(label="Configs", open=False):
max_len = gr.Slider(
minimum=10,
maximum=256,
value=200,
step=5,
interactive=True,
label="Max New Tokens",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.5,
maximum=1.0,
value=1.0,
step=0.1,
interactive=True,
label="Top p",
)
no_repeat_ngram_size = gr.Slider(
minimum=0,
maximum=4,
value=3.0,
step=1,
interactive=True,
label="No Repeat Ngram Size(1, 2にすると出力が狂います)",
)
# button
input_button = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Output")
inputs = [input_image, prompt, max_len, temperature, top_p, no_repeat_ngram_size]
input_button.click(inference_fn, inputs=inputs, outputs=[output])
prompt.submit(inference_fn, inputs=inputs, outputs=[output])
img2txt_examples = gr.Examples(
examples=[
[
"https://raw.githubusercontent.com/hibikaze-git/LLaVA-JP/feature/package/imgs/sample1.jpg",
"猫の隣には何がありますか?",
128,
0.0,
1.0,
3.0
],
],
inputs=inputs,
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0")