import gradio as gr from PIL import Image import src.depth_pro as depth_pro import numpy as np import matplotlib.pyplot as plt import subprocess import spaces import torch import tempfile import os # Run the script to get pretrained models subprocess.run(["bash", "get_pretrained_models.sh"]) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load model and preprocessing transform model, transform = depth_pro.create_model_and_transforms() model = model.to(device) model.eval() def resize_image(image_path, max_size=1024): with Image.open(image_path) as img: # Calculate the new size while maintaining aspect ratio ratio = max_size / max(img.size) new_size = tuple([int(x * ratio) for x in img.size]) # Resize the image img = img.resize(new_size, Image.LANCZOS) # Create a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: img.save(temp_file, format="PNG") return temp_file.name @spaces.GPU(duration=20) def predict_depth(input_image): temp_file = None try: # Resize the input image temp_file = resize_image(input_image) # Preprocess the image result = depth_pro.load_rgb(temp_file) image = result[0] f_px = result[-1] # Assuming f_px is the last item in the returned tuple image = transform(image) image = image.to(device) # Run inference prediction = model.infer(image, f_px=f_px) depth = prediction["depth"] # Depth in [m] focallength_px = prediction["focallength_px"] # Focal length in pixels # Convert depth to numpy array if it's a torch tensor if isinstance(depth, torch.Tensor): depth = depth.cpu().numpy() # Ensure depth is a 2D numpy array if depth.ndim != 2: depth = depth.squeeze() # Normalize depth for visualization depth_min = np.min(depth) depth_max = np.max(depth) depth_normalized = (depth - depth_min) / (depth_max - depth_min) # Create a color map plt.figure(figsize=(10, 10)) plt.imshow(depth_normalized, cmap='viridis') plt.colorbar(label='Depth') plt.title('Predicted Depth Map') plt.axis('off') # Save the plot to a file output_path = "depth_map.png" plt.savefig(output_path) plt.close() return output_path, f"Focal length: {focallength_px:.2f} pixels" except Exception as e: return None, f"An error occurred: {str(e)}" finally: # Clean up the temporary file if temp_file and os.path.exists(temp_file): os.remove(temp_file) # Create Gradio interface iface = gr.Interface( fn=predict_depth, inputs=gr.Image(type="filepath"), outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message")], title="DepthPro Demo", description="[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its depth map and focal length. Large images will be automatically resized." ) # Launch the interface iface.launch()