Spaces:
Runtime error
Runtime error
zhanghaoji
commited on
Commit
•
eb0678a
1
Parent(s):
378cd97
init
Browse files- app.py +172 -62
- app_old.py +63 -0
- flash_vstream/__init__.py +1 -0
- flash_vstream/constants.py +15 -0
- flash_vstream/conversation.py +337 -0
- flash_vstream/eval_video/eval_activitynet_qa.py +296 -0
- flash_vstream/eval_video/eval_any_dataset_features.py +340 -0
- flash_vstream/eval_video/model_msvd_qa.py +157 -0
- flash_vstream/eval_video/model_msvd_qa_featuresloader.py +179 -0
- flash_vstream/mm_utils.py +106 -0
- flash_vstream/model/__init__.py +1 -0
- flash_vstream/model/builder.py +139 -0
- flash_vstream/model/compress_functions.py +277 -0
- flash_vstream/model/language_model/vstream_llama.py +129 -0
- flash_vstream/model/multimodal_encoder/builder.py +13 -0
- flash_vstream/model/multimodal_encoder/clip_encoder.py +80 -0
- flash_vstream/model/multimodal_projector/builder.py +51 -0
- flash_vstream/model/vstream_arch.py +742 -0
- flash_vstream/serve/cli_video_stream.py +351 -0
- flash_vstream/serve/demo.py +144 -0
- flash_vstream/train/llama_flash_attn_monkey_patch.py +117 -0
- flash_vstream/train/llama_xformers_attn_monkey_patch.py +131 -0
- flash_vstream/train/train.py +1069 -0
- flash_vstream/train/train_mem.py +14 -0
- flash_vstream/train/train_xformers.py +15 -0
- flash_vstream/train/vstream_trainer.py +248 -0
- flash_vstream/utils.py +128 -0
- requirements.txt +1 -1
app.py
CHANGED
@@ -1,63 +1,173 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
import gradio as gr
|
3 |
+
from flash_vstream.serve.demo import Chat, title_markdown, block_css
|
4 |
+
from flash_vstream.constants import *
|
5 |
+
from flash_vstream.conversation import conv_templates, Conversation
|
6 |
+
import os
|
7 |
+
from PIL import Image
|
8 |
+
import tempfile
|
9 |
+
import imageio
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
|
13 |
+
model_path = "IVGSZ/Flash-VStream-7b"
|
14 |
+
load_8bit = False
|
15 |
+
load_4bit = False
|
16 |
+
|
17 |
+
def save_image_to_local(image):
|
18 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
|
19 |
+
image = Image.open(image)
|
20 |
+
image.save(filename)
|
21 |
+
return filename
|
22 |
+
|
23 |
+
|
24 |
+
def save_video_to_local(video_path):
|
25 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
|
26 |
+
shutil.copyfile(video_path, filename)
|
27 |
+
return filename
|
28 |
+
|
29 |
+
|
30 |
+
def generate(video, textbox_in, first_run, state, state_, images_tensor):
|
31 |
+
|
32 |
+
flag = 1
|
33 |
+
if not textbox_in:
|
34 |
+
if len(state_.messages) > 0:
|
35 |
+
textbox_in = state_.messages[-1][1]
|
36 |
+
state_.messages.pop(-1)
|
37 |
+
flag = 0
|
38 |
+
else:
|
39 |
+
return "Please enter instruction"
|
40 |
+
|
41 |
+
video = video if video else "none"
|
42 |
+
|
43 |
+
if type(state) is not Conversation:
|
44 |
+
state = conv_templates[conv_mode].copy()
|
45 |
+
state_ = conv_templates[conv_mode].copy()
|
46 |
+
images_tensor = []
|
47 |
+
|
48 |
+
first_run = False if len(state.messages) > 0 else True
|
49 |
+
|
50 |
+
text_en_in = textbox_in.replace("picture", "image")
|
51 |
+
|
52 |
+
image_processor = handler.image_processor
|
53 |
+
|
54 |
+
if os.path.exists(video):
|
55 |
+
video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH)
|
56 |
+
for img in video_tensor:
|
57 |
+
images_tensor.append(image_processor(img, return_tensors='pt')['pixel_values'][0].to(handler.model.device, dtype=torch.float16))
|
58 |
+
|
59 |
+
if os.path.exists(video):
|
60 |
+
text_en_in = DEFAULT_IMAGE_TOKEN * len(video_tensor) + '\n' + text_en_in
|
61 |
+
|
62 |
+
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
|
63 |
+
state_.messages[-1] = (state_.roles[1], text_en_out)
|
64 |
+
|
65 |
+
text_en_out = text_en_out.split('#')[0]
|
66 |
+
textbox_out = text_en_out
|
67 |
+
|
68 |
+
show_images = ""
|
69 |
+
if os.path.exists(video):
|
70 |
+
filename = save_video_to_local(video)
|
71 |
+
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
|
72 |
+
|
73 |
+
if flag:
|
74 |
+
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
|
75 |
+
state.append_message(state.roles[1], textbox_out)
|
76 |
+
|
77 |
+
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=None, interactive=True))
|
78 |
+
|
79 |
+
|
80 |
+
def regenerate(state, state_):
|
81 |
+
state.messages.pop(-1)
|
82 |
+
state_.messages.pop(-1)
|
83 |
+
if len(state.messages) > 0:
|
84 |
+
return state, state_, state.to_gradio_chatbot(), False
|
85 |
+
return (state, state_, state.to_gradio_chatbot(), True)
|
86 |
+
|
87 |
+
|
88 |
+
def clear_history(state, state_):
|
89 |
+
state = conv_templates[conv_mode].copy()
|
90 |
+
state_ = conv_templates[conv_mode].copy()
|
91 |
+
return (gr.update(value=None, interactive=True), \
|
92 |
+
gr.update(value=None, interactive=True),\
|
93 |
+
True, state, state_, state.to_gradio_chatbot(), [])
|
94 |
+
|
95 |
+
|
96 |
+
conv_mode = "simple"
|
97 |
+
handler = Chat(model_path, conv_mode=conv_mode, load_4bit=load_4bit, load_8bit=load_8bit)
|
98 |
+
if not os.path.exists("temp"):
|
99 |
+
os.makedirs("temp")
|
100 |
+
|
101 |
+
print(torch.cuda.memory_allocated())
|
102 |
+
print(torch.cuda.max_memory_allocated())
|
103 |
+
|
104 |
+
with gr.Blocks(title='Flash-VStream', theme=gr.themes.Soft(), css=block_css) as demo:
|
105 |
+
gr.Markdown(title_markdown)
|
106 |
+
state = gr.State()
|
107 |
+
state_ = gr.State()
|
108 |
+
first_run = gr.State()
|
109 |
+
images_tensor = gr.State()
|
110 |
+
|
111 |
+
with gr.Row():
|
112 |
+
with gr.Column(scale=3):
|
113 |
+
video = gr.Video(label="Input Video")
|
114 |
+
|
115 |
+
with gr.Column(scale=7):
|
116 |
+
chatbot = gr.Chatbot(label="Flash-VStream", bubble_full_width=True).style(height=700)
|
117 |
+
with gr.Row():
|
118 |
+
with gr.Column(scale=8):
|
119 |
+
textbox = gr.Textbox(show_label=False,
|
120 |
+
placeholder="Enter text and press Send",
|
121 |
+
container=False)
|
122 |
+
with gr.Column(scale=2, min_width=50):
|
123 |
+
submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
|
124 |
+
|
125 |
+
with gr.Row(visible=True) as button_row:
|
126 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
|
127 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
|
128 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
|
129 |
+
|
130 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
131 |
+
|
132 |
+
with gr.Row():
|
133 |
+
gr.Examples(
|
134 |
+
examples=[
|
135 |
+
[
|
136 |
+
f"{cur_dir}/examples/video2.mp4",
|
137 |
+
"Describe the video briefly.",
|
138 |
+
]
|
139 |
+
],
|
140 |
+
inputs=[video, textbox],
|
141 |
+
)
|
142 |
+
|
143 |
+
gr.Examples(
|
144 |
+
examples=[
|
145 |
+
[
|
146 |
+
f"{cur_dir}/examples/video4.mp4",
|
147 |
+
"What is the boy doing?",
|
148 |
+
]
|
149 |
+
],
|
150 |
+
inputs=[video, textbox],
|
151 |
+
)
|
152 |
+
|
153 |
+
gr.Examples(
|
154 |
+
examples=[
|
155 |
+
[
|
156 |
+
f"{cur_dir}/examples/video5.mp4",
|
157 |
+
"Why is this video funny?",
|
158 |
+
]
|
159 |
+
],
|
160 |
+
inputs=[video, textbox],
|
161 |
+
)
|
162 |
+
|
163 |
+
submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
|
164 |
+
|
165 |
+
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
|
166 |
+
generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
|
167 |
+
|
168 |
+
clear_btn.click(clear_history, [state, state_],
|
169 |
+
[video, textbox, first_run, state, state_, chatbot, images_tensor])
|
170 |
+
|
171 |
+
|
172 |
+
# app = gr.mount_gradio_app(app, demo, path="/")
|
173 |
+
demo.launch()
|
app_old.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
+
|
4 |
+
"""
|
5 |
+
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
6 |
+
"""
|
7 |
+
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
8 |
+
|
9 |
+
|
10 |
+
def respond(
|
11 |
+
message,
|
12 |
+
history: list[tuple[str, str]],
|
13 |
+
system_message,
|
14 |
+
max_tokens,
|
15 |
+
temperature,
|
16 |
+
top_p,
|
17 |
+
):
|
18 |
+
messages = [{"role": "system", "content": system_message}]
|
19 |
+
|
20 |
+
for val in history:
|
21 |
+
if val[0]:
|
22 |
+
messages.append({"role": "user", "content": val[0]})
|
23 |
+
if val[1]:
|
24 |
+
messages.append({"role": "assistant", "content": val[1]})
|
25 |
+
|
26 |
+
messages.append({"role": "user", "content": message})
|
27 |
+
|
28 |
+
response = ""
|
29 |
+
|
30 |
+
for message in client.chat_completion(
|
31 |
+
messages,
|
32 |
+
max_tokens=max_tokens,
|
33 |
+
stream=True,
|
34 |
+
temperature=temperature,
|
35 |
+
top_p=top_p,
|
36 |
+
):
|
37 |
+
token = message.choices[0].delta.content
|
38 |
+
|
39 |
+
response += token
|
40 |
+
yield response
|
41 |
+
|
42 |
+
"""
|
43 |
+
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
44 |
+
"""
|
45 |
+
demo = gr.ChatInterface(
|
46 |
+
respond,
|
47 |
+
additional_inputs=[
|
48 |
+
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
49 |
+
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
50 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
51 |
+
gr.Slider(
|
52 |
+
minimum=0.1,
|
53 |
+
maximum=1.0,
|
54 |
+
value=0.95,
|
55 |
+
step=0.05,
|
56 |
+
label="Top-p (nucleus sampling)",
|
57 |
+
),
|
58 |
+
],
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
demo.launch()
|
flash_vstream/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from flash_vstream.model import VStreamLlamaForCausalLM
|
flash_vstream/constants.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
4 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
5 |
+
|
6 |
+
LOGDIR = "."
|
7 |
+
|
8 |
+
# Model Constants
|
9 |
+
IGNORE_INDEX = -100
|
10 |
+
IMAGE_TOKEN_INDEX = -200
|
11 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
12 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
13 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
14 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
15 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
flash_vstream/conversation.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
import dataclasses
|
4 |
+
from enum import auto, Enum
|
5 |
+
from typing import List, Tuple
|
6 |
+
|
7 |
+
|
8 |
+
class SeparatorStyle(Enum):
|
9 |
+
"""Different separator style."""
|
10 |
+
SINGLE = auto()
|
11 |
+
TWO = auto()
|
12 |
+
MPT = auto()
|
13 |
+
PLAIN = auto()
|
14 |
+
LLAMA_2 = auto()
|
15 |
+
|
16 |
+
|
17 |
+
@dataclasses.dataclass
|
18 |
+
class Conversation:
|
19 |
+
"""A class that keeps all conversation history."""
|
20 |
+
system: str
|
21 |
+
roles: List[str]
|
22 |
+
messages: List[List[str]]
|
23 |
+
offset: int
|
24 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
25 |
+
sep: str = "###"
|
26 |
+
sep2: str = None
|
27 |
+
version: str = "Unknown"
|
28 |
+
|
29 |
+
skip_next: bool = False
|
30 |
+
|
31 |
+
def get_prompt(self):
|
32 |
+
messages = self.messages
|
33 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
34 |
+
messages = self.messages.copy()
|
35 |
+
init_role, init_msg = messages[0].copy()
|
36 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
37 |
+
if 'mmtag' in self.version:
|
38 |
+
messages[0] = (init_role, init_msg)
|
39 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
40 |
+
messages.insert(1, (self.roles[1], "Received."))
|
41 |
+
else:
|
42 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
43 |
+
|
44 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
45 |
+
ret = self.system + self.sep
|
46 |
+
for role, message in messages:
|
47 |
+
if message:
|
48 |
+
if type(message) is tuple:
|
49 |
+
message, _, _ = message
|
50 |
+
ret += role + ": " + message + self.sep
|
51 |
+
else:
|
52 |
+
ret += role + ":"
|
53 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
54 |
+
seps = [self.sep, self.sep2]
|
55 |
+
ret = self.system + seps[0]
|
56 |
+
for i, (role, message) in enumerate(messages):
|
57 |
+
if message:
|
58 |
+
if type(message) is tuple:
|
59 |
+
message, _, _ = message
|
60 |
+
ret += role + ": " + message + seps[i % 2]
|
61 |
+
else:
|
62 |
+
ret += role + ":"
|
63 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
64 |
+
ret = self.system + self.sep
|
65 |
+
for role, message in messages:
|
66 |
+
if message:
|
67 |
+
if type(message) is tuple:
|
68 |
+
message, _, _ = message
|
69 |
+
ret += role + message + self.sep
|
70 |
+
else:
|
71 |
+
ret += role
|
72 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
73 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
74 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
75 |
+
ret = ""
|
76 |
+
|
77 |
+
for i, (role, message) in enumerate(messages):
|
78 |
+
if i == 0:
|
79 |
+
assert message, "first message should not be none"
|
80 |
+
assert role == self.roles[0], "first message should come from user"
|
81 |
+
if message:
|
82 |
+
if type(message) is tuple:
|
83 |
+
message, _, _ = message
|
84 |
+
if i == 0: message = wrap_sys(self.system) + message
|
85 |
+
if i % 2 == 0:
|
86 |
+
message = wrap_inst(message)
|
87 |
+
ret += self.sep + message
|
88 |
+
else:
|
89 |
+
ret += " " + message + " " + self.sep2
|
90 |
+
else:
|
91 |
+
ret += ""
|
92 |
+
ret = ret.lstrip(self.sep)
|
93 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
94 |
+
seps = [self.sep, self.sep2]
|
95 |
+
ret = self.system
|
96 |
+
for i, (role, message) in enumerate(messages):
|
97 |
+
if message:
|
98 |
+
if type(message) is tuple:
|
99 |
+
message, _, _ = message
|
100 |
+
ret += message + seps[i % 2]
|
101 |
+
else:
|
102 |
+
ret += ""
|
103 |
+
else:
|
104 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
105 |
+
|
106 |
+
return ret
|
107 |
+
|
108 |
+
def append_message(self, role, message):
|
109 |
+
self.messages.append([role, message])
|
110 |
+
|
111 |
+
def get_images(self, return_pil=False):
|
112 |
+
images = []
|
113 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
114 |
+
if i % 2 == 0:
|
115 |
+
if type(msg) is tuple:
|
116 |
+
import base64
|
117 |
+
from io import BytesIO
|
118 |
+
from PIL import Image
|
119 |
+
msg, image, image_process_mode = msg
|
120 |
+
if image_process_mode == "Pad":
|
121 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
122 |
+
width, height = pil_img.size
|
123 |
+
if width == height:
|
124 |
+
return pil_img
|
125 |
+
elif width > height:
|
126 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
127 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
128 |
+
return result
|
129 |
+
else:
|
130 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
131 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
132 |
+
return result
|
133 |
+
image = expand2square(image)
|
134 |
+
elif image_process_mode in ["Default", "Crop"]:
|
135 |
+
pass
|
136 |
+
elif image_process_mode == "Resize":
|
137 |
+
image = image.resize((336, 336))
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
140 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
141 |
+
aspect_ratio = max_hw / min_hw
|
142 |
+
max_len, min_len = 800, 400
|
143 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
144 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
145 |
+
W, H = image.size
|
146 |
+
if longest_edge != max(image.size):
|
147 |
+
if H > W:
|
148 |
+
H, W = longest_edge, shortest_edge
|
149 |
+
else:
|
150 |
+
H, W = shortest_edge, longest_edge
|
151 |
+
image = image.resize((W, H))
|
152 |
+
if return_pil:
|
153 |
+
images.append(image)
|
154 |
+
else:
|
155 |
+
buffered = BytesIO()
|
156 |
+
image.save(buffered, format="PNG")
|
157 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
158 |
+
images.append(img_b64_str)
|
159 |
+
return images
|
160 |
+
|
161 |
+
def to_gradio_chatbot(self):
|
162 |
+
ret = []
|
163 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
164 |
+
if i % 2 == 0:
|
165 |
+
if type(msg) is tuple:
|
166 |
+
import base64
|
167 |
+
from io import BytesIO
|
168 |
+
msg, image, image_process_mode = msg
|
169 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
170 |
+
aspect_ratio = max_hw / min_hw
|
171 |
+
max_len, min_len = 800, 400
|
172 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
173 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
174 |
+
W, H = image.size
|
175 |
+
if H > W:
|
176 |
+
H, W = longest_edge, shortest_edge
|
177 |
+
else:
|
178 |
+
H, W = shortest_edge, longest_edge
|
179 |
+
image = image.resize((W, H))
|
180 |
+
buffered = BytesIO()
|
181 |
+
image.save(buffered, format="JPEG")
|
182 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
183 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
184 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
185 |
+
ret.append([msg, None])
|
186 |
+
else:
|
187 |
+
ret.append([msg, None])
|
188 |
+
else:
|
189 |
+
ret[-1][-1] = msg
|
190 |
+
return ret
|
191 |
+
|
192 |
+
def copy(self):
|
193 |
+
return Conversation(
|
194 |
+
system=self.system,
|
195 |
+
roles=self.roles,
|
196 |
+
messages=[[x, y] for x, y in self.messages],
|
197 |
+
offset=self.offset,
|
198 |
+
sep_style=self.sep_style,
|
199 |
+
sep=self.sep,
|
200 |
+
sep2=self.sep2,
|
201 |
+
version=self.version)
|
202 |
+
|
203 |
+
def dict(self):
|
204 |
+
if len(self.get_images()) > 0:
|
205 |
+
return {
|
206 |
+
"system": self.system,
|
207 |
+
"roles": self.roles,
|
208 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
209 |
+
"offset": self.offset,
|
210 |
+
"sep": self.sep,
|
211 |
+
"sep2": self.sep2,
|
212 |
+
}
|
213 |
+
return {
|
214 |
+
"system": self.system,
|
215 |
+
"roles": self.roles,
|
216 |
+
"messages": self.messages,
|
217 |
+
"offset": self.offset,
|
218 |
+
"sep": self.sep,
|
219 |
+
"sep2": self.sep2,
|
220 |
+
}
|
221 |
+
|
222 |
+
|
223 |
+
conv_vicuna_v0 = Conversation(
|
224 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
225 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
226 |
+
roles=("Human", "Assistant"),
|
227 |
+
messages=(
|
228 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
229 |
+
("Assistant",
|
230 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
231 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
232 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
233 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
234 |
+
"renewable and non-renewable energy sources:\n"
|
235 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
236 |
+
"energy sources are finite and will eventually run out.\n"
|
237 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
238 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
239 |
+
"and other negative effects.\n"
|
240 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
241 |
+
"have lower operational costs than non-renewable sources.\n"
|
242 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
243 |
+
"locations than non-renewable sources.\n"
|
244 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
245 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
246 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
247 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
248 |
+
),
|
249 |
+
offset=2,
|
250 |
+
sep_style=SeparatorStyle.SINGLE,
|
251 |
+
sep="###",
|
252 |
+
)
|
253 |
+
|
254 |
+
conv_vicuna_v1 = Conversation(
|
255 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
256 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
257 |
+
roles=("USER", "ASSISTANT"),
|
258 |
+
version="v1",
|
259 |
+
messages=(),
|
260 |
+
offset=0,
|
261 |
+
sep_style=SeparatorStyle.TWO,
|
262 |
+
sep=" ",
|
263 |
+
sep2="</s>",
|
264 |
+
)
|
265 |
+
|
266 |
+
conv_vicuna_v1_mcq = Conversation(
|
267 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
268 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
|
269 |
+
"The assistant should give the number of correct answer.",
|
270 |
+
roles=("USER", "ASSISTANT"),
|
271 |
+
version="v1",
|
272 |
+
messages=(),
|
273 |
+
offset=0,
|
274 |
+
sep_style=SeparatorStyle.TWO,
|
275 |
+
sep=" ",
|
276 |
+
sep2="</s>",
|
277 |
+
)
|
278 |
+
|
279 |
+
conv_tiny = Conversation(
|
280 |
+
system="""<|system|>
|
281 |
+
A conversation between a user and an AI assistant. The assistant gives short and honest answers.""",
|
282 |
+
roles=("<|user|>\n", "<|assistant|>\n"),
|
283 |
+
version="mpt",
|
284 |
+
messages=(),
|
285 |
+
offset=0,
|
286 |
+
sep_style=SeparatorStyle.MPT,
|
287 |
+
sep="</s>",
|
288 |
+
)
|
289 |
+
|
290 |
+
conv_llama_2 = Conversation(
|
291 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
292 |
+
|
293 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
294 |
+
roles=("USER", "ASSISTANT"),
|
295 |
+
version="llama_v2",
|
296 |
+
messages=(),
|
297 |
+
offset=0,
|
298 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
299 |
+
sep="<s>",
|
300 |
+
sep2="</s>",
|
301 |
+
)
|
302 |
+
|
303 |
+
conv_mpt = Conversation(
|
304 |
+
system="""<|im_start|>system
|
305 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
306 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
307 |
+
version="mpt",
|
308 |
+
messages=(),
|
309 |
+
offset=0,
|
310 |
+
sep_style=SeparatorStyle.MPT,
|
311 |
+
sep="<|im_end|>",
|
312 |
+
)
|
313 |
+
|
314 |
+
conv_plain = Conversation(
|
315 |
+
system="",
|
316 |
+
roles=("", ""),
|
317 |
+
messages=(
|
318 |
+
),
|
319 |
+
offset=0,
|
320 |
+
sep_style=SeparatorStyle.PLAIN,
|
321 |
+
sep="\n",
|
322 |
+
)
|
323 |
+
|
324 |
+
|
325 |
+
default_conversation = conv_vicuna_v1
|
326 |
+
conv_templates = {
|
327 |
+
"default": conv_vicuna_v0,
|
328 |
+
"v0": conv_vicuna_v0,
|
329 |
+
"v1": conv_vicuna_v1,
|
330 |
+
"vicuna_v1": conv_vicuna_v1,
|
331 |
+
"llama_2": conv_llama_2,
|
332 |
+
"plain": conv_plain,
|
333 |
+
}
|
334 |
+
|
335 |
+
|
336 |
+
if __name__ == "__main__":
|
337 |
+
print(default_conversation.get_prompt())
|
flash_vstream/eval_video/eval_activitynet_qa.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
import os
|
4 |
+
import ast
|
5 |
+
import json
|
6 |
+
import openai
|
7 |
+
import argparse
|
8 |
+
from tqdm import tqdm
|
9 |
+
from time import sleep
|
10 |
+
from collections import defaultdict
|
11 |
+
from multiprocessing.pool import Pool
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
|
15 |
+
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
|
16 |
+
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
|
17 |
+
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
|
18 |
+
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
|
19 |
+
parser.add_argument("--num_chunks", default=1, type=int, help="Result splits")
|
20 |
+
parser.add_argument("--api_key", required=True, type=str, help="OpenAI API key")
|
21 |
+
parser.add_argument("--api_type", default=None, type=str, help="OpenAI API type")
|
22 |
+
parser.add_argument("--api_version", default=None, type=str, help="OpenAI API version")
|
23 |
+
parser.add_argument("--api_base", default=None, type=str, help="OpenAI API base")
|
24 |
+
args = parser.parse_args()
|
25 |
+
return args
|
26 |
+
|
27 |
+
|
28 |
+
def annotate(prediction_set, caption_files, output_dir):
|
29 |
+
"""
|
30 |
+
Evaluates question and answer pairs using GPT-3
|
31 |
+
Returns a score for correctness.
|
32 |
+
"""
|
33 |
+
for file in tqdm(caption_files):
|
34 |
+
key = file[:-5] # Strip file extension
|
35 |
+
qa_set = prediction_set[key]
|
36 |
+
question = qa_set['q']
|
37 |
+
answer = qa_set['a']
|
38 |
+
pred = qa_set['pred']
|
39 |
+
try:
|
40 |
+
# Compute the correctness score
|
41 |
+
completion = openai.ChatCompletion.create(
|
42 |
+
model="gpt-3.5-turbo",
|
43 |
+
messages=[
|
44 |
+
{
|
45 |
+
"role": "system",
|
46 |
+
"content":
|
47 |
+
"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
|
48 |
+
"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
|
49 |
+
"------"
|
50 |
+
"##INSTRUCTIONS: "
|
51 |
+
"- Focus on the meaningful match between the predicted answer and the correct answer.\n"
|
52 |
+
"- Consider synonyms or paraphrases as valid matches.\n"
|
53 |
+
"- Evaluate the correctness of the prediction compared to the answer."
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"role": "user",
|
57 |
+
"content":
|
58 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
59 |
+
f"Question: {question}\n"
|
60 |
+
f"Correct Answer: {answer}\n"
|
61 |
+
f"Predicted Answer: {pred}\n\n"
|
62 |
+
"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
|
63 |
+
"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
|
64 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
65 |
+
"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
|
66 |
+
}
|
67 |
+
],
|
68 |
+
temperature=0.002
|
69 |
+
)
|
70 |
+
# Convert response to a Python dictionary.
|
71 |
+
response_message = completion["choices"][0]["message"]["content"]
|
72 |
+
response_dict = ast.literal_eval(response_message)
|
73 |
+
result_qa_pair = [response_dict, qa_set]
|
74 |
+
|
75 |
+
# Save the question-answer pairs to a json file.
|
76 |
+
with open(f"{output_dir}/{key}.json", "w") as f:
|
77 |
+
json.dump(result_qa_pair, f)
|
78 |
+
sleep(0.5)
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error processing file '{key}': {e}")
|
82 |
+
sleep(1)
|
83 |
+
|
84 |
+
|
85 |
+
def main():
|
86 |
+
"""
|
87 |
+
Main function to control the flow of the program.
|
88 |
+
"""
|
89 |
+
# Parse arguments.
|
90 |
+
args = parse_args()
|
91 |
+
|
92 |
+
if args.num_chunks > 1:
|
93 |
+
pred_contents = []
|
94 |
+
for _idx in range(args.num_chunks):
|
95 |
+
file = os.path.join(args.pred_path, f"{args.num_chunks}_{_idx}.json")
|
96 |
+
pred_contents += [json.loads(line) for line in open(file)]
|
97 |
+
|
98 |
+
else:
|
99 |
+
file = os.path.join(args.pred_path, f"pred.json")
|
100 |
+
pred_contents = [json.loads(line) for line in open(file)]
|
101 |
+
|
102 |
+
# Dictionary to store the count of occurrences for each video_id
|
103 |
+
video_id_counts = {}
|
104 |
+
new_pred_contents = []
|
105 |
+
|
106 |
+
# Iterate through each sample in pred_contents
|
107 |
+
for sample in pred_contents:
|
108 |
+
video_id = sample['id']
|
109 |
+
if video_id in video_id_counts:
|
110 |
+
video_id_counts[video_id] += 1
|
111 |
+
else:
|
112 |
+
video_id_counts[video_id] = 0
|
113 |
+
|
114 |
+
# Create a new sample with the modified key
|
115 |
+
new_sample = sample
|
116 |
+
new_sample['id'] = f"{video_id}_{video_id_counts[video_id]}"
|
117 |
+
new_pred_contents.append(new_sample)
|
118 |
+
|
119 |
+
# Generating list of id's and corresponding files
|
120 |
+
id_list = [x['id'] for x in new_pred_contents]
|
121 |
+
caption_files = [f"{id}.json" for id in id_list]
|
122 |
+
|
123 |
+
output_dir = args.output_dir
|
124 |
+
# Generate output directory if not exists.
|
125 |
+
if not os.path.exists(output_dir):
|
126 |
+
os.makedirs(output_dir)
|
127 |
+
|
128 |
+
# Preparing dictionary of question-answer sets
|
129 |
+
prediction_set = {}
|
130 |
+
for sample in new_pred_contents:
|
131 |
+
id = sample['id']
|
132 |
+
question = sample['question']
|
133 |
+
answer = sample['answer']
|
134 |
+
pred = sample['pred']
|
135 |
+
qa_set = {"q": question, "a": answer, "pred": pred, "a_type": sample['answer_type'] if 'answer_type' in sample else None}
|
136 |
+
prediction_set[id] = qa_set
|
137 |
+
|
138 |
+
# Set the OpenAI API key.
|
139 |
+
openai.api_key = args.api_key # Your API key here
|
140 |
+
if args.api_type:
|
141 |
+
openai.api_type = args.api_type
|
142 |
+
if args.api_version:
|
143 |
+
openai.api_version = args.api_version
|
144 |
+
if args.api_base:
|
145 |
+
openai.api_base = args.api_base # Your API base here
|
146 |
+
num_tasks = args.num_tasks
|
147 |
+
|
148 |
+
# While loop to ensure that all captions are processed.
|
149 |
+
incomplete_lengths = []
|
150 |
+
for _ in range(100):
|
151 |
+
try:
|
152 |
+
# Files that have not been processed yet.
|
153 |
+
completed_files = os.listdir(output_dir)
|
154 |
+
print(f"completed_files: {len(completed_files)}")
|
155 |
+
|
156 |
+
# Files that have not been processed yet.
|
157 |
+
incomplete_files = [f for f in caption_files if f not in completed_files]
|
158 |
+
print(f"incomplete_files: {len(incomplete_files)}")
|
159 |
+
incomplete_lengths.append(len(incomplete_files))
|
160 |
+
if len(incomplete_lengths) > 5 and len(set(incomplete_lengths[-5:])) <= 1:
|
161 |
+
print(f"incomplete_lengths: {incomplete_lengths}")
|
162 |
+
print(f"incomplete_files: {incomplete_files}")
|
163 |
+
print(f"completed_files: {completed_files}")
|
164 |
+
print(f"failed for 5 times, break")
|
165 |
+
break
|
166 |
+
|
167 |
+
# Break the loop when there are no incomplete files
|
168 |
+
if len(incomplete_files) == 0:
|
169 |
+
break
|
170 |
+
if len(incomplete_files) <= num_tasks:
|
171 |
+
num_tasks = 1
|
172 |
+
|
173 |
+
# Split tasks into parts.
|
174 |
+
part_len = len(incomplete_files) // num_tasks
|
175 |
+
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
|
176 |
+
task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
|
177 |
+
|
178 |
+
# Use a pool of workers to process the files in parallel.
|
179 |
+
with Pool() as pool:
|
180 |
+
pool.starmap(annotate, task_args)
|
181 |
+
|
182 |
+
except Exception as e:
|
183 |
+
print(f"Error: {e}")
|
184 |
+
|
185 |
+
# Combine all the processed files into one
|
186 |
+
combined_contents = {}
|
187 |
+
json_path = args.output_json
|
188 |
+
|
189 |
+
# Iterate through json files
|
190 |
+
for file_name in os.listdir(output_dir):
|
191 |
+
if file_name.endswith(".json"):
|
192 |
+
file_path = os.path.join(output_dir, file_name)
|
193 |
+
with open(file_path, "r") as json_file:
|
194 |
+
content = json.load(json_file)
|
195 |
+
assert 'pred' in content[0], f"Error: {file_name} don't has key=pred"
|
196 |
+
assert 'score' in content[0], f"Error: {file_name} don't has key=score"
|
197 |
+
combined_contents[file_name[:-5]] = content
|
198 |
+
|
199 |
+
# Write combined content to a json file
|
200 |
+
with open(json_path, "w") as json_file:
|
201 |
+
json.dump(combined_contents, json_file)
|
202 |
+
print("All evaluation completed!")
|
203 |
+
|
204 |
+
class ScoreMeter:
|
205 |
+
def __init__(self):
|
206 |
+
self.score_sum = 0
|
207 |
+
self.count = 0
|
208 |
+
self.yes_count = 0
|
209 |
+
self.no_count = 0
|
210 |
+
self.score_dict = {'yes': defaultdict(int), 'no': defaultdict(int)}
|
211 |
+
|
212 |
+
def add_score(self, score, pred):
|
213 |
+
self.score_sum += score
|
214 |
+
self.count += 1
|
215 |
+
pred_lower = pred.lower()
|
216 |
+
if 'yes' in pred_lower:
|
217 |
+
self.yes_count += 1
|
218 |
+
self.score_dict['yes'][score] += 1
|
219 |
+
elif 'no' in pred_lower:
|
220 |
+
self.no_count += 1
|
221 |
+
self.score_dict['no'][score] += 1
|
222 |
+
|
223 |
+
def get_average_score(self):
|
224 |
+
res = (self.score_sum / self.count) if self.count else 0
|
225 |
+
return f"{res:.6f}"
|
226 |
+
|
227 |
+
def get_accuracy(self, response_type):
|
228 |
+
if response_type == 'yes':
|
229 |
+
res = (self.yes_count / self.count) if self.count else 0
|
230 |
+
elif response_type == 'no':
|
231 |
+
res = (self.no_count / self.count) if self.count else 0
|
232 |
+
else:
|
233 |
+
res = 0
|
234 |
+
return f"{res:.6f}"
|
235 |
+
|
236 |
+
meter_dic = {'total': ScoreMeter()}
|
237 |
+
for key, result in combined_contents.items():
|
238 |
+
# Computing score
|
239 |
+
score_match = result[0]['score']
|
240 |
+
score = int(score_match)
|
241 |
+
pred = result[0]['pred']
|
242 |
+
|
243 |
+
meter_dic["total"].add_score(score, pred)
|
244 |
+
if 'a_type' in result[1] and result[1]['a_type'] is not None:
|
245 |
+
typ = str(result[1]['a_type'])
|
246 |
+
if typ not in meter_dic:
|
247 |
+
meter_dic[typ] = ScoreMeter()
|
248 |
+
meter_dic[typ].add_score(score, pred)
|
249 |
+
|
250 |
+
if 'next' in args.output_dir:
|
251 |
+
typ = typ[0]
|
252 |
+
if typ not in meter_dic:
|
253 |
+
meter_dic[typ] = ScoreMeter()
|
254 |
+
meter_dic[typ].add_score(score, pred)
|
255 |
+
|
256 |
+
csv_dic = {'acc': meter_dic["total"].get_accuracy('yes'), 'score': meter_dic["total"].get_average_score()}
|
257 |
+
|
258 |
+
output = ""
|
259 |
+
output += "Yes count: " + str(meter_dic["total"].yes_count) + "\n"
|
260 |
+
output += "No count: " + str(meter_dic["total"].no_count) + "\n"
|
261 |
+
output += "Accuracy: " + str(meter_dic["total"].get_accuracy('yes')) + "\n"
|
262 |
+
output += "Average score: " + str(meter_dic["total"].get_average_score()) + "\n"
|
263 |
+
output += "\n"
|
264 |
+
output += "Total Score Yes/No distribution:\n"
|
265 |
+
for key, value in meter_dic["total"].score_dict.items():
|
266 |
+
output += f"{key}:\n"
|
267 |
+
for k in range(0, 6):
|
268 |
+
v = value[k]
|
269 |
+
output += f"{k}: {v}\n"
|
270 |
+
output += "\n"
|
271 |
+
output += "Answer Type Score distribution:\n"
|
272 |
+
output += 'Type, Accuracy, Avg_score\n'
|
273 |
+
key_list = sorted([k for k in meter_dic.keys()])
|
274 |
+
for key in key_list:
|
275 |
+
output += f"{key}, {meter_dic[key].get_accuracy('yes')}, {meter_dic[key].get_average_score()}\n"
|
276 |
+
csv_dic[key] = meter_dic[key].get_accuracy('yes')
|
277 |
+
|
278 |
+
output += "\n"
|
279 |
+
for k in csv_dic.keys():
|
280 |
+
output += f"{k}, "
|
281 |
+
output = output.rstrip(', ') # Remove the trailing comma and space
|
282 |
+
output += "\n"
|
283 |
+
|
284 |
+
for k in csv_dic.keys():
|
285 |
+
output += str(csv_dic[k]) + ", "
|
286 |
+
output = output.rstrip(', ') # Remove the trailing comma and space
|
287 |
+
output += "\n"
|
288 |
+
|
289 |
+
print(output)
|
290 |
+
args.output_csv = args.output_json.replace(".json", ".csv")
|
291 |
+
with open(args.output_csv, 'w') as f:
|
292 |
+
f.write(output)
|
293 |
+
|
294 |
+
if __name__ == "__main__":
|
295 |
+
main()
|
296 |
+
|
flash_vstream/eval_video/eval_any_dataset_features.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Flash-VStream Authors
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import argparse
|
17 |
+
import subprocess
|
18 |
+
import multiprocessing
|
19 |
+
|
20 |
+
def exec(cmd, sub=False, device=None):
|
21 |
+
print(f'exec: {cmd}')
|
22 |
+
if not sub:
|
23 |
+
if isinstance(cmd, list):
|
24 |
+
cmd = ' '.join(cmd)
|
25 |
+
os.system(cmd)
|
26 |
+
else:
|
27 |
+
my_env = os.environ.copy()
|
28 |
+
my_env["CUDA_VISIBLE_DEVICES"] = device
|
29 |
+
subprocess.run(cmd, env=my_env)
|
30 |
+
|
31 |
+
# multi gpu, feature
|
32 |
+
def eval_msvd(args):
|
33 |
+
model_path = args.model_path
|
34 |
+
num_chunks = args.num_chunks
|
35 |
+
if not args.only_eval:
|
36 |
+
processes = []
|
37 |
+
for idx in range(0, num_chunks):
|
38 |
+
cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
|
39 |
+
"--model-path", model_path,
|
40 |
+
"--video_dir", "./data/eval_video/MSVD-QA/video_features",
|
41 |
+
"--gt_file", "./data/eval_video/MSVD-QA/test_qa.json",
|
42 |
+
"--output_dir", os.path.join(model_path, "evaluation", "msvd"),
|
43 |
+
"--output_name", "pred",
|
44 |
+
"--num-chunks", str(num_chunks),
|
45 |
+
"--chunk-idx", str(idx),
|
46 |
+
"--conv-mode", "vicuna_v1"]
|
47 |
+
p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
|
48 |
+
processes.append(p)
|
49 |
+
p.start() # 启动子进程
|
50 |
+
for p in processes:
|
51 |
+
p.join()
|
52 |
+
cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
|
53 |
+
"--pred_path", os.path.join(model_path, "evaluation", "msvd"),
|
54 |
+
"--output_dir", os.path.join(model_path, "evaluation", "msvd", "results"),
|
55 |
+
"--output_json", os.path.join(model_path, "evaluation", "msvd", "results.json"),
|
56 |
+
"--num_chunks", str(num_chunks),
|
57 |
+
"--num_tasks", "16",
|
58 |
+
"--api_key", args.api_key,
|
59 |
+
"--api_base", args.api_base,
|
60 |
+
"--api_type", args.api_type,
|
61 |
+
"--api_version", args.api_version,
|
62 |
+
]
|
63 |
+
exec(cmd)
|
64 |
+
|
65 |
+
# multi gpu, feature
|
66 |
+
def eval_msrvtt(args):
|
67 |
+
model_path = args.model_path
|
68 |
+
num_chunks = args.num_chunks
|
69 |
+
if not args.only_eval:
|
70 |
+
processes = []
|
71 |
+
for idx in range(0, num_chunks):
|
72 |
+
cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
|
73 |
+
"--model-path", model_path,
|
74 |
+
"--video_dir", "./data/eval_video/MSRVTT-QA/video_features",
|
75 |
+
"--gt_file", "./data/eval_video/MSRVTT-QA/test_qa.json",
|
76 |
+
"--output_dir", os.path.join(model_path, "evaluation", "msrvtt"),
|
77 |
+
"--output_name", "pred",
|
78 |
+
"--num-chunks", str(num_chunks),
|
79 |
+
"--chunk-idx", str(idx),
|
80 |
+
"--conv-mode", "vicuna_v1"]
|
81 |
+
p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
|
82 |
+
processes.append(p)
|
83 |
+
p.start() # 启动子进程
|
84 |
+
for p in processes:
|
85 |
+
p.join()
|
86 |
+
cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
|
87 |
+
"--pred_path", os.path.join(model_path, "evaluation", "msrvtt"),
|
88 |
+
"--output_dir", os.path.join(model_path, "evaluation", "msrvtt", "results"),
|
89 |
+
"--output_json", os.path.join(model_path, "evaluation", "msrvtt", "results.json"),
|
90 |
+
"--num_chunks", str(num_chunks),
|
91 |
+
"--num_tasks", "16",
|
92 |
+
"--api_key", args.api_key,
|
93 |
+
"--api_base", args.api_base,
|
94 |
+
"--api_type", args.api_type,
|
95 |
+
"--api_version", args.api_version,
|
96 |
+
]
|
97 |
+
exec(cmd)
|
98 |
+
|
99 |
+
# multi gpu, feature
|
100 |
+
def eval_actnet(args):
|
101 |
+
model_path = args.model_path
|
102 |
+
num_chunks = args.num_chunks
|
103 |
+
if not args.only_eval:
|
104 |
+
processes = []
|
105 |
+
for idx in range(0, num_chunks):
|
106 |
+
cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
|
107 |
+
"--model-path", model_path,
|
108 |
+
"--video_dir", "./data/eval_video/ActivityNet-QA/video_features",
|
109 |
+
"--gt_file", "./data/eval_video/ActivityNet-QA/test_qa.json",
|
110 |
+
"--output_dir", os.path.join(model_path, "evaluation", "actnet"),
|
111 |
+
"--output_name", "pred",
|
112 |
+
"--num-chunks", str(num_chunks),
|
113 |
+
"--chunk-idx", str(idx),
|
114 |
+
"--conv-mode", "vicuna_v1",
|
115 |
+
]
|
116 |
+
|
117 |
+
p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
|
118 |
+
processes.append(p)
|
119 |
+
p.start() # 启动子进程
|
120 |
+
for p in processes:
|
121 |
+
p.join()
|
122 |
+
cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
|
123 |
+
"--pred_path", os.path.join(model_path, "evaluation", "actnet"),
|
124 |
+
"--output_dir", os.path.join(model_path, "evaluation", "actnet", "results"),
|
125 |
+
"--output_json", os.path.join(model_path, "evaluation", "actnet", "results.json"),
|
126 |
+
"--num_chunks", str(num_chunks),
|
127 |
+
"--num_tasks", "16",
|
128 |
+
"--api_key", args.api_key,
|
129 |
+
"--api_base", args.api_base,
|
130 |
+
"--api_type", args.api_type,
|
131 |
+
"--api_version", args.api_version,
|
132 |
+
]
|
133 |
+
exec(cmd)
|
134 |
+
|
135 |
+
# multi gpu, feature
|
136 |
+
def eval_nextoe(args): # follow msvd format, OE follow actnet
|
137 |
+
model_path = args.model_path
|
138 |
+
num_chunks = args.num_chunks
|
139 |
+
if not args.only_eval:
|
140 |
+
processes = []
|
141 |
+
for idx in range(0, num_chunks):
|
142 |
+
cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
|
143 |
+
"--model-path", model_path,
|
144 |
+
"--video_dir", "./data/eval_video/nextoe/video_features",
|
145 |
+
"--gt_file", "./data/eval_video/nextoe/test_qa.json",
|
146 |
+
"--output_dir", os.path.join(model_path, "evaluation", "nextoe"),
|
147 |
+
"--output_name", "pred",
|
148 |
+
"--num-chunks", str(num_chunks),
|
149 |
+
"--chunk-idx", str(idx),
|
150 |
+
"--conv-mode", "vicuna_v1",
|
151 |
+
]
|
152 |
+
|
153 |
+
p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
|
154 |
+
processes.append(p)
|
155 |
+
p.start() # 启动子进程
|
156 |
+
for p in processes:
|
157 |
+
p.join()
|
158 |
+
cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
|
159 |
+
"--pred_path", os.path.join(model_path, "evaluation", "nextoe"),
|
160 |
+
"--output_dir", os.path.join(model_path, "evaluation", "nextoe", "results"),
|
161 |
+
"--output_json", os.path.join(model_path, "evaluation", "nextoe", "results.json"),
|
162 |
+
"--num_chunks", str(num_chunks),
|
163 |
+
"--num_tasks", "16",
|
164 |
+
"--api_key", args.api_key,
|
165 |
+
"--api_base", args.api_base,
|
166 |
+
"--api_type", args.api_type,
|
167 |
+
"--api_version", args.api_version,
|
168 |
+
]
|
169 |
+
exec(cmd)
|
170 |
+
|
171 |
+
# multi gpu, feature
|
172 |
+
def eval_vsmovienet(args): # follow msvd format
|
173 |
+
model_path = args.model_path
|
174 |
+
num_chunks = args.num_chunks
|
175 |
+
if not args.only_eval:
|
176 |
+
processes = []
|
177 |
+
for idx in range(0, num_chunks):
|
178 |
+
cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
|
179 |
+
"--model-path", model_path,
|
180 |
+
"--video_dir", "./data/eval_video/vstream/movienet_video_features",
|
181 |
+
"--gt_file", "./data/eval_video/vstream/test_qa_movienet.json",
|
182 |
+
"--output_dir", os.path.join(model_path, "evaluation", "vsmovienet"),
|
183 |
+
"--output_name", "pred",
|
184 |
+
"--num-chunks", str(num_chunks),
|
185 |
+
"--chunk-idx", str(idx),
|
186 |
+
"--conv-mode", "vicuna_v1",
|
187 |
+
]
|
188 |
+
|
189 |
+
p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
|
190 |
+
processes.append(p)
|
191 |
+
p.start() # 启动子进程
|
192 |
+
for p in processes:
|
193 |
+
p.join()
|
194 |
+
cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
|
195 |
+
"--pred_path", os.path.join(model_path, "evaluation", "vsmovienet"),
|
196 |
+
"--output_dir", os.path.join(model_path, "evaluation", "vsmovienet", "results"),
|
197 |
+
"--output_json", os.path.join(model_path, "evaluation", "vsmovienet", "results.json"),
|
198 |
+
"--num_chunks", str(num_chunks),
|
199 |
+
"--num_tasks", "16",
|
200 |
+
"--api_key", args.api_key,
|
201 |
+
"--api_base", args.api_base,
|
202 |
+
"--api_type", args.api_type,
|
203 |
+
"--api_version", args.api_version,
|
204 |
+
]
|
205 |
+
exec(cmd)
|
206 |
+
|
207 |
+
# multi gpu, feature
|
208 |
+
def eval_vsego4d(args): # follow msvd format
|
209 |
+
model_path = args.model_path
|
210 |
+
num_chunks = args.num_chunks
|
211 |
+
if not args.only_eval:
|
212 |
+
processes = []
|
213 |
+
for idx in range(0, num_chunks):
|
214 |
+
cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
|
215 |
+
"--model-path", model_path,
|
216 |
+
"--video_dir", "./data/eval_video/vstream/ego4d_video_features",
|
217 |
+
"--gt_file", "./data/eval_video/vstream/test_qa_ego4d.json",
|
218 |
+
"--output_dir", os.path.join(model_path, "evaluation", "vsego4d"),
|
219 |
+
"--output_name", "pred",
|
220 |
+
"--num-chunks", str(num_chunks),
|
221 |
+
"--chunk-idx", str(idx),
|
222 |
+
"--conv-mode", "vicuna_v1",
|
223 |
+
]
|
224 |
+
|
225 |
+
p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
|
226 |
+
processes.append(p)
|
227 |
+
p.start() # 启动子进程
|
228 |
+
for p in processes:
|
229 |
+
p.join()
|
230 |
+
cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
|
231 |
+
"--pred_path", os.path.join(model_path, "evaluation", "vsego4d"),
|
232 |
+
"--output_dir", os.path.join(model_path, "evaluation", "vsego4d", "results"),
|
233 |
+
"--output_json", os.path.join(model_path, "evaluation", "vsego4d", "results.json"),
|
234 |
+
"--num_chunks", str(num_chunks),
|
235 |
+
"--num_tasks", "16",
|
236 |
+
"--api_key", args.api_key,
|
237 |
+
"--api_base", args.api_base,
|
238 |
+
"--api_type", args.api_type,
|
239 |
+
"--api_version", args.api_version,
|
240 |
+
]
|
241 |
+
exec(cmd)
|
242 |
+
|
243 |
+
# multi gpu, feature
|
244 |
+
def eval_realtime_vsmovienet(args): # follow msvd format
|
245 |
+
model_path = args.model_path
|
246 |
+
num_chunks = args.num_chunks
|
247 |
+
if not args.only_eval:
|
248 |
+
processes = []
|
249 |
+
for idx in range(0, num_chunks):
|
250 |
+
cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
|
251 |
+
"--model-path", model_path,
|
252 |
+
"--video_dir", "./data/eval_video/vstream-realtime/movienet_video_features",
|
253 |
+
"--gt_file", "./data/eval_video/vstream-realtime/test_qa_movienet.json",
|
254 |
+
"--output_dir", os.path.join(model_path, "evaluation", "realtime_vsmovienet"),
|
255 |
+
"--output_name", "pred",
|
256 |
+
"--num-chunks", str(num_chunks),
|
257 |
+
"--chunk-idx", str(idx),
|
258 |
+
"--conv-mode", "vicuna_v1",
|
259 |
+
]
|
260 |
+
|
261 |
+
p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
|
262 |
+
processes.append(p)
|
263 |
+
p.start() # 启动子进程
|
264 |
+
for p in processes:
|
265 |
+
p.join()
|
266 |
+
cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
|
267 |
+
"--pred_path", os.path.join(model_path, "evaluation", "realtime_vsmovienet"),
|
268 |
+
"--output_dir", os.path.join(model_path, "evaluation", "realtime_vsmovienet", "results"),
|
269 |
+
"--output_json", os.path.join(model_path, "evaluation", "realtime_vsmovienet", "results.json"),
|
270 |
+
"--num_chunks", str(num_chunks),
|
271 |
+
"--num_tasks", "16",
|
272 |
+
"--api_key", args.api_key,
|
273 |
+
"--api_base", args.api_base,
|
274 |
+
"--api_type", args.api_type,
|
275 |
+
"--api_version", args.api_version,
|
276 |
+
]
|
277 |
+
exec(cmd)
|
278 |
+
|
279 |
+
# multi gpu, feature
|
280 |
+
def eval_realtime_vsego4d(args): # follow msvd format
|
281 |
+
model_path = args.model_path
|
282 |
+
num_chunks = args.num_chunks
|
283 |
+
if not args.only_eval:
|
284 |
+
processes = []
|
285 |
+
for idx in range(0, num_chunks):
|
286 |
+
cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
|
287 |
+
"--model-path", model_path,
|
288 |
+
"--video_dir", "./data/eval_video/vstream-realtime/ego4d_video_features",
|
289 |
+
"--gt_file", "./data/eval_video/vstream-realtime/test_qa_ego4d.json",
|
290 |
+
"--output_dir", os.path.join(model_path, "evaluation", "realtime_vsego4d"),
|
291 |
+
"--output_name", "pred",
|
292 |
+
"--num-chunks", str(num_chunks),
|
293 |
+
"--chunk-idx", str(idx),
|
294 |
+
"--conv-mode", "vicuna_v1",
|
295 |
+
]
|
296 |
+
|
297 |
+
p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
|
298 |
+
processes.append(p)
|
299 |
+
p.start() # 启动子进程
|
300 |
+
for p in processes:
|
301 |
+
p.join()
|
302 |
+
cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
|
303 |
+
"--pred_path", os.path.join(model_path, "evaluation", "realtime_vsego4d"),
|
304 |
+
"--output_dir", os.path.join(model_path, "evaluation", "realtime_vsego4d", "results"),
|
305 |
+
"--output_json", os.path.join(model_path, "evaluation", "realtime_vsego4d", "results.json"),
|
306 |
+
"--num_chunks", str(num_chunks),
|
307 |
+
"--num_tasks", "16",
|
308 |
+
"--api_key", args.api_key,
|
309 |
+
"--api_base", args.api_base,
|
310 |
+
"--api_type", args.api_type,
|
311 |
+
"--api_version", args.api_version,
|
312 |
+
]
|
313 |
+
exec(cmd)
|
314 |
+
|
315 |
+
|
316 |
+
if __name__ == "__main__":
|
317 |
+
parser = argparse.ArgumentParser()
|
318 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
319 |
+
parser.add_argument("--dataset", type=str, default=None)
|
320 |
+
parser.add_argument("--api_key", type=str, default=None)
|
321 |
+
parser.add_argument("--api_base", type=str, default=None)
|
322 |
+
parser.add_argument("--api_type", type=str, default=None)
|
323 |
+
parser.add_argument("--api_version", type=str, default=None)
|
324 |
+
parser.add_argument("--num_chunks", type=int, default=1)
|
325 |
+
parser.add_argument("--only_eval", action="store_true")
|
326 |
+
parser.add_argument("--vizlen", type=int, default=0)
|
327 |
+
parser.add_argument("--use_speech", action="store_true", default=False)
|
328 |
+
args = parser.parse_args()
|
329 |
+
func_dic = {'msvd': eval_msvd,
|
330 |
+
'msrvtt': eval_msrvtt,
|
331 |
+
'actnet': eval_actnet,
|
332 |
+
'nextoe': eval_nextoe,
|
333 |
+
'vsmovienet': eval_vsmovienet,
|
334 |
+
'vsego4d': eval_vsego4d,
|
335 |
+
'realtime_vsmovienet': eval_realtime_vsmovienet,
|
336 |
+
'realtime_vsego4d': eval_realtime_vsego4d,
|
337 |
+
}
|
338 |
+
if args.dataset in func_dic:
|
339 |
+
print(f'Execute {args.dataset} evaluation')
|
340 |
+
func_dic[args.dataset](args)
|
flash_vstream/eval_video/model_msvd_qa.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import argparse
|
8 |
+
from tqdm import tqdm
|
9 |
+
from decord import VideoReader, cpu
|
10 |
+
|
11 |
+
from llama_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
12 |
+
from llama_vstream.conversation import conv_templates, SeparatorStyle
|
13 |
+
from llama_vstream.model.builder import load_pretrained_model
|
14 |
+
from llama_vstream.utils import disable_torch_init
|
15 |
+
from llama_vstream.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
16 |
+
|
17 |
+
|
18 |
+
def split_list(lst, n):
|
19 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
20 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
21 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
22 |
+
|
23 |
+
|
24 |
+
def get_chunk(lst, n, k):
|
25 |
+
chunks = split_list(lst, n)
|
26 |
+
return chunks[k]
|
27 |
+
|
28 |
+
|
29 |
+
def parse_args():
|
30 |
+
"""
|
31 |
+
Parse command-line arguments.
|
32 |
+
"""
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
|
35 |
+
# Define the command-line arguments
|
36 |
+
parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
|
37 |
+
parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True)
|
38 |
+
parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
|
39 |
+
parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
|
40 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
41 |
+
parser.add_argument("--model-base", type=str, default=None)
|
42 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
43 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
44 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
45 |
+
parser.add_argument("--model-max-length", type=int, default=None)
|
46 |
+
|
47 |
+
return parser.parse_args()
|
48 |
+
|
49 |
+
|
50 |
+
def load_video(video_path):
|
51 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
52 |
+
total_frame_num = len(vr)
|
53 |
+
fps = round(vr.get_avg_fps())
|
54 |
+
frame_idx = [i for i in range(0, len(vr), fps)]
|
55 |
+
spare_frames = vr.get_batch(frame_idx).asnumpy()
|
56 |
+
return spare_frames
|
57 |
+
|
58 |
+
|
59 |
+
def run_inference(args):
|
60 |
+
"""
|
61 |
+
Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
args: Command-line arguments.
|
65 |
+
"""
|
66 |
+
# Initialize the model
|
67 |
+
model_name = get_model_name_from_path(args.model_path)
|
68 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.model_max_length)
|
69 |
+
|
70 |
+
# Load both ground truth file containing questions and answers
|
71 |
+
with open(args.gt_file) as file:
|
72 |
+
gt_questions = json.load(file)
|
73 |
+
gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
|
74 |
+
|
75 |
+
# Create the output directory if it doesn't exist
|
76 |
+
if not os.path.exists(args.output_dir):
|
77 |
+
try:
|
78 |
+
os.makedirs(args.output_dir)
|
79 |
+
except Exception as e:
|
80 |
+
print(f'mkdir Except: {e}')
|
81 |
+
|
82 |
+
video_formats = ['.mp4', '.avi', '.mov', '.mkv']
|
83 |
+
if args.num_chunks > 1:
|
84 |
+
output_name = f"{args.num_chunks}_{args.chunk_idx}"
|
85 |
+
else:
|
86 |
+
output_name = args.output_name
|
87 |
+
answers_file = os.path.join(args.output_dir, f"{output_name}.json")
|
88 |
+
ans_file = open(answers_file, "w")
|
89 |
+
|
90 |
+
for sample in tqdm(gt_questions, desc=f"cuda:{args.chunk_idx} "):
|
91 |
+
video_name = sample['video_id']
|
92 |
+
question = sample['question']
|
93 |
+
id = sample['id']
|
94 |
+
answer = sample['answer']
|
95 |
+
|
96 |
+
sample_set = {'id': id, 'question': question, 'answer': answer}
|
97 |
+
|
98 |
+
# Load the video file
|
99 |
+
for fmt in video_formats: # Added this line
|
100 |
+
temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}")
|
101 |
+
if os.path.exists(temp_path):
|
102 |
+
video_path = temp_path
|
103 |
+
break
|
104 |
+
|
105 |
+
# Check if the video exists
|
106 |
+
if os.path.exists(video_path):
|
107 |
+
video = load_video(video_path)
|
108 |
+
video = image_processor.preprocess(video, return_tensors='pt')['pixel_values'].half().cuda()
|
109 |
+
video = [video]
|
110 |
+
|
111 |
+
qs = question
|
112 |
+
if model.config.mm_use_im_start_end:
|
113 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
114 |
+
else:
|
115 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
116 |
+
|
117 |
+
conv = conv_templates[args.conv_mode].copy()
|
118 |
+
conv.append_message(conv.roles[0], qs)
|
119 |
+
conv.append_message(conv.roles[1], None)
|
120 |
+
prompt = conv.get_prompt()
|
121 |
+
|
122 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
123 |
+
|
124 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
125 |
+
keywords = [stop_str]
|
126 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
127 |
+
|
128 |
+
with torch.inference_mode():
|
129 |
+
output_ids = model.generate(
|
130 |
+
input_ids,
|
131 |
+
images=video,
|
132 |
+
do_sample=True,
|
133 |
+
temperature=0.002,
|
134 |
+
max_new_tokens=1024,
|
135 |
+
use_cache=True,
|
136 |
+
stopping_criteria=[stopping_criteria])
|
137 |
+
|
138 |
+
input_token_len = input_ids.shape[1]
|
139 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
140 |
+
if n_diff_input_output > 0:
|
141 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
142 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
143 |
+
outputs = outputs.strip()
|
144 |
+
if outputs.endswith(stop_str):
|
145 |
+
outputs = outputs[:-len(stop_str)]
|
146 |
+
outputs = outputs.strip()
|
147 |
+
|
148 |
+
sample_set['pred'] = outputs
|
149 |
+
ans_file.write(json.dumps(sample_set) + "\n")
|
150 |
+
ans_file.flush()
|
151 |
+
|
152 |
+
ans_file.close()
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
args = parse_args()
|
157 |
+
run_inference(args)
|
flash_vstream/eval_video/model_msvd_qa_featuresloader.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
|
2 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import random
|
9 |
+
import argparse
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch.utils.data import Dataset, DataLoader
|
12 |
+
from safetensors.torch import load_file
|
13 |
+
|
14 |
+
from llama_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
15 |
+
from llama_vstream.conversation import conv_templates, SeparatorStyle
|
16 |
+
from llama_vstream.model.builder import load_pretrained_model
|
17 |
+
from llama_vstream.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
18 |
+
|
19 |
+
|
20 |
+
def split_list(lst, n):
|
21 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
22 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
23 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
24 |
+
|
25 |
+
|
26 |
+
def get_chunk(lst, n, k):
|
27 |
+
chunks = split_list(lst, n)
|
28 |
+
return chunks[k]
|
29 |
+
|
30 |
+
|
31 |
+
def parse_args():
|
32 |
+
"""
|
33 |
+
Parse command-line arguments.
|
34 |
+
"""
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
|
37 |
+
# Define the command-line arguments
|
38 |
+
parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
|
39 |
+
parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True)
|
40 |
+
parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
|
41 |
+
parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
|
42 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
43 |
+
parser.add_argument("--model-base", type=str, default=None)
|
44 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
45 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
46 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
47 |
+
parser.add_argument("--model-max-length", type=int, default=None)
|
48 |
+
return parser.parse_args()
|
49 |
+
|
50 |
+
|
51 |
+
class CustomDataset(Dataset):
|
52 |
+
def __init__(self, questions, video_dir, tokenizer, image_processor, model_config):
|
53 |
+
self.questions = questions
|
54 |
+
self.video_dir = video_dir
|
55 |
+
self.tokenizer = tokenizer
|
56 |
+
self.image_processor = image_processor
|
57 |
+
self.model_config = model_config
|
58 |
+
|
59 |
+
def __getitem__(self, index):
|
60 |
+
sample = self.questions[index]
|
61 |
+
video_name = sample['video_id']
|
62 |
+
try:
|
63 |
+
video_path = os.path.join(self.video_dir, video_name + '.safetensors')
|
64 |
+
video_tensor = load_file(video_path)['feature']
|
65 |
+
except Exception as e:
|
66 |
+
print(f'Dataset Exception: {e}, randomly choose one.')
|
67 |
+
idx = random.randint(0, len(self.questions) - 1)
|
68 |
+
return self.__getitem__(idx)
|
69 |
+
qs = sample['question']
|
70 |
+
if self.model_config.mm_use_im_start_end:
|
71 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
72 |
+
else:
|
73 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
74 |
+
conv = conv_templates[args.conv_mode].copy()
|
75 |
+
if 'system' in sample:
|
76 |
+
conv.system = conv.system + ' ' + sample['system']
|
77 |
+
conv.append_message(conv.roles[0], qs)
|
78 |
+
conv.append_message(conv.roles[1], None)
|
79 |
+
prompt = conv.get_prompt()
|
80 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
81 |
+
return input_ids, video_tensor
|
82 |
+
|
83 |
+
def __len__(self):
|
84 |
+
return len(self.questions)
|
85 |
+
|
86 |
+
|
87 |
+
def create_data_loader(questions, video_dir, tokenizer, image_processor, model_config, batch_size=1, num_workers=2):
|
88 |
+
assert batch_size == 1, "batch_size must be 1"
|
89 |
+
dataset = CustomDataset(questions, video_dir, tokenizer, image_processor, model_config)
|
90 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
91 |
+
return data_loader
|
92 |
+
|
93 |
+
|
94 |
+
def run_inference(args):
|
95 |
+
"""
|
96 |
+
Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
args: Command-line arguments.
|
100 |
+
"""
|
101 |
+
# Initialize the model
|
102 |
+
model_name = get_model_name_from_path(args.model_path)
|
103 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.model_max_length)
|
104 |
+
|
105 |
+
# Load both ground truth file containing questions and answers
|
106 |
+
with open(args.gt_file) as file:
|
107 |
+
gt_questions = json.load(file)
|
108 |
+
gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
|
109 |
+
|
110 |
+
# Create the output directory if it doesn't exist
|
111 |
+
if not os.path.exists(args.output_dir):
|
112 |
+
try:
|
113 |
+
os.makedirs(args.output_dir)
|
114 |
+
except Exception as e:
|
115 |
+
print(f'mkdir Except: {e}')
|
116 |
+
|
117 |
+
video_formats = ['.mp4', '.avi', '.mov', '.mkv']
|
118 |
+
if args.num_chunks > 1:
|
119 |
+
output_name = f"{args.num_chunks}_{args.chunk_idx}"
|
120 |
+
else:
|
121 |
+
output_name = args.output_name
|
122 |
+
answers_file = os.path.join(args.output_dir, f"{output_name}.json")
|
123 |
+
# resume from old exp
|
124 |
+
exist_id_set = set()
|
125 |
+
if os.path.exists(answers_file):
|
126 |
+
with open(answers_file) as f:
|
127 |
+
exist_pred_contents = [json.loads(line) for line in f]
|
128 |
+
exist_id_set = set([x['id'] for x in exist_pred_contents])
|
129 |
+
|
130 |
+
new_gt_questions = []
|
131 |
+
for sample in tqdm(gt_questions):
|
132 |
+
if not sample['id'] in exist_id_set:
|
133 |
+
new_gt_questions.append(sample)
|
134 |
+
gt_questions = new_gt_questions
|
135 |
+
|
136 |
+
data_loader = create_data_loader(gt_questions, args.video_dir, tokenizer, image_processor, model.config)
|
137 |
+
|
138 |
+
conv = conv_templates[args.conv_mode].copy()
|
139 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
140 |
+
keywords = [stop_str]
|
141 |
+
|
142 |
+
with open(answers_file, "a") as ans_file:
|
143 |
+
for data, sample in tqdm(zip(data_loader, gt_questions), desc=f"cuda:{args.chunk_idx} ", total=len(gt_questions)):
|
144 |
+
input_ids, video_tensors = data
|
145 |
+
input_ids = input_ids.to(device='cuda', non_blocking=True)
|
146 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
147 |
+
with torch.inference_mode():
|
148 |
+
output_ids = model.generate(
|
149 |
+
input_ids,
|
150 |
+
features=video_tensors.to(dtype=torch.float16, device='cuda', non_blocking=True),
|
151 |
+
do_sample=True,
|
152 |
+
temperature=0.002,
|
153 |
+
max_new_tokens=1024,
|
154 |
+
use_cache=True,
|
155 |
+
stopping_criteria=[stopping_criteria],
|
156 |
+
)
|
157 |
+
input_token_len = input_ids.shape[1]
|
158 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
159 |
+
if n_diff_input_output > 0:
|
160 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
161 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
162 |
+
outputs = outputs.strip()
|
163 |
+
if outputs.endswith(stop_str):
|
164 |
+
outputs = outputs[:-len(stop_str)]
|
165 |
+
outputs = outputs.strip()
|
166 |
+
sample_set = {
|
167 |
+
'id': sample['id'],
|
168 |
+
'question': sample['question'],
|
169 |
+
'answer': sample['answer'],
|
170 |
+
'answer_type': sample['answer_type'] if 'answer_type' in sample else None,
|
171 |
+
'pred': outputs
|
172 |
+
}
|
173 |
+
ans_file.write(json.dumps(sample_set) + "\n")
|
174 |
+
ans_file.flush()
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
args = parse_args()
|
179 |
+
run_inference(args)
|
flash_vstream/mm_utils.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
from io import BytesIO
|
5 |
+
import base64
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import StoppingCriteria
|
9 |
+
from flash_vstream.constants import IMAGE_TOKEN_INDEX
|
10 |
+
|
11 |
+
|
12 |
+
def load_image_from_base64(image):
|
13 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
14 |
+
|
15 |
+
|
16 |
+
def expand2square(pil_img, background_color):
|
17 |
+
width, height = pil_img.size
|
18 |
+
if width == height:
|
19 |
+
return pil_img
|
20 |
+
elif width > height:
|
21 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
22 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
23 |
+
return result
|
24 |
+
else:
|
25 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
26 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
27 |
+
return result
|
28 |
+
|
29 |
+
|
30 |
+
def process_images(images, image_processor, model_cfg):
|
31 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
32 |
+
new_images = []
|
33 |
+
if image_aspect_ratio == 'pad':
|
34 |
+
for image in images:
|
35 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
36 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
37 |
+
new_images.append(image)
|
38 |
+
else:
|
39 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
40 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
41 |
+
new_images = torch.stack(new_images, dim=0)
|
42 |
+
return new_images
|
43 |
+
|
44 |
+
|
45 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
46 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
47 |
+
|
48 |
+
def insert_separator(X, sep):
|
49 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
50 |
+
|
51 |
+
input_ids = []
|
52 |
+
offset = 0
|
53 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
54 |
+
offset = 1
|
55 |
+
input_ids.append(prompt_chunks[0][0])
|
56 |
+
|
57 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
58 |
+
input_ids.extend(x[offset:])
|
59 |
+
|
60 |
+
if return_tensors is not None:
|
61 |
+
if return_tensors == 'pt':
|
62 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
63 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
64 |
+
return input_ids
|
65 |
+
|
66 |
+
|
67 |
+
def get_model_name_from_path(model_path):
|
68 |
+
model_path = model_path.strip("/")
|
69 |
+
model_paths = model_path.split("/")
|
70 |
+
if model_paths[-1].startswith('checkpoint-'):
|
71 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
72 |
+
else:
|
73 |
+
return model_paths[-1]
|
74 |
+
|
75 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
76 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
77 |
+
self.keywords = keywords
|
78 |
+
self.keyword_ids = []
|
79 |
+
self.max_keyword_len = 0
|
80 |
+
for keyword in keywords:
|
81 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
82 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
83 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
84 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
85 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
86 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
87 |
+
self.tokenizer = tokenizer
|
88 |
+
self.start_len = input_ids.shape[1]
|
89 |
+
|
90 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
91 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
92 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
93 |
+
for keyword_id in self.keyword_ids:
|
94 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
95 |
+
return True
|
96 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
97 |
+
for keyword in self.keywords:
|
98 |
+
if keyword in outputs:
|
99 |
+
return True
|
100 |
+
return False
|
101 |
+
|
102 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
103 |
+
outputs = []
|
104 |
+
for i in range(output_ids.shape[0]):
|
105 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
106 |
+
return all(outputs)
|
flash_vstream/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .language_model.vstream_llama import VStreamLlamaForCausalLM, VStreamConfig
|
flash_vstream/model/builder.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
|
2 |
+
# ------------------------------------------------------------------------
|
3 |
+
# Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
4 |
+
# Copyright 2023 Haotian Liu
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
import os
|
20 |
+
import warnings
|
21 |
+
import shutil
|
22 |
+
|
23 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
24 |
+
import torch
|
25 |
+
from flash_vstream.model import VStreamLlamaForCausalLM, VStreamConfig
|
26 |
+
from flash_vstream.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
27 |
+
|
28 |
+
|
29 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
|
30 |
+
kwargs = {"device_map": device_map, **kwargs}
|
31 |
+
|
32 |
+
if device != "cuda":
|
33 |
+
kwargs['device_map'] = {"": device}
|
34 |
+
|
35 |
+
if load_8bit:
|
36 |
+
kwargs['load_in_8bit'] = True
|
37 |
+
elif load_4bit:
|
38 |
+
kwargs['load_in_4bit'] = True
|
39 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
40 |
+
load_in_4bit=True,
|
41 |
+
bnb_4bit_compute_dtype=torch.float16,
|
42 |
+
bnb_4bit_use_double_quant=True,
|
43 |
+
bnb_4bit_quant_type='nf4'
|
44 |
+
)
|
45 |
+
else:
|
46 |
+
kwargs['torch_dtype'] = torch.float16
|
47 |
+
|
48 |
+
if 'vstream' in model_name.lower():
|
49 |
+
# Load LLaMA-VStream model
|
50 |
+
if 'lora' in model_name.lower() and model_base is None:
|
51 |
+
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
52 |
+
if 'lora' in model_name.lower() and model_base is not None:
|
53 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
55 |
+
print('(LoRA) Loading LLaMA-VStream from base model...')
|
56 |
+
model = VStreamLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
57 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
58 |
+
if model.lm_head.weight.shape[0] != token_num:
|
59 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
60 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
61 |
+
|
62 |
+
print('(LoRA) Loading additional LLaMA-VStream weights...')
|
63 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
64 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
65 |
+
else:
|
66 |
+
# this is probably from HF Hub
|
67 |
+
from huggingface_hub import hf_hub_download
|
68 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
69 |
+
cache_file = hf_hub_download(
|
70 |
+
repo_id=repo_id,
|
71 |
+
filename=filename,
|
72 |
+
subfolder=subfolder)
|
73 |
+
return torch.load(cache_file, map_location='cpu')
|
74 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
75 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
76 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
77 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
78 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
79 |
+
|
80 |
+
from peft import PeftModel
|
81 |
+
print('Loading LoRA weights...')
|
82 |
+
model = PeftModel.from_pretrained(model, model_path)
|
83 |
+
print('Merging LoRA weights...')
|
84 |
+
model = model.merge_and_unload()
|
85 |
+
print('Model is loaded...')
|
86 |
+
elif model_base is not None:
|
87 |
+
# this may be mm projector only
|
88 |
+
print('Loading LLaMA-VStream from base model...')
|
89 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
90 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
91 |
+
model = VStreamLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
92 |
+
|
93 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
94 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
95 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
96 |
+
else:
|
97 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
98 |
+
model = VStreamLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
99 |
+
else:
|
100 |
+
# Load language model
|
101 |
+
if model_base is not None:
|
102 |
+
# PEFT model
|
103 |
+
from peft import PeftModel
|
104 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
105 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
106 |
+
print(f"Loading LoRA weights from {model_path}")
|
107 |
+
model = PeftModel.from_pretrained(model, model_path)
|
108 |
+
print(f"Merging weights")
|
109 |
+
model = model.merge_and_unload()
|
110 |
+
print('Convert to FP16...')
|
111 |
+
model.to(torch.float16)
|
112 |
+
else:
|
113 |
+
use_fast = False
|
114 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
115 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
116 |
+
|
117 |
+
image_processor = None
|
118 |
+
|
119 |
+
if 'vstream' in model_name.lower():
|
120 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
121 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
122 |
+
if mm_use_im_patch_token:
|
123 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
124 |
+
if mm_use_im_start_end:
|
125 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
126 |
+
model.resize_token_embeddings(len(tokenizer))
|
127 |
+
|
128 |
+
vision_tower = model.get_vision_tower()
|
129 |
+
if not vision_tower.is_loaded:
|
130 |
+
vision_tower.load_model()
|
131 |
+
vision_tower.to(device=device, dtype=torch.float16)
|
132 |
+
image_processor = vision_tower.image_processor
|
133 |
+
|
134 |
+
if hasattr(model.config, "max_sequence_length"):
|
135 |
+
context_len = model.config.max_sequence_length
|
136 |
+
else:
|
137 |
+
context_len = 2048
|
138 |
+
|
139 |
+
return tokenizer, model, image_processor, context_len
|
flash_vstream/model/compress_functions.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Flash-VStream Authors
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import random
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
def drop_feature(img_feature, video_max_frames, img_similarity=None):
|
21 |
+
T, P, D = img_feature.shape
|
22 |
+
indices = [[i] for i in range(T)]
|
23 |
+
T0 = video_max_frames
|
24 |
+
if T <= T0:
|
25 |
+
return img_feature, img_similarity, [indices]
|
26 |
+
cur_feature = img_feature[:T0] # [T0, P, D]
|
27 |
+
if img_similarity is not None:
|
28 |
+
cur_sim = img_similarity[:T0 - 1]
|
29 |
+
else:
|
30 |
+
cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1]
|
31 |
+
cur_indices = indices[:T0]
|
32 |
+
step_indices = [cur_indices]
|
33 |
+
for i in range(T0, T):
|
34 |
+
new_feature = img_feature[i]
|
35 |
+
new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0)
|
36 |
+
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
|
37 |
+
all_indices = cur_indices + [[i]]
|
38 |
+
all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0)
|
39 |
+
idx = torch.argmax(all_sim)
|
40 |
+
if random.randint(0, 1) > 0:
|
41 |
+
idx = idx + 1
|
42 |
+
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
|
43 |
+
if idx + 1 == T0 + 1:
|
44 |
+
cur_sim = all_sim[:T0 - 1]
|
45 |
+
cur_indices = all_indices[:-1]
|
46 |
+
elif idx == 0:
|
47 |
+
cur_sim = all_sim[1:]
|
48 |
+
cur_indices = all_indices[1:]
|
49 |
+
else:
|
50 |
+
cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]])
|
51 |
+
cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0)
|
52 |
+
cur_indices = all_indices[:idx] + all_indices[idx + 1:]
|
53 |
+
step_indices.append(cur_indices)
|
54 |
+
# print(f'Note: perform drop feature {img_feature.shape} to {cur_feature.shape}')
|
55 |
+
return cur_feature, cur_sim, step_indices
|
56 |
+
|
57 |
+
|
58 |
+
def merge_feature(img_feature, video_max_frames, img_similarity=None):
|
59 |
+
T, P, D = img_feature.shape
|
60 |
+
indices = [[i] for i in range(T)]
|
61 |
+
T0 = video_max_frames
|
62 |
+
if T <= T0:
|
63 |
+
return img_feature, img_similarity, [indices]
|
64 |
+
cur_feature = img_feature[:T0] # [T0, P, D]
|
65 |
+
cur_indices = indices[:T0]
|
66 |
+
step_indices = [cur_indices]
|
67 |
+
if img_similarity is not None:
|
68 |
+
cur_sim = img_similarity[:T0 - 1]
|
69 |
+
else:
|
70 |
+
cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1]
|
71 |
+
for i in range(T0, T):
|
72 |
+
new_feature = img_feature[i]
|
73 |
+
new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0)
|
74 |
+
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
|
75 |
+
all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0)
|
76 |
+
all_indices = cur_indices + [[i]]
|
77 |
+
idx = torch.argmax(all_sim)
|
78 |
+
all_feature[idx + 1] = (all_feature[idx] + all_feature[idx + 1]) / 2.0
|
79 |
+
all_indices[idx + 1] = all_indices[idx] + all_indices[idx + 1]
|
80 |
+
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
|
81 |
+
cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]])
|
82 |
+
cur_indices = all_indices[:idx] + all_indices[idx + 1:]
|
83 |
+
if idx > 0:
|
84 |
+
cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0)
|
85 |
+
if idx + 1 < T0:
|
86 |
+
cur_sim[idx] = F.cosine_similarity(all_feature[idx + 1].view(-1), all_feature[idx + 2].view(-1), dim=0)
|
87 |
+
step_indices.append(cur_indices)
|
88 |
+
# print(f'Note: perform merge feature {img_feature.shape} to {cur_feature.shape}')
|
89 |
+
return cur_feature, cur_sim, step_indices
|
90 |
+
|
91 |
+
|
92 |
+
def kmeans_feature(img_feature, video_max_frames, img_similarity=None):
|
93 |
+
def kmeans_torch(X, num_clusters, distance='euclidean', tol=1e-4, max_iter=10):
|
94 |
+
indices = torch.randperm(X.size(0))[:num_clusters]
|
95 |
+
centroids = X[indices]
|
96 |
+
for i in range(max_iter):
|
97 |
+
if distance == 'euclidean':
|
98 |
+
dists = torch.cdist(X, centroids, p=2)
|
99 |
+
else:
|
100 |
+
raise NotImplementedError("Only Euclidean distance is supported yet")
|
101 |
+
labels = torch.argmin(dists, dim=1)
|
102 |
+
new_centroids = []
|
103 |
+
for j in range(num_clusters):
|
104 |
+
cluster_points = X[labels == j]
|
105 |
+
if len(cluster_points) > 0:
|
106 |
+
new_centroid = cluster_points.mean(0)
|
107 |
+
else: # fix nan centroids
|
108 |
+
new_centroid = X[random.randint(0, X.size(0) - 1)]
|
109 |
+
new_centroids.append(new_centroid)
|
110 |
+
new_centroids = torch.stack(new_centroids)
|
111 |
+
diff = torch.norm(centroids - new_centroids, dim=1).sum()
|
112 |
+
if diff < tol:
|
113 |
+
break
|
114 |
+
centroids = new_centroids
|
115 |
+
return centroids, labels, i
|
116 |
+
T, P, D = img_feature.shape
|
117 |
+
T0 = video_max_frames
|
118 |
+
if T <= T0:
|
119 |
+
return img_feature, img_similarity, [[[i] for i in range(T)]]
|
120 |
+
X = img_feature.view(T, -1) # [T, P, D]
|
121 |
+
centroids, labels, exit_step = kmeans_torch(X, T0)
|
122 |
+
reduced_feature = centroids.view(T0, P, D)
|
123 |
+
# print(f'Note: perform kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0
|
124 |
+
step_indices = [[] for _ in range(T0)]
|
125 |
+
for i in range(T0):
|
126 |
+
step_indices[i] = [j for j in range(T) if labels[j] == i]
|
127 |
+
return reduced_feature, img_similarity, [step_indices]
|
128 |
+
|
129 |
+
|
130 |
+
def weighted_kmeans_feature(img_feature, video_max_frames, weights=None):
|
131 |
+
if weights is None:
|
132 |
+
weights = torch.ones(img_feature.size(0), dtype=img_feature.dtype, device=img_feature.device)
|
133 |
+
def weighted_kmeans_torch(X, num_clusters, weights=None, distance='euclidean', tol=1e-4, max_iter=10):
|
134 |
+
indices = torch.randperm(X.size(0), device=X.device)[:num_clusters]
|
135 |
+
centroids = X[indices]
|
136 |
+
for i in range(max_iter):
|
137 |
+
if distance == 'euclidean':
|
138 |
+
dists = ((X.unsqueeze(1) - centroids.unsqueeze(0)) ** 2).sum(dim=2).sqrt()
|
139 |
+
else:
|
140 |
+
raise NotImplementedError("Only Euclidean distance is supported yet")
|
141 |
+
labels = torch.argmin(dists, dim=1)
|
142 |
+
weighted_sum = torch.zeros_like(centroids)
|
143 |
+
weights_sum = torch.zeros(num_clusters, dtype=X.dtype, device=X.device)
|
144 |
+
for j in range(num_clusters):
|
145 |
+
cluster_mask = labels == j
|
146 |
+
weighted_sum[j] = torch.sum(weights[cluster_mask, None] * X[cluster_mask], dim=0)
|
147 |
+
weights_sum[j] = torch.sum(weights[cluster_mask])
|
148 |
+
mask = weights_sum > 0
|
149 |
+
new_centroids = torch.zeros_like(weighted_sum)
|
150 |
+
new_centroids[mask] = weighted_sum[mask] / weights_sum[mask, None]
|
151 |
+
if mask.sum() < num_clusters: # fix nan centroids
|
152 |
+
new_centroids[~mask] = torch.stack([X[random.randint(0, X.size(0) - 1)] for _ in range(num_clusters - mask.sum())])
|
153 |
+
diff = torch.norm(centroids - new_centroids, dim=1).sum()
|
154 |
+
if diff < tol:
|
155 |
+
break
|
156 |
+
centroids = new_centroids
|
157 |
+
return centroids, labels, weights_sum, i
|
158 |
+
T, P, D = img_feature.shape
|
159 |
+
T0 = video_max_frames
|
160 |
+
if T <= T0:
|
161 |
+
return img_feature, weights, [[[i] for i in range(T)]]
|
162 |
+
X = img_feature.view(T, -1) # [T, P, D]
|
163 |
+
centroids, labels, weights, exit_step = weighted_kmeans_torch(X, T0, weights)
|
164 |
+
reduced_feature = centroids.view(T0, P, D)
|
165 |
+
# print(f'Note: perform weighted kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0
|
166 |
+
step_indices = [[] for _ in range(T0)]
|
167 |
+
for i in range(T0):
|
168 |
+
step_indices[i] = [j for j in range(T) if labels[j] == i]
|
169 |
+
return reduced_feature, weights, [step_indices]
|
170 |
+
|
171 |
+
|
172 |
+
def k_drop_feature(img_feature, video_max_frames, img_similarity=None):
|
173 |
+
T, P, D = img_feature.shape
|
174 |
+
indices = [[i] for i in range(T)]
|
175 |
+
T0 = video_max_frames
|
176 |
+
if T <= T0:
|
177 |
+
return img_feature, img_similarity, [indices]
|
178 |
+
cur_feature = img_feature[:T0] # [T0, P, D]
|
179 |
+
normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1)
|
180 |
+
cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0]
|
181 |
+
cur_sim.fill_diagonal_(-100.0)
|
182 |
+
cur_indices = indices[:T0]
|
183 |
+
step_indices = [cur_indices]
|
184 |
+
for i in range(T0, T):
|
185 |
+
# get new feature
|
186 |
+
new_feature = img_feature[i]
|
187 |
+
normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1)
|
188 |
+
new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1]
|
189 |
+
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
|
190 |
+
normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0)
|
191 |
+
all_indices = cur_indices + [[i]]
|
192 |
+
# get new similarity
|
193 |
+
all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1]
|
194 |
+
all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1]
|
195 |
+
all_sim[-1, :-1] = new_sim.T
|
196 |
+
# choose compression position
|
197 |
+
idx = torch.argmax(all_sim)
|
198 |
+
left, right = idx // (T0 + 1), idx % (T0 + 1)
|
199 |
+
if random.randint(0, 1) > 0:
|
200 |
+
idx = left
|
201 |
+
else:
|
202 |
+
idx = right
|
203 |
+
assert all_sim[left, right] == torch.max(all_sim)
|
204 |
+
# get compressed feature and similarity
|
205 |
+
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
|
206 |
+
normed_cur_features = torch.cat([normed_all_features[:idx], normed_all_features[idx + 1:]])
|
207 |
+
cur_indices = all_indices[:idx] + all_indices[idx + 1:]
|
208 |
+
cur_sim_1 = torch.cat([all_sim[:idx], all_sim[idx + 1:]], dim=0) # [T0, T0 + 1]
|
209 |
+
cur_sim = torch.cat([cur_sim_1[:, :idx], cur_sim_1[:, idx + 1:]], dim=1) # [T0, T0]
|
210 |
+
step_indices.append(cur_indices)
|
211 |
+
# print(f'Note: perform k-drop feature {img_feature.shape} to {cur_feature.shape}')
|
212 |
+
return cur_feature, None, step_indices
|
213 |
+
|
214 |
+
|
215 |
+
def k_merge_feature(img_feature, video_max_frames, img_similarity=None):
|
216 |
+
T, P, D = img_feature.shape
|
217 |
+
indices = [[i] for i in range(T)]
|
218 |
+
T0 = video_max_frames
|
219 |
+
if T <= T0:
|
220 |
+
return img_feature, img_similarity, [indices]
|
221 |
+
cur_feature = img_feature[:T0] # [T0, P, D]
|
222 |
+
normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1)
|
223 |
+
cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0]
|
224 |
+
cur_sim.fill_diagonal_(-100.0)
|
225 |
+
cur_indices = indices[:T0]
|
226 |
+
step_indices = [cur_indices]
|
227 |
+
for i in range(T0, T):
|
228 |
+
# get new feature
|
229 |
+
new_feature = img_feature[i]
|
230 |
+
normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1)
|
231 |
+
new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1]
|
232 |
+
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
|
233 |
+
normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0)
|
234 |
+
all_indices = cur_indices + [[i]]
|
235 |
+
# get new similarity
|
236 |
+
all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1]
|
237 |
+
all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1]
|
238 |
+
all_sim[-1, :-1] = new_sim.T
|
239 |
+
# choose compression position
|
240 |
+
idx = torch.argmax(all_sim)
|
241 |
+
left, right = idx // (T0 + 1), idx % (T0 + 1)
|
242 |
+
assert all_sim[left, right] == torch.max(all_sim)
|
243 |
+
# update feature
|
244 |
+
all_feature[right] = (all_feature[left] + all_feature[right]) / 2.0
|
245 |
+
normed_all_features[right] = F.normalize(all_feature[right].view(1, P * D), p=2, dim=1)
|
246 |
+
all_indices[right] = all_indices[left] + all_indices[right]
|
247 |
+
# update similarity
|
248 |
+
new_sim = torch.mm(normed_all_features, normed_all_features[right:right+1].T) # [T0 + 1, 1]
|
249 |
+
all_sim[right, :] = new_sim.T
|
250 |
+
all_sim[:, right:right+1] = new_sim
|
251 |
+
all_sim[right, right] = -100.0
|
252 |
+
# get compressed feature and similarity
|
253 |
+
cur_feature = torch.cat([all_feature[:left], all_feature[left + 1:]])
|
254 |
+
normed_cur_features = torch.cat([normed_all_features[:left], normed_all_features[left + 1:]])
|
255 |
+
cur_indices = all_indices[:left] + all_indices[left + 1:]
|
256 |
+
cur_sim_1 = torch.cat([all_sim[:left], all_sim[left + 1:]], dim=0) # [T0, T0 + 1]
|
257 |
+
cur_sim = torch.cat([cur_sim_1[:, :left], cur_sim_1[:, left + 1:]], dim=1) # [T0, T0]
|
258 |
+
step_indices.append(cur_indices)
|
259 |
+
# print(f'Note: perform k-merge feature {img_feature.shape} to {cur_feature.shape}')
|
260 |
+
return cur_feature, cur_sim, step_indices
|
261 |
+
|
262 |
+
|
263 |
+
def attention_feature(img_feature, video_max_frames, attention_fn=None, update_ratio=0.2):
|
264 |
+
T, P, D = img_feature.shape
|
265 |
+
T0 = video_max_frames
|
266 |
+
if T <= T0:
|
267 |
+
return img_feature, None
|
268 |
+
cur_feature = img_feature[:T0] # [T0, P, D]
|
269 |
+
turing_memory = cur_feature.reshape(T0*P, D) # [T0*P, D]
|
270 |
+
for i in range(T0, T, T0):
|
271 |
+
j = min(i + T0, T)
|
272 |
+
new_feature = img_feature[i:j] # [P, D]
|
273 |
+
new_feature = new_feature.reshape(-1, D) # [n*P, D]
|
274 |
+
turing_memory = attention_fn(turing_memory, new_feature, update_ratio=update_ratio) # [T0*P, n*P]
|
275 |
+
cur_feature = turing_memory.reshape(T0, P, D)
|
276 |
+
# print(f'Note: perform {attention_fn.__name__} feature {img_feature.shape} to {cur_feature.shape}')
|
277 |
+
return cur_feature, None
|
flash_vstream/model/language_model/vstream_llama.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
|
2 |
+
# ------------------------------------------------------------------------
|
3 |
+
# Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
4 |
+
# Copyright 2023 Haotian Liu
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
23 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from flash_vstream.model.vstream_arch import VStreamMetaModel, VStreamMetaForCausalLM
|
26 |
+
|
27 |
+
|
28 |
+
class VStreamConfig(LlamaConfig):
|
29 |
+
model_type = "vstream"
|
30 |
+
|
31 |
+
|
32 |
+
class VStreamLlamaModel(VStreamMetaModel, LlamaModel):
|
33 |
+
config_class = VStreamConfig
|
34 |
+
|
35 |
+
def __init__(self, config: LlamaConfig):
|
36 |
+
super(VStreamLlamaModel, self).__init__(config)
|
37 |
+
|
38 |
+
|
39 |
+
class VStreamLlamaForCausalLM(VStreamMetaForCausalLM, LlamaForCausalLM):
|
40 |
+
config_class = VStreamConfig
|
41 |
+
|
42 |
+
def __init__(self, config):
|
43 |
+
super(VStreamLlamaForCausalLM, self).__init__(config)
|
44 |
+
self.model = VStreamLlamaModel(config)
|
45 |
+
self.pretraining_tp = config.pretraining_tp
|
46 |
+
self.vocab_size = config.vocab_size
|
47 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
48 |
+
|
49 |
+
# Initialize weights and apply final processing
|
50 |
+
self.post_init()
|
51 |
+
|
52 |
+
def get_model(self):
|
53 |
+
return self.model
|
54 |
+
|
55 |
+
def forward(
|
56 |
+
self,
|
57 |
+
input_ids: torch.LongTensor = None,
|
58 |
+
attention_mask: Optional[torch.Tensor] = None,
|
59 |
+
position_ids: Optional[torch.LongTensor] = None,
|
60 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
61 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
62 |
+
labels: Optional[torch.LongTensor] = None,
|
63 |
+
use_cache: Optional[bool] = True,
|
64 |
+
output_attentions: Optional[bool] = None,
|
65 |
+
output_hidden_states: Optional[bool] = None,
|
66 |
+
images: Optional[torch.FloatTensor] = None,
|
67 |
+
features: Optional[torch.FloatTensor] = None,
|
68 |
+
return_dict: Optional[bool] = None,
|
69 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
70 |
+
if inputs_embeds is None:
|
71 |
+
if self.use_video_streaming_mode:
|
72 |
+
(
|
73 |
+
input_ids,
|
74 |
+
position_ids,
|
75 |
+
attention_mask,
|
76 |
+
past_key_values,
|
77 |
+
inputs_embeds,
|
78 |
+
labels
|
79 |
+
) = self.prepare_inputs_labels_for_multimodal_streaming(
|
80 |
+
input_ids,
|
81 |
+
position_ids,
|
82 |
+
attention_mask,
|
83 |
+
past_key_values,
|
84 |
+
labels,
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
(
|
88 |
+
input_ids,
|
89 |
+
position_ids,
|
90 |
+
attention_mask,
|
91 |
+
past_key_values,
|
92 |
+
inputs_embeds,
|
93 |
+
labels
|
94 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
95 |
+
input_ids,
|
96 |
+
position_ids,
|
97 |
+
attention_mask,
|
98 |
+
past_key_values,
|
99 |
+
labels,
|
100 |
+
images,
|
101 |
+
features,
|
102 |
+
)
|
103 |
+
return super().forward(
|
104 |
+
input_ids=input_ids,
|
105 |
+
attention_mask=attention_mask,
|
106 |
+
position_ids=position_ids,
|
107 |
+
past_key_values=past_key_values,
|
108 |
+
inputs_embeds=inputs_embeds,
|
109 |
+
labels=labels,
|
110 |
+
use_cache=use_cache,
|
111 |
+
output_attentions=output_attentions,
|
112 |
+
output_hidden_states=output_hidden_states,
|
113 |
+
return_dict=return_dict
|
114 |
+
)
|
115 |
+
|
116 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
117 |
+
images = kwargs.pop("images", None)
|
118 |
+
features = kwargs.pop("features", None)
|
119 |
+
_inputs = super().prepare_inputs_for_generation(
|
120 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
121 |
+
)
|
122 |
+
if images is not None:
|
123 |
+
_inputs['images'] = images
|
124 |
+
if features is not None:
|
125 |
+
_inputs['features'] = features
|
126 |
+
return _inputs
|
127 |
+
|
128 |
+
AutoConfig.register("vstream", VStreamConfig)
|
129 |
+
AutoModelForCausalLM.register(VStreamConfig, VStreamLlamaForCausalLM)
|
flash_vstream/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
import os
|
4 |
+
from .clip_encoder import CLIPVisionTower
|
5 |
+
|
6 |
+
|
7 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
8 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
9 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
10 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
|
11 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
12 |
+
|
13 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
flash_vstream/model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
7 |
+
|
8 |
+
|
9 |
+
class CLIPVisionTower(nn.Module):
|
10 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.is_loaded = False
|
14 |
+
|
15 |
+
self.vision_tower_name = vision_tower
|
16 |
+
self.select_layer = args.mm_vision_select_layer
|
17 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
18 |
+
|
19 |
+
if not delay_load:
|
20 |
+
self.load_model()
|
21 |
+
else:
|
22 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
23 |
+
|
24 |
+
def load_model(self):
|
25 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
26 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
27 |
+
self.vision_tower.requires_grad_(False)
|
28 |
+
|
29 |
+
self.is_loaded = True
|
30 |
+
|
31 |
+
def feature_select(self, image_forward_outs):
|
32 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
33 |
+
if self.select_feature == 'patch':
|
34 |
+
image_features = image_features[:, 1:]
|
35 |
+
elif self.select_feature == 'cls_patch':
|
36 |
+
image_features = image_features
|
37 |
+
else:
|
38 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
39 |
+
return image_features
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def forward(self, images):
|
43 |
+
if type(images) is list:
|
44 |
+
image_features = []
|
45 |
+
for image in images:
|
46 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
47 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
48 |
+
image_features.append(image_feature)
|
49 |
+
else:
|
50 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
51 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
52 |
+
|
53 |
+
return image_features
|
54 |
+
|
55 |
+
@property
|
56 |
+
def dummy_feature(self):
|
57 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def dtype(self):
|
61 |
+
return self.vision_tower.dtype
|
62 |
+
|
63 |
+
@property
|
64 |
+
def device(self):
|
65 |
+
return self.vision_tower.device
|
66 |
+
|
67 |
+
@property
|
68 |
+
def config(self):
|
69 |
+
if self.is_loaded:
|
70 |
+
return self.vision_tower.config
|
71 |
+
else:
|
72 |
+
return self.cfg_only
|
73 |
+
|
74 |
+
@property
|
75 |
+
def hidden_size(self):
|
76 |
+
return self.config.hidden_size
|
77 |
+
|
78 |
+
@property
|
79 |
+
def num_patches(self):
|
80 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
flash_vstream/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import re
|
6 |
+
|
7 |
+
|
8 |
+
class IdentityMap(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
def forward(self, x, *args, **kwargs):
|
13 |
+
return x
|
14 |
+
|
15 |
+
@property
|
16 |
+
def config(self):
|
17 |
+
return {"mm_projector_type": 'identity'}
|
18 |
+
|
19 |
+
|
20 |
+
class SimpleResBlock(nn.Module):
|
21 |
+
def __init__(self, channels):
|
22 |
+
super().__init__()
|
23 |
+
self.pre_norm = nn.LayerNorm(channels)
|
24 |
+
|
25 |
+
self.proj = nn.Sequential(
|
26 |
+
nn.Linear(channels, channels),
|
27 |
+
nn.GELU(),
|
28 |
+
nn.Linear(channels, channels)
|
29 |
+
)
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.pre_norm(x)
|
32 |
+
return x + self.proj(x)
|
33 |
+
|
34 |
+
|
35 |
+
def build_vision_projector(config, input_dim, delay_load=False, **kwargs):
|
36 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
37 |
+
|
38 |
+
if projector_type == 'linear':
|
39 |
+
return nn.Linear(input_dim, config.hidden_size)
|
40 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
41 |
+
if mlp_gelu_match:
|
42 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
43 |
+
modules = [nn.Linear(input_dim, config.hidden_size)]
|
44 |
+
for _ in range(1, mlp_depth):
|
45 |
+
modules.append(nn.GELU())
|
46 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
47 |
+
return nn.Sequential(*modules)
|
48 |
+
if projector_type == 'identity':
|
49 |
+
return IdentityMap()
|
50 |
+
|
51 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
flash_vstream/model/vstream_arch.py
ADDED
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
|
2 |
+
# ------------------------------------------------------------------------
|
3 |
+
# Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
4 |
+
# Copyright 2023 Haotian Liu
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import time
|
19 |
+
import math
|
20 |
+
import logging
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from torch.multiprocessing import Lock, Manager
|
25 |
+
|
26 |
+
from abc import ABC, abstractmethod
|
27 |
+
from flash_vstream.model.multimodal_encoder.builder import build_vision_tower
|
28 |
+
from flash_vstream.model.multimodal_projector.builder import build_vision_projector
|
29 |
+
from flash_vstream.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
30 |
+
|
31 |
+
from flash_vstream.model.compress_functions import drop_feature, merge_feature, kmeans_feature, weighted_kmeans_feature, k_drop_feature, k_merge_feature, attention_feature
|
32 |
+
|
33 |
+
|
34 |
+
class NeuralTuringMachine(nn.Module):
|
35 |
+
def __init__(self, input_dim=1024, output_dim=1024, attention_dropout=0.1):
|
36 |
+
super(NeuralTuringMachine, self).__init__()
|
37 |
+
self.input_dim = input_dim
|
38 |
+
self.output_dim = output_dim
|
39 |
+
self.q_proj = nn.Linear(input_dim, output_dim)
|
40 |
+
self.k_proj = nn.Linear(input_dim, output_dim)
|
41 |
+
self.v_proj = nn.Linear(input_dim, output_dim)
|
42 |
+
self.dropout = nn.Dropout(attention_dropout)
|
43 |
+
self.out_proj = nn.Linear(output_dim, input_dim)
|
44 |
+
self.out_dropout = nn.Dropout(attention_dropout)
|
45 |
+
self.out_ln = nn.LayerNorm(input_dim, eps=1e-12)
|
46 |
+
|
47 |
+
def get_weight(self, x, y):
|
48 |
+
query = self.q_proj(x)
|
49 |
+
key = self.k_proj(y)
|
50 |
+
scores = torch.matmul(query, key.transpose(0, 1)) / math.sqrt(self.output_dim)
|
51 |
+
weight = F.softmax(scores, dim=-1)
|
52 |
+
return weight
|
53 |
+
|
54 |
+
def forward(self, x, y):
|
55 |
+
query = self.q_proj(x)
|
56 |
+
key = self.k_proj(y)
|
57 |
+
scores = torch.matmul(query, key.transpose(0, 1)) / math.sqrt(self.output_dim)
|
58 |
+
weight = F.softmax(scores, dim=-1)
|
59 |
+
attn = self.dropout(weight)
|
60 |
+
value = self.v_proj(y)
|
61 |
+
output = torch.matmul(attn, value)
|
62 |
+
output = self.out_proj(output)
|
63 |
+
output = self.out_dropout(output)
|
64 |
+
output = self.out_ln(output.unsqueeze(0)).squeeze(0)
|
65 |
+
return output
|
66 |
+
|
67 |
+
|
68 |
+
class VStreamMetaModel:
|
69 |
+
|
70 |
+
def __init__(self, config):
|
71 |
+
super(VStreamMetaModel, self).__init__(config)
|
72 |
+
|
73 |
+
self.mm_input_dim = config.mm_hidden_size
|
74 |
+
if getattr(config, 'mm_use_4_vision_tokens', False):
|
75 |
+
self.mm_input_dim = self.mm_input_dim * 4
|
76 |
+
|
77 |
+
if hasattr(config, "mm_vision_tower"):
|
78 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
79 |
+
self.mm_projector = build_vision_projector(config, self.mm_input_dim)
|
80 |
+
|
81 |
+
compress_Turing_hidden_dim = getattr(self.config, "compress_Turing_hidden_dim", 32)
|
82 |
+
self.attention_model = NeuralTuringMachine(self.mm_input_dim, compress_Turing_hidden_dim)
|
83 |
+
|
84 |
+
def get_vision_tower(self):
|
85 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
86 |
+
if type(vision_tower) is list:
|
87 |
+
vision_tower = vision_tower[0]
|
88 |
+
return vision_tower
|
89 |
+
|
90 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
91 |
+
vision_tower = model_args.vision_tower
|
92 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
93 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
94 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
95 |
+
|
96 |
+
self.config.mm_vision_tower = vision_tower
|
97 |
+
|
98 |
+
if self.get_vision_tower() is None:
|
99 |
+
vision_tower = build_vision_tower(model_args)
|
100 |
+
|
101 |
+
if fsdp is not None and len(fsdp) > 0:
|
102 |
+
self.vision_tower = [vision_tower]
|
103 |
+
else:
|
104 |
+
self.vision_tower = vision_tower
|
105 |
+
else:
|
106 |
+
if fsdp is not None and len(fsdp) > 0:
|
107 |
+
vision_tower = self.vision_tower[0]
|
108 |
+
else:
|
109 |
+
vision_tower = self.vision_tower
|
110 |
+
vision_tower.load_model()
|
111 |
+
|
112 |
+
self.config.use_mm_proj = True
|
113 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
114 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
115 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
116 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
117 |
+
|
118 |
+
self.config.compress_type = getattr(model_args, "compress_type", None)
|
119 |
+
self.config.compress_size = getattr(model_args, "compress_size", 1)
|
120 |
+
self.config.compress_long_memory_size = getattr(model_args, "compress_long_memory_size", 1)
|
121 |
+
self.config.compress_Turing_memory_size = getattr(model_args, "compress_Turing_memory_size", 1)
|
122 |
+
self.config.compress_Turing_update_ratio = getattr(model_args, "compress_Turing_update_ratio", 0.2)
|
123 |
+
self.config.video_max_frames = getattr(model_args, "video_max_frames", 50)
|
124 |
+
self.config.video_long_memory_length = getattr(model_args, "video_long_memory_length", 10)
|
125 |
+
self.config.video_Turing_memory_length = getattr(model_args, "video_Turing_memory_length", 10)
|
126 |
+
self.config.video_short_memory_length = getattr(model_args, "video_short_memory_length", 10)
|
127 |
+
self.config.video_current_memory_length = getattr(model_args, "video_current_memory_length", 1)
|
128 |
+
self.config.video_sample_type = getattr(model_args, "video_sample_type", "center")
|
129 |
+
|
130 |
+
if getattr(self, 'mm_projector', None) is None:
|
131 |
+
self.mm_projector = build_vision_projector(self.config)
|
132 |
+
else:
|
133 |
+
# In case it is frozen by LoRA
|
134 |
+
for p in self.mm_projector.parameters():
|
135 |
+
p.requires_grad = True
|
136 |
+
|
137 |
+
if pretrain_mm_mlp_adapter is not None:
|
138 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
139 |
+
def get_w(weights, keyword):
|
140 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
141 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
142 |
+
|
143 |
+
class VStreamMetaForCausalLM(ABC):
|
144 |
+
|
145 |
+
def __init__(self, config):
|
146 |
+
super(VStreamMetaForCausalLM, self).__init__(config)
|
147 |
+
# support video streaming mode
|
148 |
+
self.use_video_streaming_mode = False
|
149 |
+
self.video_embedding_memory = None # set to torch.multiprocessing.Manager.list() when launching
|
150 |
+
self.video_embedding_mem_lock = Lock()
|
151 |
+
|
152 |
+
@abstractmethod
|
153 |
+
def get_model(self):
|
154 |
+
pass
|
155 |
+
|
156 |
+
def get_vision_tower(self):
|
157 |
+
return self.get_model().get_vision_tower()
|
158 |
+
|
159 |
+
def encode_images(self, images):
|
160 |
+
image_features = self.get_model().get_vision_tower()(images)
|
161 |
+
return image_features
|
162 |
+
|
163 |
+
def reshape_2x2_image_features(self, image_features):
|
164 |
+
B, P, D = image_features.shape
|
165 |
+
patch_size = round(math.sqrt(P))
|
166 |
+
assert patch_size % 2 == 0, "Patch size must be divisible by 2."
|
167 |
+
image_features = image_features.reshape(B, patch_size, patch_size, D)
|
168 |
+
image_features_2x2 = image_features.reshape(B, patch_size // 2, 2, patch_size // 2, 2, D)
|
169 |
+
image_features_2x2 = image_features_2x2.permute(0, 1, 3, 2, 4, 5)
|
170 |
+
image_features_2x2 = image_features_2x2.reshape(B, patch_size // 2, patch_size // 2, 4 * D) # concat 2x2 neighbor patches
|
171 |
+
image_features = image_features_2x2.reshape(B, (patch_size // 2) ** 2, 4 * D)
|
172 |
+
return image_features
|
173 |
+
|
174 |
+
def attention(self, turing_memory, new_feature, update_ratio=0.2):
|
175 |
+
T1, D1 = turing_memory.shape
|
176 |
+
T2, D2 = new_feature.shape
|
177 |
+
assert D1 == D2, f"dimmension not match, {D1} != {D2}"
|
178 |
+
model = self.get_model().attention_model
|
179 |
+
weight = model.get_weight(turing_memory, new_feature)
|
180 |
+
weight = weight * update_ratio # [T1, T2]
|
181 |
+
decay = weight.sum(dim=1, keepdim=True) # [T0*P, 1], 表示当前NTM memory和新来的feat的相似度
|
182 |
+
turing_memory = turing_memory * (1 - decay) + torch.mm(weight, new_feature)
|
183 |
+
return turing_memory
|
184 |
+
|
185 |
+
def attention2(self, turing_memory, new_feature, update_ratio=0.2): # deprecated
|
186 |
+
T1, D1 = turing_memory.shape
|
187 |
+
T2, D2 = new_feature.shape
|
188 |
+
assert D1 == D2, f"dimmension not match, {D1} != {D2}"
|
189 |
+
model = self.get_model().attention_model
|
190 |
+
turing_memory = model.forward(turing_memory, new_feature)
|
191 |
+
return turing_memory
|
192 |
+
|
193 |
+
def compress_spatial_features(self, image_features, compress_size=1):
|
194 |
+
compress_type = getattr(self.config, "compress_type", None)
|
195 |
+
patch_size = round(math.sqrt(image_features.shape[1]))
|
196 |
+
assert patch_size * patch_size == image_features.shape[1], f"For ViT feature map, {patch_size}*{patch_size}={patch_size**2} != {image_features.shape[1]}"
|
197 |
+
if patch_size == compress_size:
|
198 |
+
return image_features
|
199 |
+
elif compress_type is not None:
|
200 |
+
if 'mean' in self.config.compress_type:
|
201 |
+
# TODO: currently use 1 token per frame (or image), direct poolt
|
202 |
+
if compress_size == 1:
|
203 |
+
image_features = image_features.mean(dim=1, keepdim=True)
|
204 |
+
else:
|
205 |
+
image_features = image_features.view(-1, patch_size, patch_size, image_features.shape[-1])
|
206 |
+
image_features = image_features.permute(0, 3, 1, 2) # [B*T, D, P, P]
|
207 |
+
pooled_features = F.avg_pool2d(image_features, (patch_size // compress_size, patch_size // compress_size))
|
208 |
+
pooled_features = pooled_features.permute(0, 2, 3, 1) # [B*T, P, P, D]
|
209 |
+
image_features = pooled_features.view(-1, compress_size * compress_size, pooled_features.shape[-1])
|
210 |
+
else:
|
211 |
+
raise NotImplementedError(f"`compress_type` {self.config.compress_type} is not supported yet.")
|
212 |
+
return image_features
|
213 |
+
|
214 |
+
def compress_temporal_features(self, image_features):
|
215 |
+
video_long_memory_length = getattr(self.config, "video_long_memory_length", 10)
|
216 |
+
video_Turing_memory_length = getattr(self.config, "video_Turing_memory_length", 10)
|
217 |
+
video_short_memory_length = getattr(self.config, "video_short_memory_length", 10) # not used
|
218 |
+
video_current_memory_length = getattr(self.config, "video_current_memory_length", 1)
|
219 |
+
compress_long_memory_size = getattr(self.config, "compress_long_memory_size", 1)
|
220 |
+
compress_Turing_memory_size = getattr(self.config, "compress_Turing_memory_size", 1)
|
221 |
+
compress_Turing_update_ratio = getattr(self.config, "compress_Turing_update_ratio", 0.2)
|
222 |
+
compress_fn_dic = {
|
223 |
+
'drop': drop_feature,
|
224 |
+
'merge': merge_feature,
|
225 |
+
'kmeans': kmeans_feature,
|
226 |
+
'weighted_kmeans': weighted_kmeans_feature,
|
227 |
+
'kdrop': k_drop_feature,
|
228 |
+
'kmerge': k_merge_feature,
|
229 |
+
'attention': attention_feature,
|
230 |
+
}
|
231 |
+
compress_type = self.config.video_sample_type
|
232 |
+
if compress_type in compress_fn_dic:
|
233 |
+
compress_fn = compress_fn_dic[compress_type]
|
234 |
+
else:
|
235 |
+
raise NotImplementedError(f'max_length = {self.config.video_max_frames},'
|
236 |
+
f'while video_sample_type = {compress_type} is not supported yet.')
|
237 |
+
new_image_features = []
|
238 |
+
step_indices = []
|
239 |
+
step_features = []
|
240 |
+
for img_feature in image_features: # [T, P*P, D]
|
241 |
+
cur_start = min(video_current_memory_length, img_feature.shape[0])
|
242 |
+
### Calc Spatial Memory
|
243 |
+
if cur_start == 0:
|
244 |
+
cur_memory = img_feature[:0]
|
245 |
+
long_memory = img_feature
|
246 |
+
Turing_memory = img_feature
|
247 |
+
else:
|
248 |
+
cur_memory = img_feature[-cur_start:] # [C, P*P, D]
|
249 |
+
long_memory = img_feature[:-cur_start] # [L, P*P, D]
|
250 |
+
Turing_memory = img_feature[:-cur_start] # [L, P*P, D]
|
251 |
+
if compress_long_memory_size * compress_long_memory_size != long_memory.shape[1]:
|
252 |
+
long_memory = self.compress_spatial_features(long_memory, compress_long_memory_size) # [L, P'*P', D]
|
253 |
+
if compress_Turing_memory_size * compress_Turing_memory_size != Turing_memory.shape[1]:
|
254 |
+
Turing_memory = self.compress_spatial_features(Turing_memory, compress_Turing_memory_size) # [L, P'*P', D]
|
255 |
+
### Calc Temporal Memory
|
256 |
+
if video_long_memory_length == 0 or long_memory.shape[0] == 0:
|
257 |
+
long_memory_compreesed = long_memory[:0]
|
258 |
+
else:
|
259 |
+
long_memory_compreesed, weight, step_long_indices = compress_fn(long_memory, video_long_memory_length) # [L_long, P'*P', D], [L_long]
|
260 |
+
### Calc Retrieved Memory
|
261 |
+
sorted_indices = torch.argsort(weight, descending=True) # [L_long]
|
262 |
+
key_centroids = long_memory[sorted_indices] # [L_long, P'*P', D]
|
263 |
+
key_length = 3
|
264 |
+
if key_centroids.shape[0] > key_length:
|
265 |
+
key_centroids = key_centroids[:key_length]
|
266 |
+
dists = ((long_memory.unsqueeze(1) - key_centroids.unsqueeze(0)) ** 2).sum(dim=3).sum(dim=2).sqrt() # [L_long, k_L]
|
267 |
+
min_indices = torch.argmin(dists, dim=0) # [k_L]
|
268 |
+
key_memory = img_feature[min_indices]
|
269 |
+
cur_memory = torch.cat([key_memory, cur_memory], dim=0)
|
270 |
+
### Calc Abstract Memory
|
271 |
+
if video_Turing_memory_length == 0 or Turing_memory.shape[0] == 0:
|
272 |
+
Turing_memory_compreesed = Turing_memory[:0]
|
273 |
+
else:
|
274 |
+
Turing_memory_compreesed, _ = attention_feature(Turing_memory, video_Turing_memory_length, self.attention, update_ratio=compress_Turing_update_ratio)
|
275 |
+
memory_feature = torch.cat([Turing_memory_compreesed.flatten(0, 1), long_memory_compreesed.flatten(0, 1), cur_memory.flatten(0, 1)], dim=0)
|
276 |
+
new_image_features.append(memory_feature)
|
277 |
+
return new_image_features
|
278 |
+
|
279 |
+
def cat_proj(self, all_features): # concatenate features and project them together
|
280 |
+
feature_split_size = [x.shape[0] for x in all_features]
|
281 |
+
feature_embed = torch.cat(all_features, dim=0)
|
282 |
+
feature_proj = self.get_model().mm_projector(feature_embed)
|
283 |
+
feature_proj = torch.split(feature_proj, feature_split_size, dim=0)
|
284 |
+
return feature_proj
|
285 |
+
|
286 |
+
def prepare_inputs_labels_for_multimodal(
|
287 |
+
self,
|
288 |
+
input_ids,
|
289 |
+
position_ids,
|
290 |
+
attention_mask,
|
291 |
+
past_key_values,
|
292 |
+
labels,
|
293 |
+
images,
|
294 |
+
features
|
295 |
+
):
|
296 |
+
vision_tower = self.get_vision_tower()
|
297 |
+
if vision_tower is None or (images is None and features is None) or input_ids.shape[1] == 1:
|
298 |
+
if past_key_values is not None and vision_tower is not None and ((images is not None) or (features is not None)) and input_ids.shape[1] == 1:
|
299 |
+
target_shape = past_key_values[-1][-1].shape[-2] + 1
|
300 |
+
if target_shape - attention_mask.shape[1] >= 0:
|
301 |
+
attention_mask = torch.cat((attention_mask, torch.ones(
|
302 |
+
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
|
303 |
+
dtype=attention_mask.dtype,
|
304 |
+
device=attention_mask.device
|
305 |
+
)), dim=1)
|
306 |
+
elif target_shape - attention_mask.shape[1] < 0:
|
307 |
+
attention_mask = attention_mask[:, :target_shape]
|
308 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
309 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
310 |
+
|
311 |
+
if (features is not None) or (type(images) is list) or (images.ndim == 5):
|
312 |
+
compress_size = getattr(self.config, "compress_size", 1)
|
313 |
+
if images is not None:
|
314 |
+
images = [image if len(image.shape) == 4 else image.unsqueeze(0) for image in images] # [B, T, C, H, W]
|
315 |
+
concat_images = torch.cat([image for image in images], dim=0) # [B*T, C, H, W]
|
316 |
+
image_features = self.encode_images(concat_images) # [B*T, P, D]
|
317 |
+
if getattr(self.config, 'mm_use_4_vision_tokens', False):
|
318 |
+
image_features = self.reshape_2x2_image_features(image_features) # [B*T, P/4, 4*D]
|
319 |
+
image_features = self.compress_spatial_features(image_features, compress_size) # [B*T, P', D]
|
320 |
+
split_sizes = [image.shape[0] for image in images]
|
321 |
+
image_features = torch.split(image_features, split_sizes, dim=0) # [B, T, P, D]
|
322 |
+
else:
|
323 |
+
image_features = [feat if len(feat.shape) == 3 else feat.unsqueeze(0) for feat in features]
|
324 |
+
origin_img_features = image_features
|
325 |
+
if getattr(self.config, 'mm_use_4_vision_tokens', False):
|
326 |
+
image_features = [self.reshape_2x2_image_features(img_feature) for img_feature in image_features] # [B*T, P/4, 4*D]
|
327 |
+
image_features = [self.compress_spatial_features(image_feature, compress_size) for image_feature in image_features] # [B*T, P', D]
|
328 |
+
# perform memory consolidation
|
329 |
+
image_features = self.compress_temporal_features(image_features) # [B, TP, D]
|
330 |
+
image_features = [x.to(self.device) for x in image_features] # [B, TP, D]
|
331 |
+
image_features = self.cat_proj(image_features)
|
332 |
+
else:
|
333 |
+
image_features = self.encode_images(images).to(self.device) # [B, 576, 2048]
|
334 |
+
if getattr(self.config, 'mm_use_4_vision_tokens', False):
|
335 |
+
image_features = self.reshape_2x2_image_features(image_features) # [B*T, P/4, 4*D]
|
336 |
+
image_features = self.get_model().mm_projector(image_features)
|
337 |
+
|
338 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
339 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
340 |
+
raise NotImplementedError
|
341 |
+
|
342 |
+
_labels = labels
|
343 |
+
_position_ids = position_ids
|
344 |
+
_attention_mask = attention_mask
|
345 |
+
if attention_mask is None:
|
346 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
347 |
+
else:
|
348 |
+
attention_mask = attention_mask.bool()
|
349 |
+
if position_ids is None:
|
350 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
351 |
+
if labels is None:
|
352 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
353 |
+
|
354 |
+
# remove the padding using attention_mask -- TODO: double check
|
355 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
356 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
357 |
+
new_input_embeds = []
|
358 |
+
new_labels = []
|
359 |
+
cur_image_idx = 0
|
360 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
361 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
362 |
+
if num_images == 0:
|
363 |
+
cur_image_features = image_features[cur_image_idx]
|
364 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
365 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
366 |
+
new_input_embeds.append(cur_input_embeds)
|
367 |
+
new_labels.append(labels[batch_idx])
|
368 |
+
cur_image_idx += 1
|
369 |
+
continue
|
370 |
+
|
371 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] # only input first image_token
|
372 |
+
cur_input_ids_noim = []
|
373 |
+
cur_labels = labels[batch_idx]
|
374 |
+
cur_labels_noim = []
|
375 |
+
for i in range(len(image_token_indices) - 1):
|
376 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
377 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
378 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
379 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
380 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
381 |
+
cur_new_input_embeds = []
|
382 |
+
cur_new_labels = []
|
383 |
+
|
384 |
+
for i in range(num_images + 1):
|
385 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
386 |
+
cur_new_labels.append(cur_labels_noim[i])
|
387 |
+
if i < num_images:
|
388 |
+
cur_image_features = image_features[cur_image_idx]
|
389 |
+
cur_image_idx += 1
|
390 |
+
cur_new_input_embeds.append(cur_image_features)
|
391 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
392 |
+
|
393 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
394 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
395 |
+
|
396 |
+
new_input_embeds.append(cur_new_input_embeds)
|
397 |
+
new_labels.append(cur_new_labels)
|
398 |
+
assert cur_image_idx == batch_idx + 1
|
399 |
+
|
400 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
401 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
402 |
+
if tokenizer_model_max_length is not None:
|
403 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
404 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
405 |
+
|
406 |
+
# Combine them
|
407 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
408 |
+
batch_size = len(new_input_embeds)
|
409 |
+
|
410 |
+
new_input_embeds_padded = []
|
411 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
412 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
413 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
414 |
+
|
415 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
416 |
+
cur_len = cur_new_embed.shape[0]
|
417 |
+
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
418 |
+
new_input_embeds_padded.append(torch.cat((
|
419 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
420 |
+
cur_new_embed
|
421 |
+
), dim=0))
|
422 |
+
if cur_len > 0:
|
423 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
424 |
+
attention_mask[i, -cur_len:] = True
|
425 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
426 |
+
else:
|
427 |
+
new_input_embeds_padded.append(torch.cat((
|
428 |
+
cur_new_embed,
|
429 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
430 |
+
), dim=0))
|
431 |
+
if cur_len > 0:
|
432 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
433 |
+
attention_mask[i, :cur_len] = True
|
434 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
435 |
+
|
436 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
437 |
+
|
438 |
+
if _labels is None:
|
439 |
+
new_labels = None
|
440 |
+
else:
|
441 |
+
new_labels = new_labels_padded
|
442 |
+
|
443 |
+
if _attention_mask is None:
|
444 |
+
attention_mask = None
|
445 |
+
else:
|
446 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
447 |
+
|
448 |
+
if _position_ids is None:
|
449 |
+
position_ids = None
|
450 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
451 |
+
|
452 |
+
def prepare_inputs_labels_for_multimodal_streaming( # Asynchronous encoding with a SemLock, only for videos, batch_size=1
|
453 |
+
self,
|
454 |
+
input_ids,
|
455 |
+
position_ids,
|
456 |
+
attention_mask,
|
457 |
+
past_key_values,
|
458 |
+
labels
|
459 |
+
):
|
460 |
+
assert self.use_video_streaming_mode
|
461 |
+
logger = logging.getLogger(__name__)
|
462 |
+
vision_tower = self.get_vision_tower()
|
463 |
+
if vision_tower is None or input_ids.shape[1] == 1:
|
464 |
+
if past_key_values is not None and vision_tower is not None and input_ids.shape[1] == 1:
|
465 |
+
target_shape = past_key_values[-1][-1].shape[-2] + 1
|
466 |
+
if target_shape - attention_mask.shape[1] >= 0:
|
467 |
+
attention_mask = torch.cat((attention_mask, torch.ones(
|
468 |
+
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
|
469 |
+
dtype=attention_mask.dtype,
|
470 |
+
device=attention_mask.device
|
471 |
+
)), dim=1)
|
472 |
+
elif target_shape - attention_mask.shape[1] < 0:
|
473 |
+
attention_mask = attention_mask[:, :target_shape]
|
474 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
475 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
476 |
+
# Have some tries to avoid deadlock
|
477 |
+
attempt_times = 0
|
478 |
+
while attempt_times < 300:
|
479 |
+
try:
|
480 |
+
with self.video_embedding_mem_lock:
|
481 |
+
cur_memory, long_memory_compreesed, Turing_memory_compreesed, _ = self.video_embedding_memory
|
482 |
+
logger.info(f'Read cur_memory={cur_memory.shape} {cur_memory.dtype}, long_memory_compreesed={long_memory_compreesed.shape} {long_memory_compreesed.dtype}, Turing_memory_compreesed={Turing_memory_compreesed.shape} {Turing_memory_compreesed.dtype}')
|
483 |
+
image_feature = torch.cat([Turing_memory_compreesed.flatten(0, 1), long_memory_compreesed.flatten(0, 1), cur_memory.flatten(0, 1)], dim=0)
|
484 |
+
image_features = [image_feature.to(self.device)]
|
485 |
+
break
|
486 |
+
|
487 |
+
except Exception as e:
|
488 |
+
logger.error(f'Attempt:{attempt_times} Failed to get video features, Error: {e}')
|
489 |
+
image_features = []
|
490 |
+
time.sleep(0.1)
|
491 |
+
attempt_times += 1
|
492 |
+
|
493 |
+
image_features = [x.to(self.device) for x in image_features] # [B, TP, D]
|
494 |
+
image_features = self.cat_proj(image_features)
|
495 |
+
|
496 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
497 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
498 |
+
raise NotImplementedError
|
499 |
+
|
500 |
+
_labels = labels
|
501 |
+
_position_ids = position_ids
|
502 |
+
_attention_mask = attention_mask
|
503 |
+
if attention_mask is None:
|
504 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
505 |
+
else:
|
506 |
+
attention_mask = attention_mask.bool()
|
507 |
+
if position_ids is None:
|
508 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
509 |
+
if labels is None:
|
510 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
511 |
+
|
512 |
+
# remove the padding using attention_mask -- TODO: double check
|
513 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
514 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
515 |
+
|
516 |
+
new_input_embeds = []
|
517 |
+
new_labels = []
|
518 |
+
cur_image_idx = 0
|
519 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
520 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
521 |
+
if num_images == 0:
|
522 |
+
cur_image_features = image_features[cur_image_idx]
|
523 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
524 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
525 |
+
new_input_embeds.append(cur_input_embeds)
|
526 |
+
new_labels.append(labels[batch_idx])
|
527 |
+
cur_image_idx += 1
|
528 |
+
continue
|
529 |
+
|
530 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] # only input first image_token
|
531 |
+
cur_input_ids_noim = []
|
532 |
+
cur_labels = labels[batch_idx]
|
533 |
+
cur_labels_noim = []
|
534 |
+
for i in range(len(image_token_indices) - 1):
|
535 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
536 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
537 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
538 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
539 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
540 |
+
cur_new_input_embeds = []
|
541 |
+
cur_new_labels = []
|
542 |
+
|
543 |
+
for i in range(num_images + 1):
|
544 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
545 |
+
cur_new_labels.append(cur_labels_noim[i])
|
546 |
+
if i < num_images:
|
547 |
+
cur_image_features = image_features[cur_image_idx]
|
548 |
+
cur_image_idx += 1
|
549 |
+
cur_new_input_embeds.append(cur_image_features)
|
550 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
551 |
+
|
552 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
553 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
554 |
+
|
555 |
+
new_input_embeds.append(cur_new_input_embeds)
|
556 |
+
new_labels.append(cur_new_labels)
|
557 |
+
assert cur_image_idx == batch_idx + 1
|
558 |
+
|
559 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
560 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
561 |
+
if tokenizer_model_max_length is not None:
|
562 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
563 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
564 |
+
|
565 |
+
# Combine them
|
566 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
567 |
+
batch_size = len(new_input_embeds)
|
568 |
+
|
569 |
+
new_input_embeds_padded = []
|
570 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
571 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
572 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
573 |
+
|
574 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
575 |
+
cur_len = cur_new_embed.shape[0]
|
576 |
+
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
577 |
+
new_input_embeds_padded.append(torch.cat((
|
578 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
579 |
+
cur_new_embed
|
580 |
+
), dim=0))
|
581 |
+
if cur_len > 0:
|
582 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
583 |
+
attention_mask[i, -cur_len:] = True
|
584 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
585 |
+
else:
|
586 |
+
new_input_embeds_padded.append(torch.cat((
|
587 |
+
cur_new_embed,
|
588 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
589 |
+
), dim=0))
|
590 |
+
if cur_len > 0:
|
591 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
592 |
+
attention_mask[i, :cur_len] = True
|
593 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
594 |
+
|
595 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
596 |
+
|
597 |
+
if _labels is None:
|
598 |
+
new_labels = None
|
599 |
+
else:
|
600 |
+
new_labels = new_labels_padded
|
601 |
+
|
602 |
+
if _attention_mask is None:
|
603 |
+
attention_mask = None
|
604 |
+
else:
|
605 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
606 |
+
|
607 |
+
if _position_ids is None:
|
608 |
+
position_ids = None
|
609 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
610 |
+
|
611 |
+
def embed_video_streaming( # Asynchronous encoding with a SemLock, only for videos, batch_size=1
|
612 |
+
self,
|
613 |
+
images
|
614 |
+
):
|
615 |
+
assert self.use_video_streaming_mode
|
616 |
+
logger = logging.getLogger(__name__)
|
617 |
+
|
618 |
+
compress_size = getattr(self.config, "compress_size", 1)
|
619 |
+
video_long_memory_length = getattr(self.config, "video_long_memory_length", 10)
|
620 |
+
video_Turing_memory_length = getattr(self.config, "video_Turing_memory_length", 10)
|
621 |
+
video_short_memory_length = getattr(self.config, "video_short_memory_length", 10) # not used
|
622 |
+
video_current_memory_length = getattr(self.config, "video_current_memory_length", 1)
|
623 |
+
compress_long_memory_size = getattr(self.config, "compress_long_memory_size", 1)
|
624 |
+
compress_Turing_memory_size = getattr(self.config, "compress_Turing_memory_size", 1)
|
625 |
+
compress_Turing_update_ratio = getattr(self.config, "compress_Turing_update_ratio", 0.2)
|
626 |
+
compress_fn_dic = {
|
627 |
+
'drop': drop_feature,
|
628 |
+
'merge': merge_feature,
|
629 |
+
'kmeans': kmeans_feature,
|
630 |
+
'weighted_kmeans': weighted_kmeans_feature,
|
631 |
+
'kdrop': k_drop_feature,
|
632 |
+
'kmerge': k_merge_feature,
|
633 |
+
'uni_kmerge': k_merge_feature,
|
634 |
+
'both_kmerge': k_merge_feature,
|
635 |
+
'split_kmerge': k_merge_feature,
|
636 |
+
'attention': attention_feature,
|
637 |
+
}
|
638 |
+
|
639 |
+
if type(images) is list or images.ndim == 5:
|
640 |
+
assert len(images) == 1
|
641 |
+
images = [image if len(image.shape) == 4 else image.unsqueeze(0) for image in images] # [B, T, C, H, W]
|
642 |
+
concat_images = torch.cat([image for image in images], dim=0) # [B*T, C, H, W]
|
643 |
+
image_features = self.encode_images(concat_images) # [B*T, P, D]
|
644 |
+
image_features = self.compress_spatial_features(image_features, compress_size) # [B*T, P', D]
|
645 |
+
split_sizes = [image.shape[0] for image in images]
|
646 |
+
image_features = torch.split(image_features, split_sizes, dim=0) # [B, T, P, D]
|
647 |
+
else:
|
648 |
+
raise NotImplementedError('Should input video frames, not a single image')
|
649 |
+
image_feature = image_features[0].detach().to(torch.float16).to(self.device) # [T, P, D]
|
650 |
+
img_feature_buffer = image_feature.cpu()
|
651 |
+
|
652 |
+
cur_start = min(video_current_memory_length, image_feature.shape[0])
|
653 |
+
if cur_start == 0:
|
654 |
+
cur_memory = image_feature[:0]
|
655 |
+
else:
|
656 |
+
cur_memory = image_feature[-cur_start:] # [L_c, P*P, D]
|
657 |
+
long_memory = image_feature
|
658 |
+
Turing_memory = image_feature
|
659 |
+
if compress_long_memory_size * compress_long_memory_size != long_memory.shape[1]:
|
660 |
+
long_memory = self.compress_spatial_features(long_memory, compress_long_memory_size) # [L_l, P'*P', D]
|
661 |
+
if compress_Turing_memory_size * compress_Turing_memory_size != Turing_memory.shape[1]:
|
662 |
+
Turing_memory = self.compress_spatial_features(Turing_memory, compress_Turing_memory_size) # [L_t, P'*P', D]
|
663 |
+
compress_type = self.config.video_sample_type
|
664 |
+
if compress_type in compress_fn_dic:
|
665 |
+
compress_fn = compress_fn_dic[compress_type]
|
666 |
+
else:
|
667 |
+
raise NotImplementedError(f'max_length = {self.config.video_max_frames},'
|
668 |
+
f'while video_sample_type = {compress_type} is not supported yet.')
|
669 |
+
long_memory_compreesed = long_memory
|
670 |
+
Turing_memory_compreesed = Turing_memory
|
671 |
+
# Read old memory from shared memory, do not need an I/O lock
|
672 |
+
if self.video_embedding_memory is not None and len(self.video_embedding_memory) > 0:
|
673 |
+
old_cur_memory, old_long_memory_compreesed, old_Turing_memory_compreesed, old_img_feature_buffer = self.video_embedding_memory
|
674 |
+
old_long_memory_compreesed = old_long_memory_compreesed.to(self.device)
|
675 |
+
old_Turing_memory_compreesed = old_Turing_memory_compreesed.to(self.device)
|
676 |
+
img_feature_buffer = torch.cat([old_img_feature_buffer, image_feature.cpu()], dim=0)
|
677 |
+
assert isinstance(old_long_memory_compreesed, torch.Tensor) and old_long_memory_compreesed.shape[1:] == long_memory_compreesed.shape[1:]
|
678 |
+
long_memory = torch.cat((old_long_memory_compreesed, long_memory_compreesed), dim=0)
|
679 |
+
long_memory_compreesed, weight, step_long_indices = compress_fn(long_memory, video_long_memory_length)
|
680 |
+
# Retrive key frames
|
681 |
+
sorted_indices = torch.argsort(weight, descending=True) # [L_long]
|
682 |
+
key_centroids = long_memory[sorted_indices] # [L_long, P'*P', D]
|
683 |
+
key_length = 3
|
684 |
+
if key_centroids.shape[0] > key_length:
|
685 |
+
key_centroids = key_centroids[:key_length]
|
686 |
+
dists = ((long_memory.unsqueeze(1) - key_centroids.unsqueeze(0)) ** 2).sum(dim=3).sum(dim=2).sqrt() # [L_long, k_L]
|
687 |
+
min_indices = torch.argmin(dists, dim=0) # [k_L]
|
688 |
+
key_memory = img_feature_buffer[min_indices.cpu()].to(self.device)
|
689 |
+
cur_memory = torch.cat([key_memory, cur_memory], dim=0)
|
690 |
+
Turing_memory = torch.cat((old_Turing_memory_compreesed, Turing_memory_compreesed), dim=0)
|
691 |
+
Turing_memory_compreesed, _ = attention_feature(Turing_memory, video_Turing_memory_length, self.attention, update_ratio=compress_Turing_update_ratio)
|
692 |
+
# Write to shared memory, need an I/O lock
|
693 |
+
with self.video_embedding_mem_lock:
|
694 |
+
self.video_embedding_memory[:] = [cur_memory.cpu(), long_memory_compreesed.cpu(), Turing_memory_compreesed.cpu(), img_feature_buffer] # Only change content
|
695 |
+
logger.info(f'Write cur_memory={cur_memory.shape} {cur_memory.dtype}, long_memory_compreesed={long_memory_compreesed.shape} {long_memory_compreesed.dtype}, Turing_memory_compreesed={Turing_memory_compreesed.shape} {Turing_memory_compreesed.dtype}')
|
696 |
+
|
697 |
+
return []
|
698 |
+
|
699 |
+
|
700 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
701 |
+
if model_args.mm_use_im_patch_token:
|
702 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
703 |
+
self.resize_token_embeddings(len(tokenizer))
|
704 |
+
|
705 |
+
if model_args.mm_use_im_start_end:
|
706 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
707 |
+
self.resize_token_embeddings(len(tokenizer))
|
708 |
+
|
709 |
+
if num_new_tokens > 0:
|
710 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
711 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
712 |
+
|
713 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
714 |
+
dim=0, keepdim=True)
|
715 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
716 |
+
dim=0, keepdim=True)
|
717 |
+
|
718 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
719 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
720 |
+
|
721 |
+
if model_args.tune_mm_mlp_adapter:
|
722 |
+
for p in self.get_input_embeddings().parameters():
|
723 |
+
p.requires_grad = True
|
724 |
+
for p in self.get_output_embeddings().parameters():
|
725 |
+
p.requires_grad = False
|
726 |
+
|
727 |
+
if model_args.pretrain_mm_mlp_adapter:
|
728 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
729 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
730 |
+
assert num_new_tokens == 2
|
731 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
732 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
733 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
734 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
735 |
+
else:
|
736 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
737 |
+
elif model_args.mm_use_im_patch_token:
|
738 |
+
if model_args.tune_mm_mlp_adapter:
|
739 |
+
for p in self.get_input_embeddings().parameters():
|
740 |
+
p.requires_grad = False
|
741 |
+
for p in self.get_output_embeddings().parameters():
|
742 |
+
p.requires_grad = False
|
flash_vstream/serve/cli_video_stream.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
|
2 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
3 |
+
"""
|
4 |
+
This file demonstrates an implementation of a multiprocess Real-time Long Video Understanding System. With a multiprocess logging module.
|
5 |
+
main process: CLI server I/O, LLM inference
|
6 |
+
process-1: logger listener
|
7 |
+
process-2: frame generator,
|
8 |
+
process-3: frame memory manager
|
9 |
+
Author: Haoji Zhang, Haotian Liu
|
10 |
+
(This code is based on https://github.com/haotian-liu/LLaVA)
|
11 |
+
"""
|
12 |
+
import argparse
|
13 |
+
import requests
|
14 |
+
import logging
|
15 |
+
import torch
|
16 |
+
import numpy as np
|
17 |
+
import time
|
18 |
+
import os
|
19 |
+
|
20 |
+
from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
21 |
+
from flash_vstream.conversation import conv_templates, SeparatorStyle
|
22 |
+
from flash_vstream.model.builder import load_pretrained_model
|
23 |
+
from flash_vstream.utils import disable_torch_init
|
24 |
+
from flash_vstream.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
25 |
+
|
26 |
+
from torch.multiprocessing import Process, Queue, Manager
|
27 |
+
from transformers import TextStreamer
|
28 |
+
from decord import VideoReader
|
29 |
+
from datetime import datetime
|
30 |
+
from PIL import Image
|
31 |
+
from io import BytesIO
|
32 |
+
|
33 |
+
class _Metric:
|
34 |
+
def __init__(self):
|
35 |
+
self._latest_value = None
|
36 |
+
self._sum = 0.0
|
37 |
+
self._max = 0.0
|
38 |
+
self._count = 0
|
39 |
+
|
40 |
+
@property
|
41 |
+
def val(self):
|
42 |
+
return self._latest_value
|
43 |
+
|
44 |
+
@property
|
45 |
+
def max(self):
|
46 |
+
return self._max
|
47 |
+
|
48 |
+
@property
|
49 |
+
def avg(self):
|
50 |
+
if self._count == 0:
|
51 |
+
return float('nan')
|
52 |
+
return self._sum / self._count
|
53 |
+
|
54 |
+
def add(self, value):
|
55 |
+
self._latest_value = value
|
56 |
+
self._sum += value
|
57 |
+
self._count += 1
|
58 |
+
if value > self._max:
|
59 |
+
self._max = value
|
60 |
+
|
61 |
+
def __str__(self):
|
62 |
+
latest_formatted = f"{self.val:.6f}" if self.val is not None else "None"
|
63 |
+
average_formatted = f"{self.avg:.6f}"
|
64 |
+
max_formatted = f"{self.max:.6f}"
|
65 |
+
return f"{latest_formatted} ({average_formatted}, {max_formatted})"
|
66 |
+
|
67 |
+
|
68 |
+
class MetricMeter:
|
69 |
+
def __init__(self):
|
70 |
+
self._metrics = {}
|
71 |
+
|
72 |
+
def add(self, key, value):
|
73 |
+
if key not in self._metrics:
|
74 |
+
self._metrics[key] = _Metric()
|
75 |
+
self._metrics[key].add(value)
|
76 |
+
|
77 |
+
def val(self, key):
|
78 |
+
metric = self._metrics.get(key)
|
79 |
+
if metric is None or metric.val is None:
|
80 |
+
raise ValueError(f"No values have been added for key '{key}'.")
|
81 |
+
return metric.val
|
82 |
+
|
83 |
+
def avg(self, key):
|
84 |
+
metric = self._metrics.get(key)
|
85 |
+
if metric is None:
|
86 |
+
raise ValueError(f"No values have been added for key '{key}'.")
|
87 |
+
return metric.avg
|
88 |
+
|
89 |
+
def max(self, key):
|
90 |
+
metric = self._metrics.get(key)
|
91 |
+
if metric is None:
|
92 |
+
raise ValueError(f"No values have been added for key '{key}'.")
|
93 |
+
return metric.max
|
94 |
+
|
95 |
+
def __getitem__(self, key):
|
96 |
+
metric = self._metrics.get(key)
|
97 |
+
if metric is None:
|
98 |
+
raise KeyError(f"The key '{key}' does not exist.")
|
99 |
+
return str(metric)
|
100 |
+
|
101 |
+
def load_image(image_file):
|
102 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
103 |
+
response = requests.get(image_file)
|
104 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
105 |
+
else:
|
106 |
+
image = Image.open(image_file).convert('RGB')
|
107 |
+
return image
|
108 |
+
|
109 |
+
def listener(queue, filename):
|
110 |
+
############## Start sub process-1: Listener #############
|
111 |
+
import sys, traceback
|
112 |
+
root = logging.getLogger()
|
113 |
+
root.setLevel(logging.DEBUG)
|
114 |
+
# h = logging.StreamHandler(sys.stdout)
|
115 |
+
h = logging.FileHandler(filename)
|
116 |
+
f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s')
|
117 |
+
h.setFormatter(f)
|
118 |
+
root.addHandler(h)
|
119 |
+
while True:
|
120 |
+
try:
|
121 |
+
record = queue.get()
|
122 |
+
if record is None: # None is a signal to finish
|
123 |
+
break
|
124 |
+
logger = logging.getLogger(record.name)
|
125 |
+
logger.handle(record) # No level or filter logic applied - just do it!
|
126 |
+
except Exception:
|
127 |
+
import sys, traceback
|
128 |
+
print('Whoops! Problem:', file=sys.stderr)
|
129 |
+
traceback.print_exc(file=sys.stderr)
|
130 |
+
|
131 |
+
def worker_configurer(queue):
|
132 |
+
h = logging.handlers.QueueHandler(queue) # Just the one handler needed
|
133 |
+
root = logging.getLogger()
|
134 |
+
root.addHandler(h)
|
135 |
+
root.setLevel(logging.DEBUG)
|
136 |
+
|
137 |
+
def video_stream_similator(video_file, frame_queue, log_queue, video_fps=1.0, play_speed=1.0):
|
138 |
+
############## Start sub process-2: Simulator #############
|
139 |
+
worker_configurer(log_queue)
|
140 |
+
logger = logging.getLogger(__name__)
|
141 |
+
logger.setLevel(logging.DEBUG)
|
142 |
+
|
143 |
+
vr = VideoReader(video_file)
|
144 |
+
sample_fps = round(vr.get_avg_fps() / video_fps)
|
145 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
146 |
+
video = vr.get_batch(frame_idx).asnumpy()
|
147 |
+
video = np.repeat(video, 6, axis=0)
|
148 |
+
length = video.shape[0]
|
149 |
+
sleep_time = 1 / video_fps / play_speed
|
150 |
+
time_meter = MetricMeter()
|
151 |
+
logger.info(f'Simulator Process: start, length = {length}')
|
152 |
+
try:
|
153 |
+
for start in range(0, length):
|
154 |
+
start_time = time.perf_counter()
|
155 |
+
end = min(start + 1, length)
|
156 |
+
video_clip = video[start:end]
|
157 |
+
frame_queue.put(video_clip)
|
158 |
+
if start > 0:
|
159 |
+
time_meter.add('real_sleep', start_time - last_start)
|
160 |
+
logger.info(f'Simulator: write {end - start} frames,\t{start} to {end},\treal_sleep={time_meter["real_sleep"]}')
|
161 |
+
if end < length:
|
162 |
+
time.sleep(sleep_time)
|
163 |
+
last_start = start_time
|
164 |
+
frame_queue.put(None)
|
165 |
+
except Exception as e:
|
166 |
+
print(f'Simulator Exception: {e}')
|
167 |
+
time.sleep(0.1)
|
168 |
+
logger.info(f'Simulator Process: end')
|
169 |
+
|
170 |
+
def frame_memory_manager(model, image_processor, frame_queue, log_queue):
|
171 |
+
############## Start sub process-3: Memory Manager #############
|
172 |
+
worker_configurer(log_queue)
|
173 |
+
logger = logging.getLogger(__name__)
|
174 |
+
logger.setLevel(logging.DEBUG)
|
175 |
+
|
176 |
+
time_meter = MetricMeter()
|
177 |
+
logger.info(f'MemManager Process: start')
|
178 |
+
frame_cnt = 0
|
179 |
+
while True:
|
180 |
+
try:
|
181 |
+
video_clip = frame_queue.get()
|
182 |
+
start_time = time.perf_counter()
|
183 |
+
if video_clip is None:
|
184 |
+
logger.info(f'MemManager: Ooops, get None')
|
185 |
+
break
|
186 |
+
logger.info(f'MemManager: get {video_clip.shape[0]} frames from queue')
|
187 |
+
image = image_processor.preprocess(video_clip, return_tensors='pt')['pixel_values']
|
188 |
+
image = image.unsqueeze(0)
|
189 |
+
image_tensor = image.to(model.device, dtype=torch.float16)
|
190 |
+
# time_2 = time.perf_counter()
|
191 |
+
logger.info(f'MemManager: Start embedding')
|
192 |
+
with torch.inference_mode():
|
193 |
+
model.embed_video_streaming(image_tensor)
|
194 |
+
logger.info(f'MemManager: End embedding')
|
195 |
+
end_time = time.perf_counter()
|
196 |
+
if frame_cnt > 0:
|
197 |
+
time_meter.add('memory_latency', end_time - start_time)
|
198 |
+
logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={time_meter["memory_latency"]}')
|
199 |
+
else:
|
200 |
+
logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={end_time - start_time:.6f}, not logged')
|
201 |
+
frame_cnt += video_clip.shape[0]
|
202 |
+
except Exception as e:
|
203 |
+
print(f'MemManager Exception: {e}')
|
204 |
+
time.sleep(0.1)
|
205 |
+
logger.info(f'MemManager Process: end')
|
206 |
+
|
207 |
+
def main(args):
|
208 |
+
# torch.multiprocessing.log_to_stderr(logging.DEBUG)
|
209 |
+
torch.multiprocessing.set_start_method('spawn', force=True)
|
210 |
+
disable_torch_init()
|
211 |
+
|
212 |
+
log_queue = Queue()
|
213 |
+
frame_queue = Queue(maxsize=10)
|
214 |
+
processes = []
|
215 |
+
|
216 |
+
############## Start listener process #############
|
217 |
+
p1 = Process(target=listener, args=(log_queue, args.log_file))
|
218 |
+
processes.append(p1)
|
219 |
+
p1.start()
|
220 |
+
|
221 |
+
############## Start main process #############
|
222 |
+
worker_configurer(log_queue)
|
223 |
+
logger = logging.getLogger(__name__)
|
224 |
+
|
225 |
+
model_name = get_model_name_from_path(args.model_path)
|
226 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
227 |
+
|
228 |
+
logger.info(f'Using conv_mode={args.conv_mode}')
|
229 |
+
|
230 |
+
conv = conv_templates[args.conv_mode].copy()
|
231 |
+
if "mpt" in model_name.lower():
|
232 |
+
roles = ('user', 'assistant')
|
233 |
+
else:
|
234 |
+
roles = conv.roles
|
235 |
+
|
236 |
+
with Manager() as manager:
|
237 |
+
image_tensor = None
|
238 |
+
model.use_video_streaming_mode = True
|
239 |
+
model.video_embedding_memory = manager.list()
|
240 |
+
if args.video_max_frames is not None:
|
241 |
+
model.config.video_max_frames = args.video_max_frames
|
242 |
+
logger.info(f'Important: set model.config.video_max_frames = {model.config.video_max_frames}')
|
243 |
+
|
244 |
+
logger.info(f'Important: set video_fps = {args.video_fps}')
|
245 |
+
logger.info(f'Important: set play_speed = {args.play_speed}')
|
246 |
+
|
247 |
+
############## Start simulator process #############
|
248 |
+
p2 = Process(target=video_stream_similator,
|
249 |
+
args=(args.video_file, frame_queue, log_queue, args.video_fps, args.play_speed))
|
250 |
+
processes.append(p2)
|
251 |
+
p2.start()
|
252 |
+
|
253 |
+
############## Start memory manager process #############
|
254 |
+
p3 = Process(target=frame_memory_manager,
|
255 |
+
args=(model, image_processor, frame_queue, log_queue))
|
256 |
+
processes.append(p3)
|
257 |
+
p3.start()
|
258 |
+
|
259 |
+
# start QA server
|
260 |
+
start_time = datetime.now()
|
261 |
+
time_meter = MetricMeter()
|
262 |
+
conv_cnt = 0
|
263 |
+
while True:
|
264 |
+
time.sleep(5)
|
265 |
+
try:
|
266 |
+
# inp = input(f"{roles[0]}: ")
|
267 |
+
inp = "what is in the video?"
|
268 |
+
except EOFError:
|
269 |
+
inp = ""
|
270 |
+
if not inp:
|
271 |
+
print("exit...")
|
272 |
+
break
|
273 |
+
|
274 |
+
# 获取当前时间
|
275 |
+
now = datetime.now()
|
276 |
+
conv_start_time = time.perf_counter()
|
277 |
+
# 将当前时间格式化为字符串
|
278 |
+
current_time = now.strftime("%H:%M:%S")
|
279 |
+
duration = now.timestamp() - start_time.timestamp()
|
280 |
+
|
281 |
+
# 打印当前时间
|
282 |
+
print("\nCurrent Time:", current_time, "Run for:", duration)
|
283 |
+
print(f"{roles[0]}: {inp}", end="\n")
|
284 |
+
print(f"{roles[1]}: ", end="")
|
285 |
+
# every conversation is a new conversation
|
286 |
+
conv = conv_templates[args.conv_mode].copy()
|
287 |
+
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
|
288 |
+
conv.append_message(conv.roles[0], inp)
|
289 |
+
|
290 |
+
conv.append_message(conv.roles[1], None)
|
291 |
+
prompt = conv.get_prompt()
|
292 |
+
|
293 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
294 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
295 |
+
keywords = [stop_str]
|
296 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
297 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
298 |
+
|
299 |
+
llm_start_time = time.perf_counter()
|
300 |
+
with torch.inference_mode():
|
301 |
+
output_ids = model.generate(
|
302 |
+
input_ids,
|
303 |
+
images=image_tensor,
|
304 |
+
do_sample=True if args.temperature > 0 else False,
|
305 |
+
temperature=args.temperature,
|
306 |
+
max_new_tokens=args.max_new_tokens,
|
307 |
+
streamer=streamer,
|
308 |
+
use_cache=True,
|
309 |
+
stopping_criteria=[stopping_criteria]
|
310 |
+
)
|
311 |
+
llm_end_time = time.perf_counter()
|
312 |
+
|
313 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
314 |
+
conv.messages[-1][-1] = outputs
|
315 |
+
conv_end_time = time.perf_counter()
|
316 |
+
if conv_cnt > 0:
|
317 |
+
time_meter.add('conv_latency', conv_end_time - conv_start_time)
|
318 |
+
time_meter.add('llm_latency', llm_end_time - llm_start_time)
|
319 |
+
time_meter.add('real_sleep', conv_start_time - last_conv_start_time)
|
320 |
+
logger.info(f'CliServer: idx={conv_cnt},\treal_sleep={time_meter["real_sleep"]},\tconv_latency={time_meter["conv_latency"]},\tllm_latency={time_meter["llm_latency"]}')
|
321 |
+
else:
|
322 |
+
logger.info(f'CliServer: idx={conv_cnt},\tconv_latency={conv_end_time - conv_start_time},\tllm_latency={llm_end_time - llm_start_time}')
|
323 |
+
conv_cnt += 1
|
324 |
+
last_conv_start_time = conv_start_time
|
325 |
+
|
326 |
+
for p in processes:
|
327 |
+
p.terminate()
|
328 |
+
print("All processes finished.")
|
329 |
+
|
330 |
+
|
331 |
+
if __name__ == "__main__":
|
332 |
+
parser = argparse.ArgumentParser()
|
333 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
334 |
+
parser.add_argument("--model-base", type=str, default=None)
|
335 |
+
parser.add_argument("--image-file", type=str, default=None)
|
336 |
+
parser.add_argument("--video-file", type=str, default=None)
|
337 |
+
parser.add_argument("--device", type=str, default="cuda")
|
338 |
+
parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
|
339 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
340 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
341 |
+
parser.add_argument("--load-8bit", action="store_true")
|
342 |
+
parser.add_argument("--load-4bit", action="store_true")
|
343 |
+
parser.add_argument("--debug", action="store_true")
|
344 |
+
|
345 |
+
parser.add_argument("--log-file", type=str, default="tmp_cli.log")
|
346 |
+
parser.add_argument("--use_1process", action="store_true")
|
347 |
+
parser.add_argument("--video_max_frames", type=int, default=None)
|
348 |
+
parser.add_argument("--video_fps", type=float, default=1.0)
|
349 |
+
parser.add_argument("--play_speed", type=float, default=1.0)
|
350 |
+
args = parser.parse_args()
|
351 |
+
main(args)
|
flash_vstream/serve/demo.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from ..constants import *
|
3 |
+
from ..conversation import conv_templates, SeparatorStyle
|
4 |
+
from ..model.builder import load_pretrained_model
|
5 |
+
from ..utils import disable_torch_init
|
6 |
+
from ..mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
|
7 |
+
from PIL import Image
|
8 |
+
import os
|
9 |
+
from decord import VideoReader, cpu
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
class Chat:
|
14 |
+
def __init__(self, model_path, conv_mode="simple", load_8bit=False, load_4bit=False):
|
15 |
+
disable_torch_init()
|
16 |
+
self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(model_path, None, model_name="ChatUniVi", load_8bit=load_8bit, load_4bit=load_4bit)
|
17 |
+
|
18 |
+
mm_use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
|
19 |
+
mm_use_im_patch_token = getattr(self.model.config, "mm_use_im_patch_token", True)
|
20 |
+
if mm_use_im_patch_token:
|
21 |
+
self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
22 |
+
if mm_use_im_start_end:
|
23 |
+
self.tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
24 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
25 |
+
|
26 |
+
vision_tower = self.model.get_vision_tower()
|
27 |
+
if not vision_tower.is_loaded:
|
28 |
+
vision_tower.load_model()
|
29 |
+
|
30 |
+
self.image_processor = vision_tower.image_processor
|
31 |
+
self.conv_mode = conv_mode
|
32 |
+
print(self.model)
|
33 |
+
|
34 |
+
def get_prompt(self, qs, state):
|
35 |
+
state.append_message(state.roles[0], qs)
|
36 |
+
state.append_message(state.roles[1], None)
|
37 |
+
return state
|
38 |
+
|
39 |
+
def _get_rawvideo_dec(self, video_path, image_processor, max_frames=MAX_IMAGE_LENGTH, image_resolution=224,
|
40 |
+
video_framerate=1, s=None, e=None):
|
41 |
+
if s is None:
|
42 |
+
start_time, end_time = None, None
|
43 |
+
else:
|
44 |
+
start_time = int(s)
|
45 |
+
end_time = int(e)
|
46 |
+
start_time = start_time if start_time >= 0. else 0.
|
47 |
+
end_time = end_time if end_time >= 0. else 0.
|
48 |
+
if start_time > end_time:
|
49 |
+
start_time, end_time = end_time, start_time
|
50 |
+
elif start_time == end_time:
|
51 |
+
end_time = start_time + 1
|
52 |
+
|
53 |
+
if os.path.exists(video_path):
|
54 |
+
vreader = VideoReader(video_path, ctx=cpu(0))
|
55 |
+
else:
|
56 |
+
print(video_path)
|
57 |
+
raise FileNotFoundError
|
58 |
+
|
59 |
+
fps = vreader.get_avg_fps()
|
60 |
+
f_start = 0 if start_time is None else int(start_time * fps)
|
61 |
+
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
|
62 |
+
num_frames = f_end - f_start + 1
|
63 |
+
if num_frames > 0:
|
64 |
+
sample_fps = int(video_framerate)
|
65 |
+
t_stride = int(round(float(fps) / sample_fps))
|
66 |
+
|
67 |
+
all_pos = list(range(f_start, f_end + 1, t_stride))
|
68 |
+
if len(all_pos) > max_frames:
|
69 |
+
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
|
70 |
+
else:
|
71 |
+
sample_pos = all_pos
|
72 |
+
|
73 |
+
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
|
74 |
+
return patch_images
|
75 |
+
|
76 |
+
@torch.inference_mode()
|
77 |
+
def generate(self, images_tensor: list, prompt: str, first_run: bool, state):
|
78 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
79 |
+
|
80 |
+
state = self.get_prompt(prompt, state)
|
81 |
+
prompt = state.get_prompt()
|
82 |
+
print(prompt)
|
83 |
+
|
84 |
+
images_tensor = torch.stack(images_tensor, dim=0)
|
85 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
86 |
+
|
87 |
+
temperature = 0.2
|
88 |
+
max_new_tokens = 1024
|
89 |
+
|
90 |
+
stop_str = conv_templates[self.conv_mode].copy().sep if conv_templates[self.conv_mode].copy().sep_style != SeparatorStyle.TWO else \
|
91 |
+
conv_templates[self.conv_mode].copy().sep2
|
92 |
+
keywords = [stop_str]
|
93 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
94 |
+
|
95 |
+
with torch.inference_mode():
|
96 |
+
output_ids = model.generate(
|
97 |
+
input_ids,
|
98 |
+
images=images_tensor,
|
99 |
+
do_sample=True,
|
100 |
+
temperature=temperature,
|
101 |
+
num_beams=1,
|
102 |
+
max_new_tokens=max_new_tokens,
|
103 |
+
use_cache=True,
|
104 |
+
stopping_criteria=[stopping_criteria])
|
105 |
+
|
106 |
+
input_token_len = input_ids.shape[1]
|
107 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
108 |
+
if n_diff_input_output > 0:
|
109 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
110 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
111 |
+
outputs = outputs.strip()
|
112 |
+
if outputs.endswith(stop_str):
|
113 |
+
outputs = outputs[:-len(stop_str)]
|
114 |
+
outputs = outputs.strip()
|
115 |
+
|
116 |
+
print('response', outputs)
|
117 |
+
return outputs, state
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
title_markdown = ("""
|
122 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
123 |
+
<a href="https://github.com/PKU-YuanGroup/Chat-UniVi" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
|
124 |
+
<img src="https://z1.ax1x.com/2023/11/22/pidlXh4.jpg" alt="Chat-UniVi🚀" style="max-width: 120px; height: auto;">
|
125 |
+
</a>
|
126 |
+
<div>
|
127 |
+
<h1 >Chat-UniVi: Unified Visual Representation Empowers Large Language Models with Image and Video Understanding</h1>
|
128 |
+
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
|
129 |
+
</div>
|
130 |
+
</div>
|
131 |
+
<div align="center">
|
132 |
+
<div style="display:flex; gap: 0.25rem;" align="center">
|
133 |
+
<a href='https://github.com/PKU-YuanGroup/Chat-UniVi'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
|
134 |
+
<a href="https://arxiv.org/pdf/2311.08046.pdf"><img src="https://img.shields.io/badge/Arxiv-2311.08046-red"></a>
|
135 |
+
<a href='https://github.com/PKU-YuanGroup/Chat-UniVi/stargazers'><img src='https://img.shields.io/github/stars/PKU-YuanGroup/Chat-UniVi.svg?style=social'></a>
|
136 |
+
</div>
|
137 |
+
</div>
|
138 |
+
""")
|
139 |
+
|
140 |
+
block_css = """
|
141 |
+
#buttons button {
|
142 |
+
min-width: min(120px,100%);
|
143 |
+
}
|
144 |
+
"""
|
flash_vstream/train/llama_flash_attn_monkey_patch.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import transformers
|
9 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
10 |
+
|
11 |
+
try:
|
12 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
13 |
+
except ImportError:
|
14 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
15 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
16 |
+
|
17 |
+
|
18 |
+
def forward(
|
19 |
+
self,
|
20 |
+
hidden_states: torch.Tensor,
|
21 |
+
attention_mask: Optional[torch.Tensor] = None,
|
22 |
+
position_ids: Optional[torch.Tensor] = None,
|
23 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
24 |
+
output_attentions: bool = False,
|
25 |
+
use_cache: bool = False,
|
26 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
27 |
+
if output_attentions:
|
28 |
+
warnings.warn(
|
29 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
30 |
+
)
|
31 |
+
|
32 |
+
bsz, q_len, _ = hidden_states.size()
|
33 |
+
|
34 |
+
query_states = (
|
35 |
+
self.q_proj(hidden_states)
|
36 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
37 |
+
.transpose(1, 2)
|
38 |
+
)
|
39 |
+
key_states = (
|
40 |
+
self.k_proj(hidden_states)
|
41 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
42 |
+
.transpose(1, 2)
|
43 |
+
)
|
44 |
+
value_states = (
|
45 |
+
self.v_proj(hidden_states)
|
46 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
47 |
+
.transpose(1, 2)
|
48 |
+
) # shape: (b, num_heads, s, head_dim)
|
49 |
+
|
50 |
+
kv_seq_len = key_states.shape[-2]
|
51 |
+
if past_key_value is not None:
|
52 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
53 |
+
|
54 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
55 |
+
query_states, key_states = apply_rotary_pos_emb(
|
56 |
+
query_states, key_states, cos, sin, position_ids
|
57 |
+
)
|
58 |
+
|
59 |
+
if past_key_value is not None:
|
60 |
+
# reuse k, v
|
61 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
62 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
63 |
+
|
64 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
65 |
+
|
66 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
67 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
68 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
69 |
+
|
70 |
+
# Transform the data into the format required by flash attention
|
71 |
+
qkv = torch.stack([query_states, key_states, value_states], dim=2)
|
72 |
+
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
|
73 |
+
key_padding_mask = attention_mask
|
74 |
+
|
75 |
+
if key_padding_mask is None:
|
76 |
+
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
|
77 |
+
cu_q_lens = torch.arange(
|
78 |
+
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
79 |
+
)
|
80 |
+
max_s = q_len
|
81 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
82 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
83 |
+
)
|
84 |
+
output = output.view(bsz, q_len, -1)
|
85 |
+
else:
|
86 |
+
qkv = qkv.reshape(bsz, q_len, -1)
|
87 |
+
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
|
88 |
+
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
|
89 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
90 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
91 |
+
)
|
92 |
+
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
|
93 |
+
output = pad_input(output_unpad, indices, bsz, q_len)
|
94 |
+
|
95 |
+
return self.o_proj(output), None, past_key_value
|
96 |
+
|
97 |
+
|
98 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
99 |
+
# requires the attention mask to be the same as the key_padding_mask
|
100 |
+
def _prepare_decoder_attention_mask(
|
101 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
102 |
+
):
|
103 |
+
# [bsz, seq_len]
|
104 |
+
return attention_mask
|
105 |
+
|
106 |
+
|
107 |
+
def replace_llama_attn_with_flash_attn():
|
108 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
109 |
+
if cuda_major < 8:
|
110 |
+
warnings.warn(
|
111 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
112 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
113 |
+
)
|
114 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
115 |
+
_prepare_decoder_attention_mask
|
116 |
+
)
|
117 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
flash_vstream/train/llama_xformers_attn_monkey_patch.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
"""
|
4 |
+
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
5 |
+
"""
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
from typing import Optional, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import transformers.models.llama.modeling_llama
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
try:
|
16 |
+
import xformers.ops
|
17 |
+
except ImportError:
|
18 |
+
logging.error("xformers not found! Please install it before trying to use it.")
|
19 |
+
|
20 |
+
|
21 |
+
def replace_llama_attn_with_xformers_attn():
|
22 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
23 |
+
|
24 |
+
|
25 |
+
def xformers_forward(
|
26 |
+
self,
|
27 |
+
hidden_states: torch.Tensor,
|
28 |
+
attention_mask: Optional[torch.Tensor] = None,
|
29 |
+
position_ids: Optional[torch.LongTensor] = None,
|
30 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
31 |
+
output_attentions: bool = False,
|
32 |
+
use_cache: bool = False,
|
33 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
34 |
+
# pylint: disable=duplicate-code
|
35 |
+
bsz, q_len, _ = hidden_states.size()
|
36 |
+
|
37 |
+
query_states = (
|
38 |
+
self.q_proj(hidden_states)
|
39 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
40 |
+
.transpose(1, 2)
|
41 |
+
)
|
42 |
+
key_states = (
|
43 |
+
self.k_proj(hidden_states)
|
44 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
45 |
+
.transpose(1, 2)
|
46 |
+
)
|
47 |
+
value_states = (
|
48 |
+
self.v_proj(hidden_states)
|
49 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
50 |
+
.transpose(1, 2)
|
51 |
+
)
|
52 |
+
|
53 |
+
kv_seq_len = key_states.shape[-2]
|
54 |
+
if past_key_value is not None:
|
55 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
56 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
57 |
+
(
|
58 |
+
query_states,
|
59 |
+
key_states,
|
60 |
+
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
61 |
+
query_states, key_states, cos, sin, position_ids
|
62 |
+
)
|
63 |
+
# [bsz, nh, t, hd]
|
64 |
+
|
65 |
+
if past_key_value is not None:
|
66 |
+
# reuse k, v, self_attention
|
67 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
68 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
69 |
+
|
70 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
71 |
+
|
72 |
+
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
73 |
+
if not output_attentions:
|
74 |
+
query_states = query_states.transpose(1, 2)
|
75 |
+
key_states = key_states.transpose(1, 2)
|
76 |
+
value_states = value_states.transpose(1, 2)
|
77 |
+
|
78 |
+
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
79 |
+
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
80 |
+
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
81 |
+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
82 |
+
attn_output = xformers.ops.memory_efficient_attention(
|
83 |
+
query_states, key_states, value_states, attn_bias=None
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
87 |
+
attn_output = xformers.ops.memory_efficient_attention(
|
88 |
+
query_states,
|
89 |
+
key_states,
|
90 |
+
value_states,
|
91 |
+
attn_bias=xformers.ops.LowerTriangularMask(),
|
92 |
+
)
|
93 |
+
attn_weights = None
|
94 |
+
else:
|
95 |
+
attn_weights = torch.matmul(
|
96 |
+
query_states, key_states.transpose(2, 3)
|
97 |
+
) / math.sqrt(self.head_dim)
|
98 |
+
|
99 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
100 |
+
raise ValueError(
|
101 |
+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
102 |
+
f" {attn_weights.size()}"
|
103 |
+
)
|
104 |
+
|
105 |
+
if attention_mask is not None:
|
106 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
107 |
+
raise ValueError(
|
108 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
109 |
+
)
|
110 |
+
attn_weights = attn_weights + attention_mask
|
111 |
+
attn_weights = torch.max(
|
112 |
+
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
113 |
+
)
|
114 |
+
|
115 |
+
# upcast attention to fp32
|
116 |
+
attn_weights = nn.functional.softmax(
|
117 |
+
attn_weights, dim=-1, dtype=torch.float32
|
118 |
+
).to(query_states.dtype)
|
119 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
120 |
+
|
121 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
122 |
+
raise ValueError(
|
123 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
124 |
+
f" {attn_output.size()}"
|
125 |
+
)
|
126 |
+
|
127 |
+
attn_output = attn_output.transpose(1, 2)
|
128 |
+
|
129 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
130 |
+
attn_output = self.o_proj(attn_output)
|
131 |
+
return attn_output, attn_weights, past_key_value
|
flash_vstream/train/train.py
ADDED
@@ -0,0 +1,1069 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
|
2 |
+
# ------------------------------------------------------------------------
|
3 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
4 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
5 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
6 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
import os
|
21 |
+
import copy
|
22 |
+
import json
|
23 |
+
import torch
|
24 |
+
import random
|
25 |
+
import logging
|
26 |
+
import pathlib
|
27 |
+
import transformers
|
28 |
+
from dataclasses import dataclass, field
|
29 |
+
from typing import Dict, Optional, Sequence, List
|
30 |
+
|
31 |
+
|
32 |
+
from flash_vstream.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
33 |
+
from torch.utils.data import Dataset
|
34 |
+
from flash_vstream.train.vstream_trainer import VStreamTrainer
|
35 |
+
|
36 |
+
from flash_vstream import conversation as conversation_lib
|
37 |
+
from flash_vstream.model import VStreamLlamaForCausalLM, VStreamConfig
|
38 |
+
from flash_vstream.mm_utils import tokenizer_image_token
|
39 |
+
|
40 |
+
from PIL import Image
|
41 |
+
from decord import VideoReader
|
42 |
+
from safetensors.torch import load_file, save_file
|
43 |
+
|
44 |
+
|
45 |
+
local_rank = None
|
46 |
+
|
47 |
+
|
48 |
+
def rank0_print(*args):
|
49 |
+
if local_rank == 0:
|
50 |
+
print(*args)
|
51 |
+
|
52 |
+
|
53 |
+
@dataclass
|
54 |
+
class ModelArguments:
|
55 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
56 |
+
version: Optional[str] = field(default="v0")
|
57 |
+
freeze_backbone: bool = field(default=False)
|
58 |
+
tune_mm_mlp_adapter: bool = field(default=False)
|
59 |
+
vision_tower: Optional[str] = field(default=None)
|
60 |
+
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
|
61 |
+
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
|
62 |
+
mm_projector_type: Optional[str] = field(default='linear')
|
63 |
+
mm_use_im_start_end: bool = field(default=False)
|
64 |
+
mm_use_im_patch_token: bool = field(default=True)
|
65 |
+
mm_vision_select_feature: Optional[str] = field(default="patch")
|
66 |
+
mm_use_4_vision_tokens: bool = field(default=False)
|
67 |
+
compress_type: Optional[str] = field(default=None)
|
68 |
+
compress_size: int = field(default=4)
|
69 |
+
compress_long_memory_size: int = field(default=1)
|
70 |
+
compress_Turing_memory_size: int = field(default=1)
|
71 |
+
compress_Turing_hidden_dim: int = field(default=32)
|
72 |
+
compress_Turing_update_ratio: float = field(default=0.2)
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class DataArguments:
|
77 |
+
data_path: str = field(default=None,
|
78 |
+
metadata={"help": "Path to the training data."})
|
79 |
+
lazy_preprocess: bool = False
|
80 |
+
is_multimodal: bool = False
|
81 |
+
image_folder: Optional[str] = field(default=None)
|
82 |
+
video_folder: Optional[str] = field(default=None)
|
83 |
+
video_fps: Optional[int] = field(default=1)
|
84 |
+
video_token: Optional[int] = field(default=2)
|
85 |
+
video_max_frames: Optional[int] = field(default=50)
|
86 |
+
video_long_memory_length: Optional[int] = field(default=10)
|
87 |
+
video_Turing_memory_length: Optional[int] = field(default=10)
|
88 |
+
video_short_memory_length: Optional[int] = field(default=10)
|
89 |
+
video_current_memory_length: Optional[int] = field(default=1)
|
90 |
+
video_sample_type: Optional[str] = field(default='center') # center, uniform, drop, merge
|
91 |
+
image_aspect_ratio: str = 'square'
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class TrainingArguments(transformers.TrainingArguments):
|
96 |
+
cache_dir: Optional[str] = field(default=None)
|
97 |
+
optim: str = field(default="adamw_torch")
|
98 |
+
remove_unused_columns: bool = field(default=False)
|
99 |
+
freeze_mm_mlp_adapter: bool = field(default=False)
|
100 |
+
model_max_length: int = field(
|
101 |
+
default=512,
|
102 |
+
metadata={
|
103 |
+
"help":
|
104 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
105 |
+
},
|
106 |
+
)
|
107 |
+
double_quant: bool = field(
|
108 |
+
default=True,
|
109 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
110 |
+
)
|
111 |
+
quant_type: str = field(
|
112 |
+
default="nf4",
|
113 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
114 |
+
)
|
115 |
+
bits: int = field(
|
116 |
+
default=16,
|
117 |
+
metadata={"help": "How many bits to use."}
|
118 |
+
)
|
119 |
+
lora_enable: bool = False
|
120 |
+
lora_r: int = 64
|
121 |
+
lora_alpha: int = 16
|
122 |
+
lora_dropout: float = 0.05
|
123 |
+
lora_weight_path: str = ""
|
124 |
+
lora_bias: str = "none"
|
125 |
+
mm_projector_lr: Optional[float] = None
|
126 |
+
group_by_modality_length: bool = field(default=False)
|
127 |
+
|
128 |
+
|
129 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
130 |
+
from deepspeed import zero
|
131 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
132 |
+
if hasattr(param, "ds_id"):
|
133 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
134 |
+
if not ignore_status:
|
135 |
+
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
136 |
+
with zero.GatheredParameters([param]):
|
137 |
+
param = param.data.detach().cpu().clone()
|
138 |
+
else:
|
139 |
+
param = param.detach().cpu().clone()
|
140 |
+
return param
|
141 |
+
|
142 |
+
|
143 |
+
# Borrowed from peft.utils.get_peft_model_state_dict
|
144 |
+
def get_peft_state_maybe_zero_3(named_params, bias):
|
145 |
+
if bias == "none":
|
146 |
+
to_return = {k: t for k, t in named_params if "lora_" in k}
|
147 |
+
elif bias == "all":
|
148 |
+
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
149 |
+
elif bias == "lora_only":
|
150 |
+
to_return = {}
|
151 |
+
maybe_lora_bias = {}
|
152 |
+
lora_bias_names = set()
|
153 |
+
for k, t in named_params:
|
154 |
+
if "lora_" in k:
|
155 |
+
to_return[k] = t
|
156 |
+
bias_name = k.split("lora_")[0] + "bias"
|
157 |
+
lora_bias_names.add(bias_name)
|
158 |
+
elif "bias" in k:
|
159 |
+
maybe_lora_bias[k] = t
|
160 |
+
for k, t in maybe_lora_bias:
|
161 |
+
if bias_name in lora_bias_names:
|
162 |
+
to_return[bias_name] = t
|
163 |
+
else:
|
164 |
+
raise NotImplementedError
|
165 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
166 |
+
return to_return
|
167 |
+
|
168 |
+
|
169 |
+
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
170 |
+
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
171 |
+
if require_grad_only:
|
172 |
+
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
173 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
174 |
+
return to_return
|
175 |
+
|
176 |
+
|
177 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
178 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
179 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
180 |
+
return to_return
|
181 |
+
|
182 |
+
|
183 |
+
def find_all_linear_names(model):
|
184 |
+
cls = torch.nn.Linear
|
185 |
+
lora_module_names = set()
|
186 |
+
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
|
187 |
+
for name, module in model.named_modules():
|
188 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
189 |
+
continue
|
190 |
+
if isinstance(module, cls):
|
191 |
+
names = name.split('.')
|
192 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
193 |
+
|
194 |
+
if 'lm_head' in lora_module_names: # needed for 16-bit
|
195 |
+
lora_module_names.remove('lm_head')
|
196 |
+
return list(lora_module_names)
|
197 |
+
|
198 |
+
|
199 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
|
200 |
+
output_dir: str):
|
201 |
+
"""Collects the state dict and dump to disk."""
|
202 |
+
|
203 |
+
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
|
204 |
+
# Only save Adapter
|
205 |
+
keys_to_match = ['mm_projector']
|
206 |
+
if getattr(trainer.args, "use_im_start_end", False):
|
207 |
+
keys_to_match.extend(['embed_tokens', 'embed_in'])
|
208 |
+
|
209 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
|
210 |
+
trainer.model.config.save_pretrained(output_dir)
|
211 |
+
|
212 |
+
current_folder = output_dir.split('/')[-1]
|
213 |
+
parent_folder = os.path.dirname(output_dir)
|
214 |
+
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
|
215 |
+
if current_folder.startswith('checkpoint-'):
|
216 |
+
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
217 |
+
os.makedirs(mm_projector_folder, exist_ok=True)
|
218 |
+
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
|
219 |
+
else:
|
220 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
221 |
+
return
|
222 |
+
|
223 |
+
if trainer.deepspeed:
|
224 |
+
torch.cuda.synchronize()
|
225 |
+
trainer.save_model(output_dir)
|
226 |
+
return
|
227 |
+
|
228 |
+
state_dict = trainer.model.state_dict()
|
229 |
+
if trainer.args.should_save:
|
230 |
+
cpu_state_dict = {
|
231 |
+
key: value.cpu()
|
232 |
+
for key, value in state_dict.items()
|
233 |
+
}
|
234 |
+
del state_dict
|
235 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
236 |
+
|
237 |
+
|
238 |
+
def smart_tokenizer_and_embedding_resize(
|
239 |
+
special_tokens_dict: Dict,
|
240 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
241 |
+
model: transformers.PreTrainedModel,
|
242 |
+
):
|
243 |
+
"""Resize tokenizer and embedding.
|
244 |
+
|
245 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
246 |
+
"""
|
247 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
248 |
+
model.resize_token_embeddings(len(tokenizer))
|
249 |
+
|
250 |
+
if num_new_tokens > 0:
|
251 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
252 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
253 |
+
|
254 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
255 |
+
dim=0, keepdim=True)
|
256 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
257 |
+
dim=0, keepdim=True)
|
258 |
+
|
259 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
260 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
261 |
+
|
262 |
+
|
263 |
+
def _tokenize_fn(strings: Sequence[str],
|
264 |
+
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
265 |
+
"""Tokenize a list of strings."""
|
266 |
+
tokenized_list = [
|
267 |
+
tokenizer(
|
268 |
+
text,
|
269 |
+
return_tensors="pt",
|
270 |
+
padding="longest",
|
271 |
+
max_length=tokenizer.model_max_length,
|
272 |
+
truncation=True,
|
273 |
+
) for text in strings
|
274 |
+
]
|
275 |
+
input_ids = labels = [
|
276 |
+
tokenized.input_ids[0] for tokenized in tokenized_list
|
277 |
+
]
|
278 |
+
input_ids_lens = labels_lens = [
|
279 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
280 |
+
for tokenized in tokenized_list
|
281 |
+
]
|
282 |
+
return dict(
|
283 |
+
input_ids=input_ids,
|
284 |
+
labels=labels,
|
285 |
+
input_ids_lens=input_ids_lens,
|
286 |
+
labels_lens=labels_lens,
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
def _mask_targets(target, tokenized_lens, speakers):
|
291 |
+
# cur_idx = 0
|
292 |
+
cur_idx = tokenized_lens[0]
|
293 |
+
tokenized_lens = tokenized_lens[1:]
|
294 |
+
target[:cur_idx] = IGNORE_INDEX
|
295 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
296 |
+
if speaker == "human":
|
297 |
+
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
|
298 |
+
cur_idx += tokenized_len
|
299 |
+
|
300 |
+
|
301 |
+
def _add_speaker_and_signal(header, source, get_conversation=True):
|
302 |
+
"""Add speaker and start/end signal on each round."""
|
303 |
+
BEGIN_SIGNAL = "### "
|
304 |
+
END_SIGNAL = "\n"
|
305 |
+
conversation = header
|
306 |
+
for sentence in source:
|
307 |
+
from_str = sentence["from"]
|
308 |
+
if from_str.lower() == "human":
|
309 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
310 |
+
elif from_str.lower() == "gpt":
|
311 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
312 |
+
else:
|
313 |
+
from_str = 'unknown'
|
314 |
+
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
|
315 |
+
sentence["value"] + END_SIGNAL)
|
316 |
+
if get_conversation:
|
317 |
+
conversation += sentence["value"]
|
318 |
+
conversation += BEGIN_SIGNAL
|
319 |
+
return conversation
|
320 |
+
|
321 |
+
|
322 |
+
def preprocess_multimodal(
|
323 |
+
sources: Sequence[str],
|
324 |
+
data_args: DataArguments
|
325 |
+
) -> Dict:
|
326 |
+
is_multimodal = data_args.is_multimodal
|
327 |
+
if not is_multimodal:
|
328 |
+
return sources
|
329 |
+
|
330 |
+
for source in sources:
|
331 |
+
for sentence in source:
|
332 |
+
if DEFAULT_IMAGE_TOKEN in sentence['value']:
|
333 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
334 |
+
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
|
335 |
+
sentence['value'] = sentence['value'].strip()
|
336 |
+
if "mmtag" in conversation_lib.default_conversation.version:
|
337 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
|
338 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
339 |
+
if data_args.mm_use_im_start_end:
|
340 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
341 |
+
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
342 |
+
|
343 |
+
return sources
|
344 |
+
|
345 |
+
|
346 |
+
def preprocess_llama_2(
|
347 |
+
sources,
|
348 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
349 |
+
has_image: bool = False
|
350 |
+
) -> Dict:
|
351 |
+
conv = conversation_lib.default_conversation.copy()
|
352 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
353 |
+
|
354 |
+
# Apply prompt templates
|
355 |
+
conversations = []
|
356 |
+
for i, source in enumerate(sources):
|
357 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
358 |
+
# Skip the first one if it is not from human
|
359 |
+
source = source[1:]
|
360 |
+
|
361 |
+
conv.messages = []
|
362 |
+
for j, sentence in enumerate(source):
|
363 |
+
role = roles[sentence["from"]]
|
364 |
+
assert role == conv.roles[j % 2], f"{i}"
|
365 |
+
conv.append_message(role, sentence["value"])
|
366 |
+
conversations.append(conv.get_prompt())
|
367 |
+
|
368 |
+
# Tokenize conversations
|
369 |
+
|
370 |
+
if has_image:
|
371 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
372 |
+
else:
|
373 |
+
input_ids = tokenizer(
|
374 |
+
conversations,
|
375 |
+
return_tensors="pt",
|
376 |
+
padding="longest",
|
377 |
+
max_length=tokenizer.model_max_length,
|
378 |
+
truncation=True,
|
379 |
+
).input_ids
|
380 |
+
|
381 |
+
targets = input_ids.clone()
|
382 |
+
|
383 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
384 |
+
|
385 |
+
# Mask targets
|
386 |
+
sep = "[/INST] "
|
387 |
+
for conversation, target in zip(conversations, targets):
|
388 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
389 |
+
|
390 |
+
rounds = conversation.split(conv.sep2)
|
391 |
+
cur_len = 1
|
392 |
+
target[:cur_len] = IGNORE_INDEX
|
393 |
+
for i, rou in enumerate(rounds):
|
394 |
+
if rou == "":
|
395 |
+
break
|
396 |
+
|
397 |
+
parts = rou.split(sep)
|
398 |
+
if len(parts) != 2:
|
399 |
+
break
|
400 |
+
parts[0] += sep
|
401 |
+
|
402 |
+
if has_image:
|
403 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
404 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
405 |
+
else:
|
406 |
+
round_len = len(tokenizer(rou).input_ids)
|
407 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
408 |
+
|
409 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
410 |
+
|
411 |
+
cur_len += round_len
|
412 |
+
target[cur_len:] = IGNORE_INDEX
|
413 |
+
|
414 |
+
if cur_len < tokenizer.model_max_length:
|
415 |
+
if cur_len != total_len:
|
416 |
+
target[:] = IGNORE_INDEX
|
417 |
+
print(
|
418 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
419 |
+
f" (ignored)"
|
420 |
+
)
|
421 |
+
|
422 |
+
return dict(
|
423 |
+
input_ids=input_ids,
|
424 |
+
labels=targets,
|
425 |
+
)
|
426 |
+
|
427 |
+
|
428 |
+
def preprocess_v1(
|
429 |
+
sources,
|
430 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
431 |
+
has_image: bool = False
|
432 |
+
) -> Dict:
|
433 |
+
conv = conversation_lib.default_conversation.copy()
|
434 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
435 |
+
|
436 |
+
# Apply prompt templates
|
437 |
+
conversations = []
|
438 |
+
for i, source in enumerate(sources):
|
439 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
440 |
+
# Skip the first one if it is not from human
|
441 |
+
source = source[1:]
|
442 |
+
|
443 |
+
conv.messages = []
|
444 |
+
for j, sentence in enumerate(source):
|
445 |
+
role = roles[sentence["from"]]
|
446 |
+
assert role == conv.roles[j % 2], f"{i}"
|
447 |
+
conv.append_message(role, sentence["value"])
|
448 |
+
conversations.append(conv.get_prompt())
|
449 |
+
|
450 |
+
# Tokenize conversations
|
451 |
+
|
452 |
+
if has_image:
|
453 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
454 |
+
else:
|
455 |
+
input_ids = tokenizer(
|
456 |
+
conversations,
|
457 |
+
return_tensors="pt",
|
458 |
+
padding="longest",
|
459 |
+
max_length=tokenizer.model_max_length,
|
460 |
+
truncation=True,
|
461 |
+
).input_ids
|
462 |
+
|
463 |
+
targets = input_ids.clone()
|
464 |
+
|
465 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
466 |
+
|
467 |
+
# Mask targets
|
468 |
+
sep = conv.sep + conv.roles[1] + ": "
|
469 |
+
for conversation, target in zip(conversations, targets):
|
470 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
471 |
+
|
472 |
+
rounds = conversation.split(conv.sep2)
|
473 |
+
cur_len = 1
|
474 |
+
target[:cur_len] = IGNORE_INDEX
|
475 |
+
for i, rou in enumerate(rounds):
|
476 |
+
if rou == "":
|
477 |
+
break
|
478 |
+
|
479 |
+
parts = rou.split(sep)
|
480 |
+
if len(parts) != 2:
|
481 |
+
break
|
482 |
+
parts[0] += sep
|
483 |
+
|
484 |
+
if has_image:
|
485 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
486 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
487 |
+
else:
|
488 |
+
round_len = len(tokenizer(rou).input_ids)
|
489 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
490 |
+
|
491 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
492 |
+
|
493 |
+
cur_len += round_len
|
494 |
+
target[cur_len:] = IGNORE_INDEX
|
495 |
+
|
496 |
+
if cur_len < tokenizer.model_max_length:
|
497 |
+
if cur_len != total_len:
|
498 |
+
target[:] = IGNORE_INDEX
|
499 |
+
print(
|
500 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
501 |
+
f" (ignored)"
|
502 |
+
)
|
503 |
+
|
504 |
+
return dict(
|
505 |
+
input_ids=input_ids,
|
506 |
+
labels=targets,
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
def preprocess_mpt(
|
511 |
+
sources,
|
512 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
513 |
+
) -> Dict:
|
514 |
+
conv = conversation_lib.default_conversation.copy()
|
515 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
516 |
+
|
517 |
+
# Apply prompt templates
|
518 |
+
conversations = []
|
519 |
+
for i, source in enumerate(sources):
|
520 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
521 |
+
# Skip the first one if it is not from human
|
522 |
+
source = source[1:]
|
523 |
+
|
524 |
+
conv.messages = []
|
525 |
+
for j, sentence in enumerate(source):
|
526 |
+
role = roles[sentence["from"]]
|
527 |
+
assert role == conv.roles[j % 2], f"{i}"
|
528 |
+
conv.append_message(role, sentence["value"])
|
529 |
+
conversations.append(conv.get_prompt())
|
530 |
+
|
531 |
+
# Tokenize conversations
|
532 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
533 |
+
targets = input_ids.clone()
|
534 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
|
535 |
+
|
536 |
+
# Mask targets
|
537 |
+
sep = conv.sep + conv.roles[1]
|
538 |
+
for conversation, target in zip(conversations, targets):
|
539 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
540 |
+
|
541 |
+
rounds = conversation.split(conv.sep)
|
542 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
543 |
+
for conv_idx in range(3, len(rounds), 2):
|
544 |
+
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
|
545 |
+
cur_len = 0
|
546 |
+
target[:cur_len] = IGNORE_INDEX
|
547 |
+
for i, rou in enumerate(re_rounds):
|
548 |
+
if rou == "":
|
549 |
+
break
|
550 |
+
|
551 |
+
parts = rou.split(sep)
|
552 |
+
if len(parts) != 2:
|
553 |
+
break
|
554 |
+
parts[0] += sep
|
555 |
+
round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
|
556 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
557 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
558 |
+
|
559 |
+
cur_len += round_len
|
560 |
+
target[cur_len:] = IGNORE_INDEX
|
561 |
+
|
562 |
+
if cur_len < tokenizer.model_max_length:
|
563 |
+
if cur_len != total_len:
|
564 |
+
target[:] = IGNORE_INDEX
|
565 |
+
print(
|
566 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
567 |
+
f" (ignored)"
|
568 |
+
)
|
569 |
+
|
570 |
+
return dict(
|
571 |
+
input_ids=input_ids,
|
572 |
+
labels=targets,
|
573 |
+
)
|
574 |
+
|
575 |
+
|
576 |
+
def preprocess_plain(
|
577 |
+
sources: Sequence[str],
|
578 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
579 |
+
) -> Dict:
|
580 |
+
# add end signal and concatenate together
|
581 |
+
conversations = []
|
582 |
+
for source in sources:
|
583 |
+
assert len(source) == 2
|
584 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
|
585 |
+
source[0]['value'] = DEFAULT_IMAGE_TOKEN
|
586 |
+
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
587 |
+
conversations.append(conversation)
|
588 |
+
# tokenize conversations
|
589 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
590 |
+
targets = copy.deepcopy(input_ids)
|
591 |
+
for target, source in zip(targets, sources):
|
592 |
+
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
|
593 |
+
target[:tokenized_len] = IGNORE_INDEX
|
594 |
+
|
595 |
+
return dict(input_ids=input_ids, labels=targets)
|
596 |
+
|
597 |
+
|
598 |
+
def preprocess(
|
599 |
+
sources: Sequence[str],
|
600 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
601 |
+
has_image: bool = False
|
602 |
+
) -> Dict:
|
603 |
+
"""
|
604 |
+
Given a list of sources, each is a conversation list. This transform:
|
605 |
+
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
606 |
+
2. Concatenate conversations together;
|
607 |
+
3. Tokenize the concatenated conversation;
|
608 |
+
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
609 |
+
"""
|
610 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
611 |
+
return preprocess_plain(sources, tokenizer)
|
612 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
|
613 |
+
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
|
614 |
+
if conversation_lib.default_conversation.version.startswith("v1"):
|
615 |
+
return preprocess_v1(sources, tokenizer, has_image=has_image)
|
616 |
+
if conversation_lib.default_conversation.version == "mpt":
|
617 |
+
return preprocess_mpt(sources, tokenizer)
|
618 |
+
# add end signal and concatenate together
|
619 |
+
conversations = []
|
620 |
+
for source in sources:
|
621 |
+
header = f"{conversation_lib.default_conversation.system}\n\n"
|
622 |
+
conversation = _add_speaker_and_signal(header, source)
|
623 |
+
conversations.append(conversation)
|
624 |
+
# tokenize conversations
|
625 |
+
def get_tokenize_len(prompts):
|
626 |
+
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
|
627 |
+
|
628 |
+
if has_image:
|
629 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
630 |
+
else:
|
631 |
+
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
632 |
+
input_ids = conversations_tokenized["input_ids"]
|
633 |
+
targets = copy.deepcopy(input_ids)
|
634 |
+
for target, source in zip(targets, sources):
|
635 |
+
if has_image:
|
636 |
+
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
|
637 |
+
else:
|
638 |
+
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
|
639 |
+
speakers = [sentence["from"] for sentence in source]
|
640 |
+
_mask_targets(target, tokenized_lens, speakers)
|
641 |
+
|
642 |
+
return dict(input_ids=input_ids, labels=targets)
|
643 |
+
|
644 |
+
|
645 |
+
class LazySupervisedDataset(Dataset):
|
646 |
+
"""Dataset for supervised fine-tuning."""
|
647 |
+
|
648 |
+
def __init__(self, data_path: str,
|
649 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
650 |
+
data_args: DataArguments):
|
651 |
+
super(LazySupervisedDataset, self).__init__()
|
652 |
+
list_data_dict = json.load(open(data_path, "r"))
|
653 |
+
|
654 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
655 |
+
self.tokenizer = tokenizer
|
656 |
+
self.list_data_dict = list_data_dict
|
657 |
+
self.data_args = data_args
|
658 |
+
|
659 |
+
def __len__(self):
|
660 |
+
return len(self.list_data_dict)
|
661 |
+
|
662 |
+
@property
|
663 |
+
def lengths(self):
|
664 |
+
length_list = []
|
665 |
+
for sample in self.list_data_dict:
|
666 |
+
img_tokens = 128 if 'image' in sample else 0
|
667 |
+
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
|
668 |
+
return length_list
|
669 |
+
|
670 |
+
@property
|
671 |
+
def modality_lengths(self):
|
672 |
+
length_list = []
|
673 |
+
for sample in self.list_data_dict:
|
674 |
+
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
|
675 |
+
cur_len = cur_len if ('image' in sample) or ('video' in sample) else -cur_len
|
676 |
+
length_list.append(cur_len)
|
677 |
+
return length_list
|
678 |
+
|
679 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
680 |
+
attempt, max_attempt = 0, 10
|
681 |
+
while attempt < max_attempt:
|
682 |
+
try:
|
683 |
+
sources = self.list_data_dict[i]
|
684 |
+
if isinstance(i, int):
|
685 |
+
sources = [sources]
|
686 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
687 |
+
feature = None
|
688 |
+
if 'image' in sources[0]:
|
689 |
+
image_file = self.list_data_dict[i]['image']
|
690 |
+
image_folder = self.data_args.image_folder
|
691 |
+
image_file = os.path.join(image_folder, image_file)
|
692 |
+
suffix = image_file.split('.')[-1]
|
693 |
+
|
694 |
+
if 'features' in image_folder:
|
695 |
+
# TODO: load video feature, not supported yet
|
696 |
+
image_file = image_file.replace(suffix, 'safetensors')
|
697 |
+
if not os.path.exists(image_file):
|
698 |
+
print('Image file {} not exist!'.format(image_file))
|
699 |
+
feature = load_file(image_file)['feature'].unsqueeze(0)
|
700 |
+
sources = preprocess_multimodal(
|
701 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
702 |
+
self.data_args)
|
703 |
+
|
704 |
+
else:
|
705 |
+
processor = self.data_args.image_processor
|
706 |
+
image = Image.open().convert('RGB')
|
707 |
+
if self.data_args.image_aspect_ratio == 'pad':
|
708 |
+
def expand2square(pil_img, background_color):
|
709 |
+
width, height = pil_img.size
|
710 |
+
if width == height:
|
711 |
+
return pil_img
|
712 |
+
elif width > height:
|
713 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
714 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
715 |
+
return result
|
716 |
+
else:
|
717 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
718 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
719 |
+
return result
|
720 |
+
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
|
721 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
722 |
+
else:
|
723 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
724 |
+
sources = preprocess_multimodal(
|
725 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
726 |
+
self.data_args)
|
727 |
+
|
728 |
+
elif 'video' in sources[0]:
|
729 |
+
video_file = self.list_data_dict[i]['video']
|
730 |
+
video_folder = self.data_args.video_folder
|
731 |
+
video_file = os.path.join(video_folder, video_file)
|
732 |
+
suffix = video_file.split('.')[-1]
|
733 |
+
|
734 |
+
if 'features' in video_folder:
|
735 |
+
# TODO: load video feature, not supported yet
|
736 |
+
video_file = video_file.replace(suffix, 'safetensors')
|
737 |
+
if not os.path.exists(video_file):
|
738 |
+
print('Video file {} not exist!'.format(video_file))
|
739 |
+
feature = load_file(video_file)['feature']
|
740 |
+
if 'time' in self.list_data_dict[i]: # breakpoint mode
|
741 |
+
if 'time_9dense' in self.list_data_dict[i]:
|
742 |
+
tim = self.list_data_dict[i]['time_9dense'] // 4
|
743 |
+
start = max(tim - 6 * 9, 0)
|
744 |
+
end = min(tim + 6 * 9, feature.shape[0])
|
745 |
+
feature = feature[start:end]
|
746 |
+
else:
|
747 |
+
expansion = 15
|
748 |
+
if 'time_9' in self.list_data_dict[i]:
|
749 |
+
expansion = 9
|
750 |
+
tim = self.list_data_dict[i]['time']
|
751 |
+
start = max(tim - expansion, 0)
|
752 |
+
end = min(tim + expansion, feature.shape[0])
|
753 |
+
feature = feature[start:end]
|
754 |
+
elif 'time_9dense' in self.list_data_dict[i]:
|
755 |
+
feature = feature[::6]
|
756 |
+
|
757 |
+
sources = preprocess_multimodal(
|
758 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
759 |
+
self.data_args)
|
760 |
+
else:
|
761 |
+
# directly load video file
|
762 |
+
if not os.path.exists(video_file):
|
763 |
+
print('File {} not exist!'.format(video_file))
|
764 |
+
vr = VideoReader(video_file, num_threads=4)
|
765 |
+
sample_fps = round(vr.get_avg_fps()/self.data_args.video_fps)
|
766 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
767 |
+
if len(frame_idx) > self.data_args.video_max_frames:
|
768 |
+
if self.data_args.video_sample_type == 'center':
|
769 |
+
# select middle frames
|
770 |
+
start_pos = (len(frame_idx) - self.data_args.video_max_frames) // 2
|
771 |
+
frame_idx = frame_idx[start_pos:start_pos + self.data_args.video_max_frames]
|
772 |
+
elif self.data_args.video_sample_type == 'uniform':
|
773 |
+
scale = 1.0 * len(frame_idx) / self.data_args.video_max_frames
|
774 |
+
uniform_idx = [round((i + 1) * scale - 1) for i in range(self.data_args.video_max_frames)]
|
775 |
+
frame_idx = [frame_idx[i] for i in uniform_idx]
|
776 |
+
elif len(frame_idx) > 18000:
|
777 |
+
scale = 1.0 * len(frame_idx) / 180
|
778 |
+
uniform_idx = [round((i + 1) * scale - 1) for i in range(180)]
|
779 |
+
frame_idx = [frame_idx[i] for i in uniform_idx]
|
780 |
+
video = vr.get_batch(frame_idx).asnumpy()
|
781 |
+
processor = self.data_args.image_processor
|
782 |
+
image = processor.preprocess(video, return_tensors='pt')['pixel_values']
|
783 |
+
sources = preprocess_multimodal(
|
784 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
785 |
+
self.data_args)
|
786 |
+
|
787 |
+
else:
|
788 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
789 |
+
break
|
790 |
+
except Exception as e:
|
791 |
+
attempt += 1
|
792 |
+
print(f"Error in loading id:{i} sample, retrying {attempt} time... Error={e}")
|
793 |
+
i = random.randint(0, len(self.list_data_dict)-1)
|
794 |
+
|
795 |
+
has_image = ('image' in self.list_data_dict[i]) or ('video' in self.list_data_dict[i])
|
796 |
+
data_dict = preprocess(
|
797 |
+
sources,
|
798 |
+
self.tokenizer,
|
799 |
+
has_image=has_image)
|
800 |
+
if isinstance(i, int):
|
801 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
802 |
+
labels=data_dict["labels"][0])
|
803 |
+
|
804 |
+
# image exist in the data
|
805 |
+
if 'image' in self.list_data_dict[i] or 'video' in self.list_data_dict[i]:
|
806 |
+
if feature is not None:
|
807 |
+
data_dict['feature'] = feature
|
808 |
+
else:
|
809 |
+
data_dict['image'] = image
|
810 |
+
elif self.data_args.is_multimodal:
|
811 |
+
# image does not exist in the data, but the model is multimodal
|
812 |
+
crop_size = self.data_args.image_processor.crop_size
|
813 |
+
patch_size = 14
|
814 |
+
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
|
815 |
+
data_dict['feature'] = torch.zeros((crop_size['height'] // patch_size) * (crop_size['width'] // patch_size), self.data_args.mm_hidden_size)
|
816 |
+
return data_dict
|
817 |
+
|
818 |
+
|
819 |
+
@dataclass
|
820 |
+
class DataCollatorForSupervisedDataset(object):
|
821 |
+
"""Collate examples for supervised fine-tuning."""
|
822 |
+
|
823 |
+
tokenizer: transformers.PreTrainedTokenizer
|
824 |
+
|
825 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
826 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
827 |
+
for key in ("input_ids", "labels"))
|
828 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
829 |
+
input_ids,
|
830 |
+
batch_first=True,
|
831 |
+
padding_value=self.tokenizer.pad_token_id)
|
832 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
833 |
+
batch_first=True,
|
834 |
+
padding_value=IGNORE_INDEX)
|
835 |
+
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
836 |
+
labels = labels[:, :self.tokenizer.model_max_length]
|
837 |
+
batch = dict(
|
838 |
+
input_ids=input_ids,
|
839 |
+
labels=labels,
|
840 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
841 |
+
)
|
842 |
+
|
843 |
+
if 'feature' in instances[0]:
|
844 |
+
batch['features'] = [instance['feature'] for instance in instances]
|
845 |
+
elif 'image' in instances[0]:
|
846 |
+
images = [instance['image'] for instance in instances]
|
847 |
+
if all(x is not None and x.shape == images[0].shape for x in images):
|
848 |
+
batch['images'] = torch.stack(images)
|
849 |
+
else:
|
850 |
+
batch['images'] = images
|
851 |
+
|
852 |
+
|
853 |
+
return batch
|
854 |
+
|
855 |
+
|
856 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
857 |
+
data_args) -> Dict:
|
858 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
859 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
860 |
+
data_path=data_args.data_path,
|
861 |
+
data_args=data_args)
|
862 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
863 |
+
return dict(train_dataset=train_dataset,
|
864 |
+
eval_dataset=None,
|
865 |
+
data_collator=data_collator)
|
866 |
+
|
867 |
+
|
868 |
+
def train():
|
869 |
+
global local_rank
|
870 |
+
|
871 |
+
parser = transformers.HfArgumentParser(
|
872 |
+
(ModelArguments, DataArguments, TrainingArguments))
|
873 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
874 |
+
local_rank = training_args.local_rank
|
875 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
876 |
+
|
877 |
+
bnb_model_from_pretrained_args = {}
|
878 |
+
if training_args.bits in [4, 8]:
|
879 |
+
from transformers import BitsAndBytesConfig
|
880 |
+
bnb_model_from_pretrained_args.update(dict(
|
881 |
+
device_map={"": training_args.device},
|
882 |
+
load_in_4bit=training_args.bits == 4,
|
883 |
+
load_in_8bit=training_args.bits == 8,
|
884 |
+
quantization_config=BitsAndBytesConfig(
|
885 |
+
load_in_4bit=training_args.bits == 4,
|
886 |
+
load_in_8bit=training_args.bits == 8,
|
887 |
+
llm_int8_skip_modules=["mm_projector"],
|
888 |
+
llm_int8_threshold=6.0,
|
889 |
+
llm_int8_has_fp16_weight=False,
|
890 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
891 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
892 |
+
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
|
893 |
+
)
|
894 |
+
))
|
895 |
+
|
896 |
+
if model_args.vision_tower is not None:
|
897 |
+
model = VStreamLlamaForCausalLM.from_pretrained(
|
898 |
+
model_args.model_name_or_path,
|
899 |
+
cache_dir=training_args.cache_dir,
|
900 |
+
**bnb_model_from_pretrained_args
|
901 |
+
)
|
902 |
+
else:
|
903 |
+
model = transformers.LlamaForCausalLM.from_pretrained(
|
904 |
+
model_args.model_name_or_path,
|
905 |
+
cache_dir=training_args.cache_dir,
|
906 |
+
**bnb_model_from_pretrained_args
|
907 |
+
)
|
908 |
+
model.config.use_cache = False
|
909 |
+
|
910 |
+
if model_args.freeze_backbone:
|
911 |
+
model.model.requires_grad_(False)
|
912 |
+
|
913 |
+
if training_args.bits in [4, 8]:
|
914 |
+
from peft import prepare_model_for_kbit_training
|
915 |
+
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
916 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
917 |
+
|
918 |
+
if training_args.gradient_checkpointing:
|
919 |
+
if hasattr(model, "enable_input_require_grads"):
|
920 |
+
model.enable_input_require_grads()
|
921 |
+
else:
|
922 |
+
def make_inputs_require_grad(module, input, output):
|
923 |
+
output.requires_grad_(True)
|
924 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
925 |
+
|
926 |
+
if training_args.lora_enable:
|
927 |
+
from peft import LoraConfig, get_peft_model
|
928 |
+
lora_config = LoraConfig(
|
929 |
+
r=training_args.lora_r,
|
930 |
+
lora_alpha=training_args.lora_alpha,
|
931 |
+
target_modules=find_all_linear_names(model),
|
932 |
+
lora_dropout=training_args.lora_dropout,
|
933 |
+
bias=training_args.lora_bias,
|
934 |
+
task_type="CAUSAL_LM",
|
935 |
+
)
|
936 |
+
if training_args.bits == 16:
|
937 |
+
if training_args.bf16:
|
938 |
+
model.to(torch.bfloat16)
|
939 |
+
if training_args.fp16:
|
940 |
+
model.to(torch.float16)
|
941 |
+
rank0_print("Adding LoRA adapters...")
|
942 |
+
model = get_peft_model(model, lora_config)
|
943 |
+
|
944 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
945 |
+
model_args.model_name_or_path,
|
946 |
+
cache_dir=training_args.cache_dir,
|
947 |
+
model_max_length=training_args.model_max_length,
|
948 |
+
padding_side="right",
|
949 |
+
use_fast=False,
|
950 |
+
)
|
951 |
+
|
952 |
+
if model_args.version == "v0":
|
953 |
+
if tokenizer.pad_token is None:
|
954 |
+
smart_tokenizer_and_embedding_resize(
|
955 |
+
special_tokens_dict=dict(pad_token="[PAD]"),
|
956 |
+
tokenizer=tokenizer,
|
957 |
+
model=model,
|
958 |
+
)
|
959 |
+
elif model_args.version == "v0.5":
|
960 |
+
tokenizer.pad_token = tokenizer.unk_token
|
961 |
+
else:
|
962 |
+
tokenizer.pad_token = tokenizer.unk_token
|
963 |
+
if model_args.version in conversation_lib.conv_templates:
|
964 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
|
965 |
+
else:
|
966 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
|
967 |
+
|
968 |
+
if model_args.vision_tower is not None:
|
969 |
+
model_args.video_sample_type = data_args.video_sample_type
|
970 |
+
model_args.video_max_frames = data_args.video_max_frames
|
971 |
+
model_args.video_long_memory_length = data_args.video_long_memory_length
|
972 |
+
model_args.video_Turing_memory_length = data_args.video_Turing_memory_length
|
973 |
+
model_args.video_short_memory_length = data_args.video_short_memory_length
|
974 |
+
model_args.video_current_memory_length = data_args.video_current_memory_length
|
975 |
+
model.get_model().initialize_vision_modules(
|
976 |
+
model_args=model_args,
|
977 |
+
fsdp=training_args.fsdp
|
978 |
+
)
|
979 |
+
|
980 |
+
vision_tower = model.get_vision_tower()
|
981 |
+
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
|
982 |
+
|
983 |
+
data_args.image_processor = vision_tower.image_processor
|
984 |
+
data_args.is_multimodal = True
|
985 |
+
|
986 |
+
model.config.image_aspect_ratio = data_args.image_aspect_ratio
|
987 |
+
model.config.tokenizer_padding_side = tokenizer.padding_side
|
988 |
+
model.config.tokenizer_model_max_length = tokenizer.model_max_length
|
989 |
+
|
990 |
+
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
|
991 |
+
if model_args.tune_mm_mlp_adapter:
|
992 |
+
model.requires_grad_(False)
|
993 |
+
for p in model.get_model().mm_projector.parameters():
|
994 |
+
p.requires_grad = True
|
995 |
+
for p in model.get_model().attention_model.parameters():
|
996 |
+
p.requires_grad = True
|
997 |
+
|
998 |
+
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
|
999 |
+
if training_args.freeze_mm_mlp_adapter:
|
1000 |
+
for p in model.get_model().mm_projector.parameters():
|
1001 |
+
p.requires_grad = False
|
1002 |
+
for p in model.get_model().attention_model.parameters():
|
1003 |
+
p.requires_grad = False
|
1004 |
+
|
1005 |
+
if training_args.bits in [4, 8]:
|
1006 |
+
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
|
1007 |
+
|
1008 |
+
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
|
1009 |
+
model.config.mm_projector_lr = training_args.mm_projector_lr
|
1010 |
+
training_args.use_im_start_end = model_args.mm_use_im_start_end
|
1011 |
+
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
|
1012 |
+
model.config.mm_use_4_vision_tokens = model_args.mm_use_4_vision_tokens
|
1013 |
+
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
|
1014 |
+
|
1015 |
+
if training_args.bits in [4, 8]:
|
1016 |
+
from peft.tuners.lora import LoraLayer
|
1017 |
+
for name, module in model.named_modules():
|
1018 |
+
if isinstance(module, LoraLayer):
|
1019 |
+
if training_args.bf16:
|
1020 |
+
module = module.to(torch.bfloat16)
|
1021 |
+
if 'norm' in name:
|
1022 |
+
module = module.to(torch.float32)
|
1023 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
1024 |
+
if hasattr(module, 'weight'):
|
1025 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
1026 |
+
module = module.to(torch.bfloat16)
|
1027 |
+
|
1028 |
+
data_args.mm_hidden_size = model.get_vision_tower().hidden_size
|
1029 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer,
|
1030 |
+
data_args=data_args)
|
1031 |
+
trainer = VStreamTrainer(model=model,
|
1032 |
+
tokenizer=tokenizer,
|
1033 |
+
args=training_args,
|
1034 |
+
**data_module)
|
1035 |
+
|
1036 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
1037 |
+
trainer.train(resume_from_checkpoint=True)
|
1038 |
+
else:
|
1039 |
+
trainer.train()
|
1040 |
+
trainer.save_state()
|
1041 |
+
|
1042 |
+
model.config.use_cache = True
|
1043 |
+
|
1044 |
+
if training_args.lora_enable:
|
1045 |
+
state_dict = get_peft_state_maybe_zero_3(
|
1046 |
+
model.named_parameters(), training_args.lora_bias
|
1047 |
+
)
|
1048 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
|
1049 |
+
model.named_parameters()
|
1050 |
+
)
|
1051 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
1052 |
+
model.config.save_pretrained(training_args.output_dir)
|
1053 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
1054 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
1055 |
+
else:
|
1056 |
+
safe_save_model_for_hf_trainer(trainer=trainer,
|
1057 |
+
output_dir=training_args.output_dir)
|
1058 |
+
|
1059 |
+
|
1060 |
+
if __name__ == "__main__":
|
1061 |
+
# random.seed(42)
|
1062 |
+
# np.random.seed(42)
|
1063 |
+
# torch.manual_seed(42)
|
1064 |
+
# torch.cuda.manual_seed(42)
|
1065 |
+
# torch.cuda.manual_seed_all(42)
|
1066 |
+
# torch.backends.cudnn.deterministic = True
|
1067 |
+
# torch.backends.cudnn.benchmark = False
|
1068 |
+
|
1069 |
+
train()
|
flash_vstream/train/train_mem.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from https://github.com/haotian-liu/LLaVA.
|
2 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
3 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
4 |
+
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
5 |
+
|
6 |
+
# Need to call this before importing transformers.
|
7 |
+
from flash_vstream.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
8 |
+
|
9 |
+
replace_llama_attn_with_flash_attn()
|
10 |
+
|
11 |
+
from flash_vstream.train.train import train
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
train()
|
flash_vstream/train/train_xformers.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
|
4 |
+
|
5 |
+
# Need to call this before importing transformers.
|
6 |
+
from flash_vstream.train.llama_xformers_attn_monkey_patch import (
|
7 |
+
replace_llama_attn_with_xformers_attn,
|
8 |
+
)
|
9 |
+
|
10 |
+
replace_llama_attn_with_xformers_attn()
|
11 |
+
|
12 |
+
from flash_vstream.train.train import train
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
train()
|
flash_vstream/train/vstream_trainer.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
|
2 |
+
# ------------------------------------------------------------------------
|
3 |
+
# Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
4 |
+
# Copyright 2023 Haotian Liu
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import os
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
|
22 |
+
from torch.utils.data import Sampler
|
23 |
+
|
24 |
+
from transformers import Trainer
|
25 |
+
from transformers.trainer import (
|
26 |
+
is_sagemaker_mp_enabled,
|
27 |
+
get_parameter_names,
|
28 |
+
has_length,
|
29 |
+
ALL_LAYERNORM_LAYERS,
|
30 |
+
ShardedDDPOption,
|
31 |
+
logger,
|
32 |
+
)
|
33 |
+
from typing import List, Optional
|
34 |
+
|
35 |
+
|
36 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
37 |
+
from deepspeed import zero
|
38 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
39 |
+
if hasattr(param, "ds_id"):
|
40 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
41 |
+
if not ignore_status:
|
42 |
+
print(name, 'no ignore status')
|
43 |
+
with zero.GatheredParameters([param]):
|
44 |
+
param = param.data.detach().cpu().clone()
|
45 |
+
else:
|
46 |
+
param = param.detach().cpu().clone()
|
47 |
+
return param
|
48 |
+
|
49 |
+
|
50 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
51 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
52 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
|
53 |
+
return to_return
|
54 |
+
|
55 |
+
|
56 |
+
def split_to_even_chunks(indices, lengths, num_chunks):
|
57 |
+
"""
|
58 |
+
Split a list of indices into `chunks` chunks of roughly equal lengths.
|
59 |
+
"""
|
60 |
+
|
61 |
+
if len(indices) % num_chunks != 0:
|
62 |
+
return [indices[i::num_chunks] for i in range(num_chunks)]
|
63 |
+
|
64 |
+
num_indices_per_chunk = len(indices) // num_chunks
|
65 |
+
|
66 |
+
chunks = [[] for _ in range(num_chunks)]
|
67 |
+
chunks_lengths = [0 for _ in range(num_chunks)]
|
68 |
+
for index in indices:
|
69 |
+
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
|
70 |
+
chunks[shortest_chunk].append(index)
|
71 |
+
chunks_lengths[shortest_chunk] += lengths[index]
|
72 |
+
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
|
73 |
+
chunks_lengths[shortest_chunk] = float("inf")
|
74 |
+
|
75 |
+
return chunks
|
76 |
+
|
77 |
+
|
78 |
+
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
|
79 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
80 |
+
assert all(l != 0 for l in lengths), "Should not have zero length."
|
81 |
+
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
|
82 |
+
# all samples are in the same modality
|
83 |
+
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
|
84 |
+
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
|
85 |
+
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
|
86 |
+
|
87 |
+
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
|
88 |
+
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
|
89 |
+
megabatch_size = world_size * batch_size
|
90 |
+
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
|
91 |
+
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
|
92 |
+
|
93 |
+
last_mm = mm_megabatches[-1]
|
94 |
+
last_lang = lang_megabatches[-1]
|
95 |
+
additional_batch = last_mm + last_lang
|
96 |
+
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
|
97 |
+
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
|
98 |
+
megabatches = [megabatches[i] for i in megabatch_indices]
|
99 |
+
|
100 |
+
if len(additional_batch) > 0:
|
101 |
+
megabatches.append(sorted(additional_batch))
|
102 |
+
|
103 |
+
return [i for megabatch in megabatches for i in megabatch]
|
104 |
+
|
105 |
+
|
106 |
+
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
|
107 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
108 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
109 |
+
megabatch_size = world_size * batch_size
|
110 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
111 |
+
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
|
112 |
+
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
|
113 |
+
|
114 |
+
return [i for megabatch in megabatches for batch in megabatch for i in batch]
|
115 |
+
|
116 |
+
|
117 |
+
class LengthGroupedSampler(Sampler):
|
118 |
+
r"""
|
119 |
+
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
120 |
+
keeping a bit of randomness.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
batch_size: int,
|
126 |
+
world_size: int,
|
127 |
+
lengths: Optional[List[int]] = None,
|
128 |
+
generator=None,
|
129 |
+
group_by_modality: bool = False,
|
130 |
+
):
|
131 |
+
if lengths is None:
|
132 |
+
raise ValueError("Lengths must be provided.")
|
133 |
+
|
134 |
+
self.batch_size = batch_size
|
135 |
+
self.world_size = world_size
|
136 |
+
self.lengths = lengths
|
137 |
+
self.generator = generator
|
138 |
+
self.group_by_modality = group_by_modality
|
139 |
+
|
140 |
+
def __len__(self):
|
141 |
+
return len(self.lengths)
|
142 |
+
|
143 |
+
def __iter__(self):
|
144 |
+
if self.group_by_modality:
|
145 |
+
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
146 |
+
else:
|
147 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
148 |
+
return iter(indices)
|
149 |
+
|
150 |
+
|
151 |
+
class VStreamTrainer(Trainer):
|
152 |
+
|
153 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
154 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
155 |
+
return None
|
156 |
+
|
157 |
+
if self.args.group_by_modality_length:
|
158 |
+
lengths = self.train_dataset.modality_lengths
|
159 |
+
return LengthGroupedSampler(
|
160 |
+
self.args.train_batch_size,
|
161 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
|
162 |
+
lengths=lengths,
|
163 |
+
group_by_modality=True,
|
164 |
+
)
|
165 |
+
else:
|
166 |
+
return super()._get_train_sampler()
|
167 |
+
|
168 |
+
def create_optimizer(self):
|
169 |
+
"""
|
170 |
+
Setup the optimizer.
|
171 |
+
|
172 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
173 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
174 |
+
"""
|
175 |
+
if is_sagemaker_mp_enabled():
|
176 |
+
return super().create_optimizer()
|
177 |
+
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
178 |
+
return super().create_optimizer()
|
179 |
+
|
180 |
+
opt_model = self.model
|
181 |
+
|
182 |
+
if self.optimizer is None:
|
183 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
184 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
185 |
+
if self.args.mm_projector_lr is not None:
|
186 |
+
projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
|
187 |
+
optimizer_grouped_parameters = [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
|
191 |
+
],
|
192 |
+
"weight_decay": self.args.weight_decay,
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"params": [
|
196 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
|
197 |
+
],
|
198 |
+
"weight_decay": 0.0,
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"params": [
|
202 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
|
203 |
+
],
|
204 |
+
"weight_decay": self.args.weight_decay,
|
205 |
+
"lr": self.args.mm_projector_lr,
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"params": [
|
209 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
|
210 |
+
],
|
211 |
+
"weight_decay": 0.0,
|
212 |
+
"lr": self.args.mm_projector_lr,
|
213 |
+
},
|
214 |
+
]
|
215 |
+
else:
|
216 |
+
optimizer_grouped_parameters = [
|
217 |
+
{
|
218 |
+
"params": [
|
219 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
|
220 |
+
],
|
221 |
+
"weight_decay": self.args.weight_decay,
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"params": [
|
225 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
226 |
+
],
|
227 |
+
"weight_decay": 0.0,
|
228 |
+
},
|
229 |
+
]
|
230 |
+
|
231 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
232 |
+
|
233 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
234 |
+
if optimizer_cls.__name__ == "Adam8bit":
|
235 |
+
import bitsandbytes
|
236 |
+
|
237 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
238 |
+
|
239 |
+
skipped = 0
|
240 |
+
for module in opt_model.modules():
|
241 |
+
if isinstance(module, nn.Embedding):
|
242 |
+
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
243 |
+
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
244 |
+
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
245 |
+
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
246 |
+
logger.info(f"skipped: {skipped/2**20}M params")
|
247 |
+
|
248 |
+
return self.optimizer
|
flash_vstream/utils.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/haotian-liu/LLaVA.
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import logging
|
5 |
+
import logging.handlers
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import requests
|
10 |
+
|
11 |
+
from flash_vstream.constants import LOGDIR
|
12 |
+
|
13 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
14 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
15 |
+
|
16 |
+
handler = None
|
17 |
+
|
18 |
+
|
19 |
+
def build_logger(logger_name, logger_filename):
|
20 |
+
global handler
|
21 |
+
|
22 |
+
formatter = logging.Formatter(
|
23 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
24 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
25 |
+
)
|
26 |
+
|
27 |
+
# Set the format of root handlers
|
28 |
+
if not logging.getLogger().handlers:
|
29 |
+
logging.basicConfig(level=logging.INFO)
|
30 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
31 |
+
|
32 |
+
# Redirect stdout and stderr to loggers
|
33 |
+
stdout_logger = logging.getLogger("stdout")
|
34 |
+
stdout_logger.setLevel(logging.INFO)
|
35 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
36 |
+
sys.stdout = sl
|
37 |
+
|
38 |
+
stderr_logger = logging.getLogger("stderr")
|
39 |
+
stderr_logger.setLevel(logging.ERROR)
|
40 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
41 |
+
sys.stderr = sl
|
42 |
+
|
43 |
+
# Get logger
|
44 |
+
logger = logging.getLogger(logger_name)
|
45 |
+
logger.setLevel(logging.INFO)
|
46 |
+
|
47 |
+
# Add a file handler for all loggers
|
48 |
+
if handler is None:
|
49 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
50 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
51 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
52 |
+
filename, when='D', utc=True, encoding='UTF-8')
|
53 |
+
handler.setFormatter(formatter)
|
54 |
+
|
55 |
+
for name, item in logging.root.manager.loggerDict.items():
|
56 |
+
if isinstance(item, logging.Logger):
|
57 |
+
item.addHandler(handler)
|
58 |
+
|
59 |
+
return logger
|
60 |
+
|
61 |
+
|
62 |
+
class StreamToLogger(object):
|
63 |
+
"""
|
64 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
65 |
+
"""
|
66 |
+
def __init__(self, logger, log_level=logging.INFO):
|
67 |
+
self.terminal = sys.stdout
|
68 |
+
self.logger = logger
|
69 |
+
self.log_level = log_level
|
70 |
+
self.linebuf = ''
|
71 |
+
|
72 |
+
def __getattr__(self, attr):
|
73 |
+
return getattr(self.terminal, attr)
|
74 |
+
|
75 |
+
def write(self, buf):
|
76 |
+
temp_linebuf = self.linebuf + buf
|
77 |
+
self.linebuf = ''
|
78 |
+
for line in temp_linebuf.splitlines(True):
|
79 |
+
# From the io.TextIOWrapper docs:
|
80 |
+
# On output, if newline is None, any '\n' characters written
|
81 |
+
# are translated to the system default line separator.
|
82 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
83 |
+
# translates them so this is still cross platform.
|
84 |
+
if line[-1] == '\n':
|
85 |
+
self.logger.log(self.log_level, line.rstrip())
|
86 |
+
else:
|
87 |
+
self.linebuf += line
|
88 |
+
|
89 |
+
def flush(self):
|
90 |
+
if self.linebuf != '':
|
91 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
92 |
+
self.linebuf = ''
|
93 |
+
|
94 |
+
|
95 |
+
def disable_torch_init():
|
96 |
+
"""
|
97 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
98 |
+
"""
|
99 |
+
import torch
|
100 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
101 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
102 |
+
|
103 |
+
|
104 |
+
def violates_moderation(text):
|
105 |
+
"""
|
106 |
+
Check whether the text violates OpenAI moderation API.
|
107 |
+
"""
|
108 |
+
url = "https://api.openai.com/v1/moderations"
|
109 |
+
headers = {"Content-Type": "application/json",
|
110 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
111 |
+
text = text.replace("\n", "")
|
112 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
113 |
+
data = data.encode("utf-8")
|
114 |
+
try:
|
115 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
116 |
+
flagged = ret.json()["results"][0]["flagged"]
|
117 |
+
except requests.exceptions.RequestException as e:
|
118 |
+
flagged = False
|
119 |
+
except KeyError as e:
|
120 |
+
flagged = False
|
121 |
+
|
122 |
+
return flagged
|
123 |
+
|
124 |
+
|
125 |
+
def pretty_print_semaphore(semaphore):
|
126 |
+
if semaphore is None:
|
127 |
+
return "None"
|
128 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
requirements.txt
CHANGED
@@ -1 +1 @@
|
|
1 |
-
huggingface_hub==0.22.2
|
|
|
1 |
+
huggingface_hub==0.22.2
|