morizon commited on
Commit
7dce37f
1 Parent(s): fb8ac53

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ def install_package():
5
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "https://github.com/hibikaze-git/LLaVA-JP@feature/tanuki-moe"])
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import transformers
10
+ from transformers import BitsAndBytesConfig
11
+ from llavajp.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
12
+ from llavajp.conversation import conv_templates
13
+ from llavajp.model.llava_llama import LlavaLlamaForCausalLM
14
+ from llavajp.train.dataset import tokenizer_image_token
15
+ import spaces
16
+
17
+ model_path = "weblab-GENIAC/Tanuki-8B-vision"
18
+
19
+ # load model
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
22
+
23
+ bnb_model_from_pretrained_args = {}
24
+ bnb_model_from_pretrained_args.update(dict(
25
+ device_map="auto",
26
+ quantization_config=BitsAndBytesConfig(
27
+ load_in_8bit=True,
28
+ llm_int8_skip_modules=["mm_projector", "vision_tower"],
29
+ llm_int8_threshold=6.0,
30
+ llm_int8_has_fp16_weight=False,
31
+ )
32
+ ))
33
+
34
+ model = LlavaLlamaForCausalLM.from_pretrained(
35
+ model_path,
36
+ low_cpu_mem_usage=True,
37
+ use_safetensors=True,
38
+ **bnb_model_from_pretrained_args
39
+ )
40
+
41
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
42
+ model_path,
43
+ model_max_length=8192,
44
+ padding_side="right",
45
+ use_fast=False,
46
+ )
47
+ model.eval()
48
+ conv_mode = "v1"
49
+
50
+ @torch.inference_mode()
51
+ def inference_fn(
52
+ image,
53
+ prompt,
54
+ max_len,
55
+ temperature,
56
+ top_p,
57
+ no_repeat_ngram_size
58
+ ):
59
+ # prepare inputs
60
+ # image pre-process
61
+ image_size = model.get_model().vision_tower.image_processor.size["height"]
62
+ if model.get_model().vision_tower.scales is not None:
63
+ image_size = model.get_model().vision_tower.image_processor.size[
64
+ "height"
65
+ ] * len(model.get_model().vision_tower.scales)
66
+
67
+ if device == "cuda":
68
+ image_tensor = (
69
+ model.get_model()
70
+ .vision_tower.image_processor(
71
+ image,
72
+ return_tensors="pt",
73
+ size={"height": image_size, "width": image_size},
74
+ )["pixel_values"]
75
+ .half()
76
+ .cuda()
77
+ .to(torch_dtype)
78
+ )
79
+ else:
80
+ image_tensor = (
81
+ model.get_model()
82
+ .vision_tower.image_processor(
83
+ image,
84
+ return_tensors="pt",
85
+ size={"height": image_size, "width": image_size},
86
+ )["pixel_values"]
87
+ .to(torch_dtype)
88
+ )
89
+
90
+ # create prompt
91
+ inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt
92
+ conv = conv_templates[conv_mode].copy()
93
+ conv.append_message(conv.roles[0], inp)
94
+ conv.append_message(conv.roles[1], None)
95
+ prompt = conv.get_prompt()
96
+
97
+ input_ids = tokenizer_image_token(
98
+ prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
99
+ ).unsqueeze(0)
100
+ if device == "cuda":
101
+ input_ids = input_ids.to(device)
102
+
103
+ input_ids = input_ids[:, :-1] # </sep>がinputの最後に入るので削除する
104
+
105
+ # generate
106
+ output_ids = model.generate(
107
+ inputs=input_ids,
108
+ images=image_tensor,
109
+ do_sample=temperature != 0.0,
110
+ temperature=temperature,
111
+ top_p=top_p,
112
+ max_new_tokens=max_len,
113
+ repetition_penalty=1.0,
114
+ use_cache=False,
115
+ no_repeat_ngram_size=no_repeat_ngram_size
116
+ )
117
+
118
+ output_ids = [
119
+ token_id for token_id in output_ids.tolist()[0] if token_id != IMAGE_TOKEN_INDEX
120
+ ]
121
+
122
+ print(output_ids)
123
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
124
+
125
+ print(output)
126
+
127
+ target = "システム: "
128
+ idx = output.find(target)
129
+ output = output[idx + len(target) :]
130
+
131
+ return output
132
+
133
+ @spaces.GPU
134
+ with gr.Blocks() as demo:
135
+ gr.Markdown("# LLaVA-JP Demo")
136
+
137
+ with gr.Row():
138
+ with gr.Column():
139
+ # input_instruction = gr.TextArea(label="instruction", value=DEFAULT_INSTRUCTION)
140
+ input_image = gr.Image(type="pil", label="image")
141
+ prompt = gr.Textbox(label="prompt (optional)", value="")
142
+ with gr.Accordion(label="Configs", open=False):
143
+ max_len = gr.Slider(
144
+ minimum=10,
145
+ maximum=256,
146
+ value=200,
147
+ step=5,
148
+ interactive=True,
149
+ label="Max New Tokens",
150
+ )
151
+
152
+ temperature = gr.Slider(
153
+ minimum=0.0,
154
+ maximum=1.0,
155
+ value=0.0,
156
+ step=0.1,
157
+ interactive=True,
158
+ label="Temperature",
159
+ )
160
+
161
+ top_p = gr.Slider(
162
+ minimum=0.5,
163
+ maximum=1.0,
164
+ value=1.0,
165
+ step=0.1,
166
+ interactive=True,
167
+ label="Top p",
168
+ )
169
+
170
+ no_repeat_ngram_size = gr.Slider(
171
+ minimum=0,
172
+ maximum=4,
173
+ value=3.0,
174
+ step=1,
175
+ interactive=True,
176
+ label="No Repeat Ngram Size(1, 2にすると出力が狂います)",
177
+ )
178
+ # button
179
+ input_button = gr.Button(value="Submit")
180
+ with gr.Column():
181
+ output = gr.Textbox(label="Output")
182
+
183
+ inputs = [input_image, prompt, max_len, temperature, top_p, no_repeat_ngram_size]
184
+ input_button.click(inference_fn, inputs=inputs, outputs=[output])
185
+ prompt.submit(inference_fn, inputs=inputs, outputs=[output])
186
+ img2txt_examples = gr.Examples(
187
+ examples=[
188
+ [
189
+ "https://raw.githubusercontent.com/hibikaze-git/LLaVA-JP/feature/package/imgs/sample1.jpg",
190
+ "猫の隣には何がありますか?",
191
+ 128,
192
+ 0.0,
193
+ 1.0,
194
+ 3.0
195
+ ],
196
+ ],
197
+ inputs=inputs,
198
+ )
199
+
200
+
201
+ if __name__ == "__main__":
202
+ demo.queue().launch(server_name="0.0.0.0")