import gradio as gr import torch import transformers from transformers import BitsAndBytesConfig from llavajp.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX from llavajp.conversation import conv_templates from llavajp.model.llava_llama import LlavaLlamaForCausalLM from llavajp.train.dataset import tokenizer_image_token import spaces # import subprocess # import sys # def install_package(): # subprocess.check_call([sys.executable, "-m", "pip", "install", "https://github.com/hibikaze-git/LLaVA-JP@feature/tanuki-moe"]) model_path = "weblab-GENIAC/Tanuki-8B-vision" # load model device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 bnb_model_from_pretrained_args = {} bnb_model_from_pretrained_args.update(dict( device_map="auto", quantization_config=BitsAndBytesConfig( load_in_8bit=True, llm_int8_skip_modules=["mm_projector", "vision_tower"], llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, ) )) model = LlavaLlamaForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, use_safetensors=True, **bnb_model_from_pretrained_args ) tokenizer = transformers.AutoTokenizer.from_pretrained( model_path, model_max_length=8192, padding_side="right", use_fast=False, ) model.eval() conv_mode = "v1" @spaces.GPU(duration=120) @torch.inference_mode() def inference_fn( image, prompt, max_len, temperature, top_p, no_repeat_ngram_size ): # prepare inputs # image pre-process image_size = model.get_model().vision_tower.image_processor.size["height"] if model.get_model().vision_tower.scales is not None: image_size = model.get_model().vision_tower.image_processor.size[ "height" ] * len(model.get_model().vision_tower.scales) if device == "cuda": image_tensor = ( model.get_model() .vision_tower.image_processor( image, return_tensors="pt", size={"height": image_size, "width": image_size}, )["pixel_values"] .half() .cuda() .to(torch_dtype) ) else: image_tensor = ( model.get_model() .vision_tower.image_processor( image, return_tensors="pt", size={"height": image_size, "width": image_size}, )["pixel_values"] .to(torch_dtype) ) # create prompt inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt conv = conv_templates[conv_mode].copy() conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token( prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0) if device == "cuda": input_ids = input_ids.to(device) input_ids = input_ids[:, :-1] # がinputの最後に入るので削除する # generate output_ids = model.generate( inputs=input_ids, images=image_tensor, do_sample=temperature != 0.0, temperature=temperature, top_p=top_p, max_new_tokens=max_len, repetition_penalty=1.0, use_cache=False, no_repeat_ngram_size=no_repeat_ngram_size ) output_ids = [ token_id for token_id in output_ids.tolist()[0] if token_id != IMAGE_TOKEN_INDEX ] print(output_ids) output = tokenizer.decode(output_ids, skip_special_tokens=True) print(output) target = "システム: " idx = output.find(target) output = output[idx + len(target) :] return output with gr.Blocks() as demo: gr.Markdown("# Tanuki-8B-vision Demo") with gr.Row(): with gr.Column(): # input_instruction = gr.TextArea(label="instruction", value=DEFAULT_INSTRUCTION) input_image = gr.Image(type="pil", label="image") prompt = gr.Textbox(label="prompt (optional)", value="") with gr.Accordion(label="Configs", open=False): max_len = gr.Slider( minimum=10, maximum=256, value=200, step=5, interactive=True, label="Max New Tokens", ) temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature", ) top_p = gr.Slider( minimum=0.5, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Top p", ) no_repeat_ngram_size = gr.Slider( minimum=0, maximum=4, value=3.0, step=1, interactive=True, label="No Repeat Ngram Size(1, 2にすると出力が狂います)", ) # button input_button = gr.Button(value="Submit") with gr.Column(): output = gr.Textbox(label="Output") inputs = [input_image, prompt, max_len, temperature, top_p, no_repeat_ngram_size] input_button.click(inference_fn, inputs=inputs, outputs=[output]) prompt.submit(inference_fn, inputs=inputs, outputs=[output]) img2txt_examples = gr.Examples( examples=[ [ "https://raw.githubusercontent.com/hibikaze-git/LLaVA-JP/feature/package/imgs/sample1.jpg", "猫の隣には何がありますか?", 128, 0.0, 1.0, 3.0 ], ], inputs=inputs, ) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0")