capradeepgujaran commited on
Commit
0052d38
1 Parent(s): 3f506ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -197
app.py CHANGED
@@ -4,243 +4,219 @@ from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForCondi
4
  import torch
5
  from PIL import Image
6
  import faiss
7
- from typing import List, Dict, Tuple
8
  import logging
9
  import gradio as gr
10
  import tempfile
11
  import os
12
- import shutil
13
- from tqdm import tqdm
14
  from pathlib import Path
15
- from moviepy.video.io.VideoFileClip import VideoFileClip
16
 
17
- class VideoRAGSystem:
18
  def __init__(self):
19
- self.logger = self.setup_logger()
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- self.logger.info(f"Using device: {self.device}")
 
 
 
 
 
 
 
 
22
 
23
- # Initialize models
24
- self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
 
 
 
 
 
25
  self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
-
27
- self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
28
  self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
29
  "Salesforce/blip2-opt-2.7b",
30
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
31
  ).to(self.device)
 
32
 
33
- # Vector store setup
34
- self.frame_index = None
35
- self.frame_data = []
36
- self.target_size = (224, 224)
37
-
38
- # Create directories for storing processed data
39
- self.temp_dir = tempfile.mkdtemp()
40
- self.frames_dir = os.path.join(self.temp_dir, "frames")
41
- os.makedirs(self.frames_dir, exist_ok=True)
42
-
43
- def setup_logger(self) -> logging.Logger:
44
- logger = logging.getLogger('VideoRAGSystem')
45
- if logger.handlers:
46
- logger.handlers.clear()
47
- logger.setLevel(logging.INFO)
48
- handler = logging.StreamHandler()
49
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
50
- handler.setFormatter(formatter)
51
- logger.addHandler(handler)
52
- return logger
53
-
54
- def split_video(self, video_path: str, timestamp_ms: int, context_seconds: int = 3) -> str:
55
- """Extract a clip around the specified timestamp"""
56
- timestamp_sec = timestamp_ms / 1000
57
- output_path = os.path.join(self.temp_dir, "clip.mp4")
58
-
59
- with VideoFileClip(video_path) as video:
60
- duration = video.duration
61
- start_time = max(timestamp_sec - context_seconds, 0)
62
- end_time = min(timestamp_sec + context_seconds, duration)
63
- clip = video.subclip(start_time, end_time)
64
- clip.write_videofile(output_path, audio_codec='aac')
65
-
66
- return output_path
67
 
68
  @torch.no_grad()
69
- def analyze_frame(self, image: Image.Image) -> Dict:
70
- """Comprehensive frame analysis"""
71
  try:
72
- # Generate caption
73
- inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
74
- if self.device.type == "cuda":
75
- inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
76
- caption = self.blip_model.generate(**inputs, max_length=50)
77
- caption_text = self.blip_processor.decode(caption[0], skip_special_tokens=True)
78
-
79
- # Get visual features
80
- clip_inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
81
  if self.device.type == "cuda":
82
  clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()}
83
  features = self.clip_model.get_image_features(**clip_inputs)
84
-
85
- return {
86
- "caption": caption_text,
87
- "features": features.cpu().numpy()
88
- }
 
 
 
 
89
  except Exception as e:
90
- self.logger.error(f"Frame analysis error: {str(e)}")
91
- return None
92
 
93
- def extract_keyframes(self, video_path: str, max_frames: int = 15) -> List[Dict]:
94
- """Extract and analyze key frames"""
95
  cap = cv2.VideoCapture(video_path)
96
- frames_info = []
97
- frame_count = 0
98
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
99
- interval = max(1, total_frames // max_frames)
100
-
101
- with tqdm(total=max_frames, desc="Analyzing frames") as pbar:
102
- while len(frames_info) < max_frames and cap.isOpened():
 
 
 
 
 
 
 
 
 
 
 
 
103
  ret, frame = cap.read()
104
  if not ret:
105
  break
106
-
107
- if frame_count % interval == 0:
108
- # Save frame
109
- frame_path = os.path.join(self.frames_dir, f"frame_{frame_count}.jpg")
110
- cv2.imwrite(frame_path, frame)
111
 
112
- # Analyze frame
113
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
114
- image = Image.fromarray(frame_rgb).resize(self.target_size, Image.LANCZOS)
115
- analysis = self.analyze_frame(image)
116
 
117
- if analysis is not None:
118
- frames_info.append({
119
- "frame_number": frame_count,
120
- "timestamp": frame_count / cap.get(cv2.CAP_PROP_FPS),
121
- "path": frame_path,
122
- "caption": analysis["caption"],
123
- "features": analysis["features"]
124
- })
125
- pbar.update(1)
126
-
 
 
 
 
 
 
 
 
 
 
 
127
  frame_count += 1
128
-
129
- cap.release()
130
- return frames_info
131
-
132
- def process_video(self, video_path: str):
133
- """Process video and build search index"""
134
- self.logger.info(f"Processing video: {video_path}")
135
-
136
- try:
137
- # Extract and analyze frames
138
- frames_info = self.extract_keyframes(video_path)
139
- self.frame_data = frames_info
140
-
141
- # Build FAISS index
142
- if frames_info:
143
- features = np.vstack([frame["features"] for frame in frames_info])
144
- self.frame_index = faiss.IndexFlatL2(features.shape[1])
145
- self.frame_index.add(features)
146
 
147
- self.logger.info(f"Processed {len(frames_info)} frames successfully")
148
- return True
149
 
 
 
 
 
 
 
 
 
 
 
150
  except Exception as e:
151
- self.logger.error(f"Video processing error: {str(e)}")
152
- return False
153
-
154
- @torch.no_grad()
155
- def search_frames(self, query: str, k: int = 4) -> List[Dict]:
156
- """Search for relevant frames based on the query"""
157
- try:
158
- # Process query
159
- inputs = self.clip_processor(text=[query], return_tensors="pt").to(self.device)
160
- if self.device.type == "cuda":
161
- inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
162
- query_features = self.clip_model.get_text_features(**inputs)
163
-
164
- # Search
165
- distances, indices = self.frame_index.search(
166
- query_features.cpu().numpy(),
167
- k
168
- )
169
-
170
- # Prepare results
171
- results = []
172
- for distance, idx in zip(distances[0], indices[0]):
173
- frame_info = self.frame_data[idx].copy()
174
- frame_info["relevance"] = float(1 / (1 + distance))
175
- results.append(frame_info)
176
-
177
- return results
178
-
179
- except Exception as e:
180
- self.logger.error(f"Search error: {str(e)}")
181
- return []
182
 
183
  class VideoQAInterface:
184
  def __init__(self):
185
- self.rag_system = VideoRAGSystem()
186
- self.current_video = None
 
187
  self.processed = False
188
 
189
- def process_video(self, video_file):
190
- """Handle video upload and processing"""
191
  try:
192
  if video_file is None:
193
- return "Please upload a video first.", gr.Progress(0)
194
-
195
- self.current_video = video_file.name
196
- success = self.rag_system.process_video(self.current_video)
 
 
197
 
198
- if success:
199
  self.processed = True
200
- return "Video processed successfully! You can now ask questions.", gr.Progress(100)
201
  else:
202
- return "Error processing video. Please try again.", gr.Progress(0)
203
-
 
204
  except Exception as e:
205
  self.processed = False
206
- return f"Error: {str(e)}", gr.Progress(0)
207
 
 
208
  def answer_question(self, query):
209
- """Handle question answering"""
210
  if not self.processed:
211
  return None, "Please process a video first."
212
-
213
  try:
 
 
 
 
214
  # Search for relevant frames
215
- results = self.rag_system.search_frames(query)
 
216
 
217
- if not results:
218
- return None, "No relevant frames found."
219
-
220
- # Prepare output
221
- frames = []
 
 
222
  descriptions = []
223
-
 
 
224
  for result in results:
225
- # Load frame
226
- frame = Image.open(result["path"])
227
- frames.append(frame)
228
 
229
- # Prepare description
230
- desc = f"Timestamp: {result['timestamp']:.2f}s\n"
231
- desc += f"Scene Description: {result['caption']}\n"
232
- desc += f"Relevance Score: {result['relevance']:.2f}"
233
- descriptions.append(desc)
234
-
235
- # Combine descriptions
 
 
 
 
236
  combined_desc = "\n\nFrame Analysis:\n\n"
237
  for i, desc in enumerate(descriptions, 1):
238
  combined_desc += f"Frame {i}:\n{desc}\n\n"
239
-
240
  return frames, combined_desc
241
-
242
  except Exception as e:
243
- return None, f"Error: {str(e)}"
244
 
245
  def create_interface(self):
246
  """Create Gradio interface"""
@@ -249,45 +225,42 @@ class VideoQAInterface:
249
  gr.Markdown("Upload a video and ask questions about any aspect of its content!")
250
 
251
  with gr.Row():
252
- video_input = gr.File(
253
- label="Upload Video",
254
- file_types=["video"],
255
- )
256
- process_button = gr.Button("Process Video")
257
-
258
- status_output = gr.Textbox(
259
- label="Status",
260
- interactive=False
261
- )
262
 
263
  with gr.Row():
264
  query_input = gr.Textbox(
265
  label="Ask about the video",
266
  placeholder="What's happening in the video?"
267
  )
268
- query_button = gr.Button("Search")
269
 
270
  gallery = gr.Gallery(
271
  label="Retrieved Frames",
272
  show_label=True,
273
  columns=[2],
274
- rows=[2],
275
- height="auto"
276
  )
277
-
278
  descriptions = gr.Textbox(
279
- label="Scene Analysis",
280
  interactive=False,
281
  lines=10
282
  )
283
 
284
- process_button.click(
 
285
  fn=self.process_video,
286
  inputs=[video_input],
287
- outputs=[status_output]
288
  )
289
 
290
- query_button.click(
291
  fn=self.answer_question,
292
  inputs=[query_input],
293
  outputs=[gallery, descriptions]
@@ -295,10 +268,9 @@ class VideoQAInterface:
295
 
296
  return interface
297
 
298
- # Initialize and create the interface
299
  app = VideoQAInterface()
300
  interface = app.create_interface()
301
 
302
- # Launch the app
303
  if __name__ == "__main__":
304
  interface.launch()
 
4
  import torch
5
  from PIL import Image
6
  import faiss
 
7
  import logging
8
  import gradio as gr
9
  import tempfile
10
  import os
11
+ from tqdm.auto import tqdm
 
12
  from pathlib import Path
13
+ import time
14
 
15
+ class VideoProcessor:
16
  def __init__(self):
 
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ # Load models with optimizations
20
+ self.load_models()
21
+
22
+ # Processing settings
23
+ self.frame_interval = 30 # Process 1 frame every 30 frames
24
+ self.max_frames = 50 # Maximum frames to process
25
+ self.target_size = (224, 224)
26
+ self.batch_size = 4 if torch.cuda.is_available() else 2
27
 
28
+ def load_models(self):
29
+ """Load models with optimizations"""
30
+ # Load CLIP
31
+ self.clip_model = CLIPModel.from_pretrained(
32
+ "openai/clip-vit-base-patch32",
33
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
34
+ ).to(self.device)
35
  self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
36
+
37
+ # Load BLIP2 with reduced size
38
  self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
39
  "Salesforce/blip2-opt-2.7b",
40
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
41
+ device_map="auto" if torch.cuda.is_available() else None
42
  ).to(self.device)
43
+ self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
44
 
45
+ # Set models to evaluation mode
46
+ self.clip_model.eval()
47
+ self.blip_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @torch.no_grad()
50
+ def process_frame_batch(self, frames):
51
+ """Process a batch of frames efficiently"""
52
  try:
53
+ # Convert frames to PIL Images
54
+ pil_frames = [Image.fromarray(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)).resize(self.target_size) for f in frames]
55
+
56
+ # Get CLIP features
57
+ clip_inputs = self.clip_processor(images=pil_frames, return_tensors="pt", padding=True).to(self.device)
 
 
 
 
58
  if self.device.type == "cuda":
59
  clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()}
60
  features = self.clip_model.get_image_features(**clip_inputs)
61
+
62
+ # Get BLIP captions
63
+ blip_inputs = self.blip_processor(images=pil_frames, return_tensors="pt", padding=True).to(self.device)
64
+ if self.device.type == "cuda":
65
+ blip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in blip_inputs.items()}
66
+ captions = self.blip_model.generate(**blip_inputs, max_length=30)
67
+ captions = [self.blip_processor.decode(c, skip_special_tokens=True) for c in captions]
68
+
69
+ return features.cpu().numpy(), captions
70
  except Exception as e:
71
+ print(f"Error in batch processing: {str(e)}")
72
+ return None, None
73
 
74
+ def process_video(self, video_path: str, progress=gr.Progress()):
75
+ """Process video with batching and progress updates"""
76
  cap = cv2.VideoCapture(video_path)
 
 
77
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
78
+ fps = cap.get(cv2.CAP_PROP_FPS)
79
+
80
+ # Calculate frames to process
81
+ frames_to_process = min(self.max_frames, total_frames // self.frame_interval)
82
+ progress(0, desc="Initializing...")
83
+
84
+ features_list = []
85
+ frame_data = []
86
+ current_batch = []
87
+ batch_positions = []
88
+
89
+ try:
90
+ frame_count = 0
91
+ processed_count = 0
92
+
93
+ while cap.isOpened() and processed_count < frames_to_process:
94
  ret, frame = cap.read()
95
  if not ret:
96
  break
 
 
 
 
 
97
 
98
+ if frame_count % self.frame_interval == 0:
99
+ current_batch.append(frame)
100
+ batch_positions.append(frame_count)
 
101
 
102
+ # Process batch when full
103
+ if len(current_batch) == self.batch_size or frame_count == total_frames - 1:
104
+ features, captions = self.process_frame_batch(current_batch)
105
+
106
+ if features is not None and captions is not None:
107
+ for i, (feat, cap) in enumerate(zip(features, captions)):
108
+ features_list.append(feat)
109
+ frame_data.append({
110
+ 'frame_number': batch_positions[i],
111
+ 'timestamp': batch_positions[i] / fps,
112
+ 'caption': cap
113
+ })
114
+
115
+ processed_count += len(current_batch)
116
+ current_batch = []
117
+ batch_positions = []
118
+
119
+ # Update progress
120
+ progress(processed_count / frames_to_process,
121
+ desc=f"Processing frames... {processed_count}/{frames_to_process}")
122
+
123
  frame_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ cap.release()
 
126
 
127
+ # Create FAISS index
128
+ if features_list:
129
+ features_array = np.vstack(features_list)
130
+ frame_index = faiss.IndexFlatL2(features_array.shape[1])
131
+ frame_index.add(features_array)
132
+
133
+ return frame_index, frame_data, "Video processed successfully!"
134
+ else:
135
+ return None, None, "No frames were processed successfully."
136
+
137
  except Exception as e:
138
+ cap.release()
139
+ return None, None, f"Error processing video: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  class VideoQAInterface:
142
  def __init__(self):
143
+ self.processor = VideoProcessor()
144
+ self.frame_index = None
145
+ self.frame_data = None
146
  self.processed = False
147
 
148
+ def process_video(self, video_file, progress=gr.Progress()):
149
+ """Process video with progress tracking"""
150
  try:
151
  if video_file is None:
152
+ return "Please upload a video first."
153
+
154
+ progress(0, desc="Starting video processing...")
155
+ self.frame_index, self.frame_data, message = self.processor.process_video(
156
+ video_file.name, progress
157
+ )
158
 
159
+ if self.frame_index is not None:
160
  self.processed = True
161
+ return "Video processed successfully! You can now ask questions."
162
  else:
163
+ self.processed = False
164
+ return message
165
+
166
  except Exception as e:
167
  self.processed = False
168
+ return f"Error: {str(e)}"
169
 
170
+ @torch.no_grad()
171
  def answer_question(self, query):
172
+ """Answer questions about the video"""
173
  if not self.processed:
174
  return None, "Please process a video first."
175
+
176
  try:
177
+ # Get query features
178
+ inputs = self.processor.clip_processor(text=[query], return_tensors="pt").to(self.processor.device)
179
+ query_features = self.processor.clip_model.get_text_features(**inputs)
180
+
181
  # Search for relevant frames
182
+ k = 4 # Number of frames to retrieve
183
+ D, I = self.frame_index.search(query_features.cpu().numpy(), k)
184
 
185
+ results = []
186
+ for distance, idx in zip(D[0], I[0]):
187
+ frame_info = self.frame_data[idx].copy()
188
+ frame_info['relevance'] = float(1 / (1 + distance))
189
+ results.append(frame_info)
190
+
191
+ # Format output
192
  descriptions = []
193
+ frames = []
194
+
195
+ cap = cv2.VideoCapture(video_file.name)
196
  for result in results:
197
+ frame_number = result['frame_number']
198
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
199
+ ret, frame = cap.read()
200
 
201
+ if ret:
202
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
203
+ frames.append(Image.fromarray(frame_rgb))
204
+
205
+ desc = f"Timestamp: {result['timestamp']:.2f}s\n"
206
+ desc += f"Scene Description: {result['caption']}\n"
207
+ desc += f"Relevance Score: {result['relevance']:.2f}"
208
+ descriptions.append(desc)
209
+
210
+ cap.release()
211
+
212
  combined_desc = "\n\nFrame Analysis:\n\n"
213
  for i, desc in enumerate(descriptions, 1):
214
  combined_desc += f"Frame {i}:\n{desc}\n\n"
215
+
216
  return frames, combined_desc
217
+
218
  except Exception as e:
219
+ return None, f"Error answering question: {str(e)}"
220
 
221
  def create_interface(self):
222
  """Create Gradio interface"""
 
225
  gr.Markdown("Upload a video and ask questions about any aspect of its content!")
226
 
227
  with gr.Row():
228
+ with gr.Column():
229
+ video_input = gr.File(
230
+ label="Upload Video",
231
+ file_types=["video"]
232
+ )
233
+ status = gr.Textbox(label="Status", interactive=False)
234
+ process_btn = gr.Button("Process Video")
 
 
 
235
 
236
  with gr.Row():
237
  query_input = gr.Textbox(
238
  label="Ask about the video",
239
  placeholder="What's happening in the video?"
240
  )
241
+ query_btn = gr.Button("Search")
242
 
243
  gallery = gr.Gallery(
244
  label="Retrieved Frames",
245
  show_label=True,
246
  columns=[2],
247
+ rows=[2]
 
248
  )
249
+
250
  descriptions = gr.Textbox(
251
+ label="Analysis",
252
  interactive=False,
253
  lines=10
254
  )
255
 
256
+ # Set up event handlers
257
+ process_btn.click(
258
  fn=self.process_video,
259
  inputs=[video_input],
260
+ outputs=[status]
261
  )
262
 
263
+ query_btn.click(
264
  fn=self.answer_question,
265
  inputs=[query_input],
266
  outputs=[gallery, descriptions]
 
268
 
269
  return interface
270
 
271
+ # Create and launch the app
272
  app = VideoQAInterface()
273
  interface = app.create_interface()
274
 
 
275
  if __name__ == "__main__":
276
  interface.launch()