capradeepgujaran commited on
Commit
a2433fb
1 Parent(s): ca72d14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -87
app.py CHANGED
@@ -11,9 +11,8 @@ import os
11
  import shutil
12
  from tqdm.auto import tqdm
13
  from pathlib import Path
14
- from typing import List, Dict, Tuple
15
- import time
16
- from huggingface_hub import snapshot_download
17
  import warnings
18
  warnings.filterwarnings("ignore")
19
 
@@ -22,13 +21,15 @@ os.environ["TRANSFORMERS_CACHE"] = "./model_cache"
22
  os.environ["HF_HOME"] = "./model_cache"
23
  os.makedirs("./model_cache", exist_ok=True)
24
 
 
 
25
  class VideoProcessor:
26
  def __init__(self):
27
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- print(f"Using device: {self.device}")
29
 
30
  # Load models with optimizations
31
- self.load_models()
32
 
33
  # Processing settings
34
  self.frame_interval = 30 # Process 1 frame every 30 frames
@@ -36,48 +37,57 @@ class VideoProcessor:
36
  self.target_size = (224, 224)
37
  self.batch_size = 4 if torch.cuda.is_available() else 2
38
 
39
- def load_models(self):
40
  """Load models with optimizations and proper configurations"""
41
- print("Loading CLIP model...")
42
- self.clip_model = CLIPModel.from_pretrained(
43
- "openai/clip-vit-base-patch32",
44
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
- cache_dir="./model_cache"
46
- ).to(self.device)
47
- self.clip_processor = CLIPProcessor.from_pretrained(
48
- "openai/clip-vit-base-patch32",
49
- cache_dir="./model_cache"
50
- )
51
-
52
- print("Loading BLIP2 model...")
53
- model_name = "Salesforce/blip2-opt-2.7b"
54
-
55
- # Initialize BLIP2 processor without config modifications
56
- self.blip_processor = Blip2Processor.from_pretrained(
57
- model_name,
58
- cache_dir="./model_cache"
59
- )
60
-
61
- # Load BLIP2 model with optimizations
62
- self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
63
- model_name,
64
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
65
- device_map="auto" if torch.cuda.is_available() else None,
66
- cache_dir="./model_cache",
67
- low_cpu_mem_usage=True
68
- ).to(self.device)
69
-
70
- # Set models to evaluation mode
71
- self.clip_model.eval()
72
- self.blip_model.eval()
73
- print("Models loaded successfully!")
 
 
 
 
 
 
 
 
 
74
 
75
  @torch.no_grad()
76
- def process_frame_batch(self, frames):
77
  """Process a batch of frames efficiently"""
78
  try:
79
  # Convert frames to PIL Images
80
- pil_frames = [Image.fromarray(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)).resize(self.target_size) for f in frames]
81
 
82
  # Get CLIP features
83
  clip_inputs = self.clip_processor(
@@ -90,7 +100,7 @@ class VideoProcessor:
90
  clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()}
91
  features = self.clip_model.get_image_features(**clip_inputs)
92
 
93
- # Get BLIP captions with updated processing
94
  blip_inputs = self.blip_processor(
95
  images=pil_frames,
96
  return_tensors="pt",
@@ -100,7 +110,7 @@ class VideoProcessor:
100
  if self.device.type == "cuda":
101
  blip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in blip_inputs.items()}
102
 
103
- # Generate captions with better parameters
104
  captions = self.blip_model.generate(
105
  **blip_inputs,
106
  max_length=30,
@@ -113,38 +123,44 @@ class VideoProcessor:
113
 
114
  captions = [self.blip_processor.decode(c, skip_special_tokens=True) for c in captions]
115
 
 
 
 
 
116
  return features.cpu().numpy(), captions
 
117
  except Exception as e:
118
- print(f"Error in batch processing: {str(e)}")
119
  return None, None
120
 
121
- def process_video(self, video_path: str, progress=gr.Progress()):
122
  """Process video with batching and progress updates"""
123
- cap = cv2.VideoCapture(video_path)
124
- if not cap.isOpened():
125
- raise ValueError("Could not open video file")
126
-
127
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
128
- fps = cap.get(cv2.CAP_PROP_FPS)
129
-
130
- # Calculate frames to process
131
- frames_to_process = min(self.max_frames, total_frames // self.frame_interval)
132
- progress(0, desc="Initializing video processing...")
133
-
134
- features_list = []
135
- frame_data = []
136
- current_batch = []
137
- batch_positions = []
138
-
139
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  frame_count = 0
141
  processed_count = 0
142
 
143
- while cap.isOpened() and processed_count < frames_to_process:
144
  ret, frame = cap.read()
145
  if not ret:
146
  break
147
-
148
  if frame_count % self.frame_interval == 0:
149
  current_batch.append(frame)
150
  batch_positions.append(frame_count)
@@ -157,12 +173,12 @@ class VideoProcessor:
157
  features, captions = self.process_frame_batch(current_batch)
158
 
159
  if features is not None and captions is not None:
160
- for i, (feat, cap) in enumerate(zip(features, captions)):
161
  features_list.append(feat)
162
  frame_data.append({
163
  'frame_number': batch_positions[i],
164
  'timestamp': batch_positions[i] / fps,
165
- 'caption': cap
166
  })
167
 
168
  processed_count += len(current_batch)
@@ -171,21 +187,25 @@ class VideoProcessor:
171
 
172
  frame_count += 1
173
 
174
- cap.release()
175
-
176
  # Create FAISS index
177
  if features_list:
178
  features_array = np.vstack(features_list)
179
  frame_index = faiss.IndexFlatL2(features_array.shape[1])
180
  frame_index.add(features_array)
181
-
182
  return frame_index, frame_data, "Video processed successfully!"
183
  else:
184
  return None, None, "No frames were processed successfully."
185
-
186
  except Exception as e:
187
- cap.release()
188
- raise e
 
 
 
 
 
 
 
189
 
190
  class VideoQAInterface:
191
  def __init__(self):
@@ -195,28 +215,28 @@ class VideoQAInterface:
195
  self.processed = False
196
  self.current_video_path = None
197
  self.temp_dir = tempfile.mkdtemp()
198
- print(f"Initialized temp directory: {self.temp_dir}")
199
 
200
  def __del__(self):
201
  """Cleanup temporary files"""
202
- if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
203
- try:
204
  shutil.rmtree(self.temp_dir)
205
- print(f"Cleaned up temp directory: {self.temp_dir}")
206
- except Exception as e:
207
- print(f"Error cleaning up temp directory: {str(e)}")
208
 
209
  def process_video(self, video_file, progress=gr.Progress()):
210
  """Process video with progress tracking"""
211
- try:
212
- if video_file is None:
213
- return "Please upload a video first."
214
 
 
215
  # Save uploaded video to temp directory
216
  temp_video_path = os.path.join(self.temp_dir, "input_video.mp4")
217
  shutil.copy2(video_file.name, temp_video_path)
218
  self.current_video_path = temp_video_path
219
- print(f"Saved video to: {self.current_video_path}")
220
 
221
  progress(0, desc="Starting video processing...")
222
  self.frame_index, self.frame_data, message = self.processor.process_video(
@@ -232,6 +252,7 @@ class VideoQAInterface:
232
 
233
  except Exception as e:
234
  self.processed = False
 
235
  return f"Error processing video: {str(e)}"
236
 
237
  @torch.no_grad()
@@ -259,7 +280,6 @@ class VideoQAInterface:
259
  descriptions = []
260
  frames = []
261
 
262
- # Use cv2.VideoCapture to read frames
263
  cap = cv2.VideoCapture(self.current_video_path)
264
  try:
265
  for result in results:
@@ -288,6 +308,7 @@ class VideoQAInterface:
288
  return frames, combined_desc
289
 
290
  except Exception as e:
 
291
  return None, f"Error answering question: {str(e)}"
292
 
293
  def create_interface(self):
@@ -341,14 +362,12 @@ class VideoQAInterface:
341
  return interface
342
 
343
  # Create and launch the app
344
-
345
  app = VideoQAInterface()
346
  interface = app.create_interface()
347
 
348
  if __name__ == "__main__":
349
  interface.launch(
350
- server_name="0.0.0.0", # Allow external connections
351
- share=False, # Set to True for public URL
352
- show_error=True, # Show detailed error messages
353
- quiet=False # Show server logs
354
  )
 
11
  import shutil
12
  from tqdm.auto import tqdm
13
  from pathlib import Path
14
+ from typing import List, Dict, Tuple, Optional
15
+ import gc
 
16
  import warnings
17
  warnings.filterwarnings("ignore")
18
 
 
21
  os.environ["HF_HOME"] = "./model_cache"
22
  os.makedirs("./model_cache", exist_ok=True)
23
 
24
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
25
+
26
  class VideoProcessor:
27
  def __init__(self):
28
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ logging.info(f"Using device: {self.device}")
30
 
31
  # Load models with optimizations
32
+ self._load_models()
33
 
34
  # Processing settings
35
  self.frame_interval = 30 # Process 1 frame every 30 frames
 
37
  self.target_size = (224, 224)
38
  self.batch_size = 4 if torch.cuda.is_available() else 2
39
 
40
+ def _load_models(self):
41
  """Load models with optimizations and proper configurations"""
42
+ try:
43
+ logging.info("Loading CLIP model...")
44
+ self.clip_model = CLIPModel.from_pretrained(
45
+ "openai/clip-vit-base-patch32",
46
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
47
+ cache_dir="./model_cache"
48
+ ).to(self.device)
49
+ self.clip_processor = CLIPProcessor.from_pretrained(
50
+ "openai/clip-vit-base-patch32",
51
+ cache_dir="./model_cache"
52
+ )
53
+
54
+ logging.info("Loading BLIP2 model...")
55
+ model_name = "Salesforce/blip2-opt-2.7b"
56
+
57
+ # Initialize BLIP2 with minimal configuration
58
+ self.blip_processor = Blip2Processor.from_pretrained(
59
+ model_name,
60
+ cache_dir="./model_cache"
61
+ )
62
+
63
+ self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
64
+ model_name,
65
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
66
+ device_map="auto" if torch.cuda.is_available() else None,
67
+ cache_dir="./model_cache",
68
+ low_cpu_mem_usage=True
69
+ ).to(self.device)
70
+
71
+ # Set models to evaluation mode
72
+ self.clip_model.eval()
73
+ self.blip_model.eval()
74
+
75
+ logging.info("Models loaded successfully!")
76
+ except Exception as e:
77
+ logging.error(f"Error loading models: {str(e)}")
78
+ raise
79
+
80
+ def _preprocess_frame(self, frame: np.ndarray) -> Image.Image:
81
+ """Preprocess a single frame"""
82
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
83
+ return Image.fromarray(rgb_frame).resize(self.target_size, Image.LANCZOS)
84
 
85
  @torch.no_grad()
86
+ def process_frame_batch(self, frames: List[np.ndarray]) -> Tuple[Optional[np.ndarray], Optional[List[str]]]:
87
  """Process a batch of frames efficiently"""
88
  try:
89
  # Convert frames to PIL Images
90
+ pil_frames = [self._preprocess_frame(f) for f in frames]
91
 
92
  # Get CLIP features
93
  clip_inputs = self.clip_processor(
 
100
  clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()}
101
  features = self.clip_model.get_image_features(**clip_inputs)
102
 
103
+ # Get BLIP captions
104
  blip_inputs = self.blip_processor(
105
  images=pil_frames,
106
  return_tensors="pt",
 
110
  if self.device.type == "cuda":
111
  blip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in blip_inputs.items()}
112
 
113
+ # Generate captions
114
  captions = self.blip_model.generate(
115
  **blip_inputs,
116
  max_length=30,
 
123
 
124
  captions = [self.blip_processor.decode(c, skip_special_tokens=True) for c in captions]
125
 
126
+ # Clear GPU memory if needed
127
+ if self.device.type == "cuda":
128
+ torch.cuda.empty_cache()
129
+
130
  return features.cpu().numpy(), captions
131
+
132
  except Exception as e:
133
+ logging.error(f"Error in batch processing: {str(e)}")
134
  return None, None
135
 
136
+ def process_video(self, video_path: str, progress: gr.Progress) -> Tuple[Optional[faiss.Index], Optional[List[Dict]], str]:
137
  """Process video with batching and progress updates"""
138
+ cap = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  try:
140
+ cap = cv2.VideoCapture(video_path)
141
+ if not cap.isOpened():
142
+ raise ValueError(f"Could not open video file: {video_path}")
143
+
144
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
145
+ fps = cap.get(cv2.CAP_PROP_FPS)
146
+
147
+ # Calculate frames to process
148
+ frames_to_process = min(self.max_frames, total_frames // self.frame_interval)
149
+ progress(0, desc="Initializing video processing...")
150
+
151
+ features_list = []
152
+ frame_data = []
153
+ current_batch = []
154
+ batch_positions = []
155
+
156
  frame_count = 0
157
  processed_count = 0
158
 
159
+ while processed_count < frames_to_process:
160
  ret, frame = cap.read()
161
  if not ret:
162
  break
163
+
164
  if frame_count % self.frame_interval == 0:
165
  current_batch.append(frame)
166
  batch_positions.append(frame_count)
 
173
  features, captions = self.process_frame_batch(current_batch)
174
 
175
  if features is not None and captions is not None:
176
+ for i, (feat, cap_text) in enumerate(zip(features, captions)):
177
  features_list.append(feat)
178
  frame_data.append({
179
  'frame_number': batch_positions[i],
180
  'timestamp': batch_positions[i] / fps,
181
+ 'caption': cap_text
182
  })
183
 
184
  processed_count += len(current_batch)
 
187
 
188
  frame_count += 1
189
 
 
 
190
  # Create FAISS index
191
  if features_list:
192
  features_array = np.vstack(features_list)
193
  frame_index = faiss.IndexFlatL2(features_array.shape[1])
194
  frame_index.add(features_array)
 
195
  return frame_index, frame_data, "Video processed successfully!"
196
  else:
197
  return None, None, "No frames were processed successfully."
198
+
199
  except Exception as e:
200
+ logging.error(f"Error processing video: {str(e)}")
201
+ return None, None, f"Error processing video: {str(e)}"
202
+
203
+ finally:
204
+ if cap is not None:
205
+ cap.release()
206
+ gc.collect()
207
+ if self.device.type == "cuda":
208
+ torch.cuda.empty_cache()
209
 
210
  class VideoQAInterface:
211
  def __init__(self):
 
215
  self.processed = False
216
  self.current_video_path = None
217
  self.temp_dir = tempfile.mkdtemp()
218
+ logging.info(f"Initialized temp directory: {self.temp_dir}")
219
 
220
  def __del__(self):
221
  """Cleanup temporary files"""
222
+ try:
223
+ if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
224
  shutil.rmtree(self.temp_dir)
225
+ logging.info(f"Cleaned up temp directory: {self.temp_dir}")
226
+ except Exception as e:
227
+ logging.error(f"Error cleaning up temp directory: {str(e)}")
228
 
229
  def process_video(self, video_file, progress=gr.Progress()):
230
  """Process video with progress tracking"""
231
+ if video_file is None:
232
+ return "Please upload a video first."
 
233
 
234
+ try:
235
  # Save uploaded video to temp directory
236
  temp_video_path = os.path.join(self.temp_dir, "input_video.mp4")
237
  shutil.copy2(video_file.name, temp_video_path)
238
  self.current_video_path = temp_video_path
239
+ logging.info(f"Saved video to: {self.current_video_path}")
240
 
241
  progress(0, desc="Starting video processing...")
242
  self.frame_index, self.frame_data, message = self.processor.process_video(
 
252
 
253
  except Exception as e:
254
  self.processed = False
255
+ logging.error(f"Error processing video: {str(e)}")
256
  return f"Error processing video: {str(e)}"
257
 
258
  @torch.no_grad()
 
280
  descriptions = []
281
  frames = []
282
 
 
283
  cap = cv2.VideoCapture(self.current_video_path)
284
  try:
285
  for result in results:
 
308
  return frames, combined_desc
309
 
310
  except Exception as e:
311
+ logging.error(f"Error answering question: {str(e)}")
312
  return None, f"Error answering question: {str(e)}"
313
 
314
  def create_interface(self):
 
362
  return interface
363
 
364
  # Create and launch the app
 
365
  app = VideoQAInterface()
366
  interface = app.create_interface()
367
 
368
  if __name__ == "__main__":
369
  interface.launch(
370
+ server_name="0.0.0.0",
371
+ share=False,
372
+ show_error=True
 
373
  )