Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
import pydeck as pdk | |
from geopy.geocoders import Nominatim | |
import time | |
import requests | |
from io import BytesIO | |
import reverse_geocoder as rg | |
from bs4 import BeautifulSoup | |
from urllib.parse import urljoin | |
from models.huggingface import Geolocalizer | |
import spacy | |
from collections import Counter | |
from spacy.cli import download | |
def load_spacy_model(model_name="en_core_web_md"): | |
try: | |
return spacy.load(model_name) | |
except IOError: | |
print(f"Model {model_name} not found, downloading...") | |
download(model_name) | |
return spacy.load(model_name) | |
nlp = load_spacy_model() | |
IMAGE_SIZE = (224, 224) | |
GEOLOC_MODEL_NAME = "osv5m/baseline" | |
# Load geolocation model | |
def load_geoloc_model() -> Geolocalizer: | |
with st.spinner('Loading model...'): | |
try: | |
model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME) | |
model.eval() | |
return model | |
except Exception as e: | |
st.error(f"Failed to load the model: {e}") | |
return None | |
# Function to find the most frequent location | |
def most_frequent_locations(text: str): | |
doc = nlp(text) | |
locations = [] | |
# Collect all identified location entities | |
for ent in doc.ents: | |
if ent.label_ in ['LOC', 'GPE']: | |
print(f"Entity: {ent.text} | Label: {ent.label_} | Sentence: {ent.sent}") | |
locations.append(ent.text) | |
# Count occurrences and extract the most common locations | |
if locations: | |
location_counts = Counter(locations) | |
most_common_locations = location_counts.most_common(2) # Adjust the number as needed | |
# Format the output to show location names along with their counts | |
common_locations_str = ', '.join([f"{loc[0]} ({loc[1]} occurrences)" for loc in most_common_locations]) | |
return f"Most Mentioned Locations: {common_locations_str}" | |
else: | |
return "No locations found" | |
# Transform image for model prediction | |
def transform_image(image: Image) -> torch.Tensor: | |
transform = transforms.Compose([ | |
transforms.Resize(IMAGE_SIZE), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
return transform(image).unsqueeze(0) | |
# Fetch city GeoJSON data | |
def get_city_geojson(location_name: str) -> dict: | |
geolocator = Nominatim(user_agent="predictGeolocforImage") | |
try: | |
location = geolocator.geocode(location_name, geometry='geojson') | |
return location.raw['geojson'] if location else None | |
except Exception as e: | |
st.error(f"Failed to geocode location: {e}") | |
return None | |
# Fetch media from URL | |
def get_media(url: str) -> list: | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
data = response.json() | |
return [(media['media_url'], entry['full_text']) | |
for entry in data for media in entry.get('media', []) if 'media_url' in media] | |
except requests.RequestException as e: | |
st.error(f"Failed to fetch media URL: {e}") | |
return None | |
# Predict location from image | |
def predict_location(image: Image, model: Geolocalizer) -> tuple: | |
with st.spinner('Processing image and predicting location...'): | |
start_time = time.time() | |
try: | |
img_tensor = transform_image(image) | |
gps_radians = model(img_tensor) | |
gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist() | |
location_query = rg.search((gps_degrees[0], gps_degrees[1]))[0] | |
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}" | |
city_geojson = get_city_geojson(location_name) | |
processing_time = time.time() - start_time | |
return gps_degrees, location_query, city_geojson, processing_time | |
except Exception as e: | |
st.error(f"Failed to predict the location: {e}") | |
return None | |
# Display map in Streamlit | |
def display_map(city_geojson: dict, gps_degrees: list) -> None: | |
map_view = pdk.Deck( | |
map_style='mapbox://styles/mapbox/light-v9', | |
initial_view_state=pdk.ViewState( | |
latitude=gps_degrees[0], | |
longitude=gps_degrees[1], | |
zoom=8, | |
pitch=0, | |
), | |
layers=[ | |
pdk.Layer( | |
'GeoJsonLayer', | |
data=city_geojson, | |
get_fill_color=[255, 180, 0, 140], | |
pickable=True, | |
stroked=True, | |
filled=True, | |
extruded=False, | |
line_width_min_pixels=1, | |
), | |
], | |
) | |
st.pydeck_chart(map_view) | |
# Display image | |
def display_image(image_url: str) -> None: | |
try: | |
response = requests.get(image_url) | |
response.raise_for_status() | |
image_bytes = BytesIO(response.content) | |
st.image(image_bytes, caption=f'Image from URL: {image_url}', use_column_width=True) | |
except requests.RequestException as e: | |
st.error(f"Failed to fetch image at URL {image_url}: {e}") | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
# Scrape webpage for text and images | |
def scrape_webpage(url: str) -> tuple: | |
with st.spinner('Scraping web page...'): | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
base_url = url # Adjust based on <base> tags or other HTML clues | |
text = ''.join(p.text for p in soup.find_all('p')) | |
images = [urljoin(base_url, img['src']) for img in soup.find_all('img') if 'src' in img.attrs] | |
return text, images | |
except requests.RequestException as e: | |
st.error(f"Failed to fetch and parse the URL: {e}") | |
return None, None | |
def main(): | |
st.title('Welcome to Geolocation Predictor Demo 👋') | |
# Define page navigation using the sidebar | |
page = st.sidebar.selectbox( | |
"Choose your action:", | |
("Home", "Upload Images", "Social Media URL", "Web Page URL"), | |
index=0 # Default to Home | |
) | |
st.sidebar.success("Select a demo above.") | |
if page == "Home": | |
st.write("Welcome to the Geolocation Predictor. Please select an action from the sidebar dropdown.") | |
elif page == "Upload Images": | |
upload_images_page() | |
elif page == "Social Media URL": | |
social_media_page() | |
elif page == "Web Page URL": | |
web_page_url_page() | |
def upload_images_page(): | |
st.header("Image Upload for Geolocation Prediction") | |
uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | |
if uploaded_files: | |
for idx, file in enumerate(uploaded_files, start=1): | |
with st.spinner(f"Processing {file.name}..."): | |
image = Image.open(file).convert('RGB') | |
st.image(image, caption=f'Uploaded Image: {file.name}', use_column_width=True) | |
model = load_geoloc_model() | |
if model: | |
result = predict_location(image, model) # Assume this function is defined elsewhere | |
if result: | |
gps_degrees, location_query, city_geojson, processing_time = result | |
st.write( | |
f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}") | |
if city_geojson: | |
display_map(city_geojson, gps_degrees) | |
st.write(f"Processing Time (seconds): {processing_time}") | |
def social_media_page(): | |
st.header("Social Media Image Analyser") | |
social_media_url = st.text_input("Enter a social media URL to analyse:", key='social_media_url_input') | |
if social_media_url: | |
media_data = get_media(social_media_url) # Assume this function is defined elsewhere | |
if media_data: | |
# Display the full text of the first media found | |
full_text = media_data[0][1] | |
st.subheader("Full Text") | |
st.write(full_text) | |
most_used_location = most_frequent_locations(full_text) | |
st.subheader("Most Frequent Location") | |
st.write(most_used_location) | |
# Process and display each image found in the media data | |
for idx, (media_url, _) in enumerate(media_data, start=1): | |
st.subheader(f"Image {idx}") | |
response = requests.get(media_url) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)).convert('RGB') | |
st.image(image, caption=f'Image from URL: {media_url}', use_column_width=True) | |
model = load_geoloc_model() # Assume this function is defined elsewhere | |
if model: | |
result = predict_location(image, model) # Assume this function is defined elsewhere | |
if result: | |
gps_degrees, location_query, city_geojson, processing_time = result | |
st.write( | |
f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}") | |
if city_geojson: | |
display_map(city_geojson, gps_degrees) | |
st.write(f"Processing Time (seconds): {processing_time}") | |
else: | |
st.error(f"Failed to fetch image at URL {media_url}: HTTP {response.status_code}") | |
def web_page_url_page(): | |
st.header("Web Page Scraper") | |
web_page_url = st.text_input("Enter a web page URL to scrape:", key='web_page_url_input') | |
if web_page_url: | |
text, images = scrape_webpage(web_page_url) # Assume this function is defined elsewhere | |
if text: | |
st.subheader("Extracted Text First 500 Chracter:") | |
st.write(text[:500]) | |
most_used_location = most_frequent_locations(text) | |
st.subheader("Most Frequent Location") | |
st.write(most_used_location) | |
show_images = st.checkbox('Show Images', key='show_images') | |
if show_images: | |
st.subheader("Images Found") | |
for image_url in images: | |
display_image(image_url) # Assumes a function to display images with error handling | |
else: | |
st.write("No data found or unable to parse the webpage.") | |
if __name__ == '__main__': | |
main() | |