capradeepgujaran commited on
Commit
17e6c9d
1 Parent(s): 17991a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -24
app.py CHANGED
@@ -13,9 +13,19 @@ from tqdm.auto import tqdm
13
  from pathlib import Path
14
  from typing import List, Dict, Tuple
15
  import time
 
 
 
 
 
 
 
 
 
16
  class VideoProcessor:
17
  def __init__(self):
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
19
 
20
  # Load models with optimizations
21
  self.load_models()
@@ -27,25 +37,42 @@ class VideoProcessor:
27
  self.batch_size = 4 if torch.cuda.is_available() else 2
28
 
29
  def load_models(self):
30
- """Load models with optimizations"""
31
- # Load CLIP
32
  self.clip_model = CLIPModel.from_pretrained(
33
  "openai/clip-vit-base-patch32",
34
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
35
  ).to(self.device)
36
- self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Load BLIP2 with reduced size
39
  self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
40
- "Salesforce/blip2-opt-2.7b",
41
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
42
- device_map="auto" if torch.cuda.is_available() else None
 
 
43
  ).to(self.device)
44
- self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
45
 
46
  # Set models to evaluation mode
47
  self.clip_model.eval()
48
  self.blip_model.eval()
 
49
 
50
  @torch.no_grad()
51
  def process_frame_batch(self, frames):
@@ -55,16 +82,37 @@ class VideoProcessor:
55
  pil_frames = [Image.fromarray(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)).resize(self.target_size) for f in frames]
56
 
57
  # Get CLIP features
58
- clip_inputs = self.clip_processor(images=pil_frames, return_tensors="pt", padding=True).to(self.device)
 
 
 
 
 
59
  if self.device.type == "cuda":
60
  clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()}
61
  features = self.clip_model.get_image_features(**clip_inputs)
62
 
63
- # Get BLIP captions
64
- blip_inputs = self.blip_processor(images=pil_frames, return_tensors="pt", padding=True).to(self.device)
 
 
 
 
 
65
  if self.device.type == "cuda":
66
  blip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in blip_inputs.items()}
67
- captions = self.blip_model.generate(**blip_inputs, max_length=30)
 
 
 
 
 
 
 
 
 
 
 
68
  captions = [self.blip_processor.decode(c, skip_special_tokens=True) for c in captions]
69
 
70
  return features.cpu().numpy(), captions
@@ -75,12 +123,15 @@ class VideoProcessor:
75
  def process_video(self, video_path: str, progress=gr.Progress()):
76
  """Process video with batching and progress updates"""
77
  cap = cv2.VideoCapture(video_path)
 
 
 
78
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
79
  fps = cap.get(cv2.CAP_PROP_FPS)
80
 
81
  # Calculate frames to process
82
  frames_to_process = min(self.max_frames, total_frames // self.frame_interval)
83
- progress(0, desc="Initializing...")
84
 
85
  features_list = []
86
  frame_data = []
@@ -102,6 +153,9 @@ class VideoProcessor:
102
 
103
  # Process batch when full
104
  if len(current_batch) == self.batch_size or frame_count == total_frames - 1:
 
 
 
105
  features, captions = self.process_frame_batch(current_batch)
106
 
107
  if features is not None and captions is not None:
@@ -116,13 +170,9 @@ class VideoProcessor:
116
  processed_count += len(current_batch)
117
  current_batch = []
118
  batch_positions = []
119
-
120
- # Update progress
121
- progress(processed_count / frames_to_process,
122
- desc=f"Processing frames... {processed_count}/{frames_to_process}")
123
 
124
  frame_count += 1
125
-
126
  cap.release()
127
 
128
  # Create FAISS index
@@ -137,7 +187,7 @@ class VideoProcessor:
137
 
138
  except Exception as e:
139
  cap.release()
140
- return None, None, f"Error processing video: {str(e)}"
141
 
142
  class VideoQAInterface:
143
  def __init__(self):
@@ -145,13 +195,18 @@ class VideoQAInterface:
145
  self.frame_index = None
146
  self.frame_data = None
147
  self.processed = False
148
- self.current_video_path = None # Store the video path
149
  self.temp_dir = tempfile.mkdtemp()
 
150
 
151
  def __del__(self):
152
  """Cleanup temporary files"""
153
  if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
154
- shutil.rmtree(self.temp_dir, ignore_errors=True)
 
 
 
 
155
 
156
  def process_video(self, video_file, progress=gr.Progress()):
157
  """Process video with progress tracking"""
@@ -163,6 +218,7 @@ class VideoQAInterface:
163
  temp_video_path = os.path.join(self.temp_dir, "input_video.mp4")
164
  shutil.copy2(video_file.name, temp_video_path)
165
  self.current_video_path = temp_video_path
 
166
 
167
  progress(0, desc="Starting video processing...")
168
  self.frame_index, self.frame_data, message = self.processor.process_video(
@@ -178,7 +234,7 @@ class VideoQAInterface:
178
 
179
  except Exception as e:
180
  self.processed = False
181
- return f"Error: {str(e)}"
182
 
183
  @torch.no_grad()
184
  def answer_question(self, query):
@@ -222,7 +278,7 @@ class VideoQAInterface:
222
  desc += f"Relevance Score: {result['relevance']:.2f}"
223
  descriptions.append(desc)
224
  finally:
225
- cap.release() # Ensure video capture is released
226
 
227
  if not frames:
228
  return None, "No relevant frames found."
@@ -291,4 +347,9 @@ app = VideoQAInterface()
291
  interface = app.create_interface()
292
 
293
  if __name__ == "__main__":
294
- interface.launch()
 
 
 
 
 
 
13
  from pathlib import Path
14
  from typing import List, Dict, Tuple
15
  import time
16
+ from huggingface_hub import snapshot_download
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+
20
+ # Configure model caching and environment
21
+ os.environ["TRANSFORMERS_CACHE"] = "./model_cache"
22
+ os.environ["HF_HOME"] = "./model_cache"
23
+ os.makedirs("./model_cache", exist_ok=True)
24
+
25
  class VideoProcessor:
26
  def __init__(self):
27
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ print(f"Using device: {self.device}")
29
 
30
  # Load models with optimizations
31
  self.load_models()
 
37
  self.batch_size = 4 if torch.cuda.is_available() else 2
38
 
39
  def load_models(self):
40
+ """Load models with optimizations and proper configurations"""
41
+ print("Loading CLIP model...")
42
  self.clip_model = CLIPModel.from_pretrained(
43
  "openai/clip-vit-base-patch32",
44
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
+ cache_dir="./model_cache"
46
  ).to(self.device)
47
+ self.clip_processor = CLIPProcessor.from_pretrained(
48
+ "openai/clip-vit-base-patch32",
49
+ cache_dir="./model_cache"
50
+ )
51
+
52
+ print("Loading BLIP2 model...")
53
+ model_name = "Salesforce/blip2-opt-2.7b"
54
+
55
+ # Initialize BLIP2 processor with updated configuration
56
+ self.blip_processor = Blip2Processor.from_pretrained(
57
+ model_name,
58
+ cache_dir="./model_cache"
59
+ )
60
+ self.blip_processor.config.use_fast_tokenizer = True
61
+ self.blip_processor.config.processor_class = "Blip2Processor"
62
 
63
+ # Load BLIP2 model with optimizations
64
  self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
65
+ model_name,
66
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
67
+ device_map="auto" if torch.cuda.is_available() else None,
68
+ cache_dir="./model_cache",
69
+ low_cpu_mem_usage=True
70
  ).to(self.device)
 
71
 
72
  # Set models to evaluation mode
73
  self.clip_model.eval()
74
  self.blip_model.eval()
75
+ print("Models loaded successfully!")
76
 
77
  @torch.no_grad()
78
  def process_frame_batch(self, frames):
 
82
  pil_frames = [Image.fromarray(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)).resize(self.target_size) for f in frames]
83
 
84
  # Get CLIP features
85
+ clip_inputs = self.clip_processor(
86
+ images=pil_frames,
87
+ return_tensors="pt",
88
+ padding=True
89
+ ).to(self.device)
90
+
91
  if self.device.type == "cuda":
92
  clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()}
93
  features = self.clip_model.get_image_features(**clip_inputs)
94
 
95
+ # Get BLIP captions with updated processing
96
+ blip_inputs = self.blip_processor(
97
+ images=pil_frames,
98
+ return_tensors="pt",
99
+ padding=True
100
+ ).to(self.device)
101
+
102
  if self.device.type == "cuda":
103
  blip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in blip_inputs.items()}
104
+
105
+ # Generate captions with better parameters
106
+ captions = self.blip_model.generate(
107
+ **blip_inputs,
108
+ max_length=30,
109
+ min_length=10,
110
+ num_beams=5,
111
+ length_penalty=1.0,
112
+ temperature=0.7,
113
+ do_sample=False
114
+ )
115
+
116
  captions = [self.blip_processor.decode(c, skip_special_tokens=True) for c in captions]
117
 
118
  return features.cpu().numpy(), captions
 
123
  def process_video(self, video_path: str, progress=gr.Progress()):
124
  """Process video with batching and progress updates"""
125
  cap = cv2.VideoCapture(video_path)
126
+ if not cap.isOpened():
127
+ raise ValueError("Could not open video file")
128
+
129
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
130
  fps = cap.get(cv2.CAP_PROP_FPS)
131
 
132
  # Calculate frames to process
133
  frames_to_process = min(self.max_frames, total_frames // self.frame_interval)
134
+ progress(0, desc="Initializing video processing...")
135
 
136
  features_list = []
137
  frame_data = []
 
153
 
154
  # Process batch when full
155
  if len(current_batch) == self.batch_size or frame_count == total_frames - 1:
156
+ progress(processed_count / frames_to_process,
157
+ desc=f"Processing frames... {processed_count}/{frames_to_process}")
158
+
159
  features, captions = self.process_frame_batch(current_batch)
160
 
161
  if features is not None and captions is not None:
 
170
  processed_count += len(current_batch)
171
  current_batch = []
172
  batch_positions = []
 
 
 
 
173
 
174
  frame_count += 1
175
+
176
  cap.release()
177
 
178
  # Create FAISS index
 
187
 
188
  except Exception as e:
189
  cap.release()
190
+ raise e
191
 
192
  class VideoQAInterface:
193
  def __init__(self):
 
195
  self.frame_index = None
196
  self.frame_data = None
197
  self.processed = False
198
+ self.current_video_path = None
199
  self.temp_dir = tempfile.mkdtemp()
200
+ print(f"Initialized temp directory: {self.temp_dir}")
201
 
202
  def __del__(self):
203
  """Cleanup temporary files"""
204
  if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
205
+ try:
206
+ shutil.rmtree(self.temp_dir)
207
+ print(f"Cleaned up temp directory: {self.temp_dir}")
208
+ except Exception as e:
209
+ print(f"Error cleaning up temp directory: {str(e)}")
210
 
211
  def process_video(self, video_file, progress=gr.Progress()):
212
  """Process video with progress tracking"""
 
218
  temp_video_path = os.path.join(self.temp_dir, "input_video.mp4")
219
  shutil.copy2(video_file.name, temp_video_path)
220
  self.current_video_path = temp_video_path
221
+ print(f"Saved video to: {self.current_video_path}")
222
 
223
  progress(0, desc="Starting video processing...")
224
  self.frame_index, self.frame_data, message = self.processor.process_video(
 
234
 
235
  except Exception as e:
236
  self.processed = False
237
+ return f"Error processing video: {str(e)}"
238
 
239
  @torch.no_grad()
240
  def answer_question(self, query):
 
278
  desc += f"Relevance Score: {result['relevance']:.2f}"
279
  descriptions.append(desc)
280
  finally:
281
+ cap.release()
282
 
283
  if not frames:
284
  return None, "No relevant frames found."
 
347
  interface = app.create_interface()
348
 
349
  if __name__ == "__main__":
350
+ interface.launch(
351
+ server_name="0.0.0.0",
352
+ share=False, # Set to True if you want to create a public link
353
+ cache_examples=True,
354
+ max_threads=4
355
+ )