yunusserhat commited on
Commit
b82263e
1 Parent(s): 0109d52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +276 -264
app.py CHANGED
@@ -1,264 +1,276 @@
1
- import streamlit as st
2
- from PIL import Image
3
- import torch
4
- from torchvision import transforms
5
- import pydeck as pdk
6
- from geopy.geocoders import Nominatim
7
- import time
8
- import requests
9
- from io import BytesIO
10
- import reverse_geocoder as rg
11
- from bs4 import BeautifulSoup
12
- from urllib.parse import urljoin
13
- from models.huggingface import Geolocalizer
14
- import spacy
15
- from collections import Counter
16
-
17
- nlp = spacy.load("en_core_web_md")
18
-
19
- IMAGE_SIZE = (224, 224)
20
- GEOLOC_MODEL_NAME = "osv5m/baseline"
21
-
22
-
23
- # Load geolocation model
24
- @st.cache_resource(show_spinner=True)
25
- def load_geoloc_model() -> Geolocalizer:
26
- with st.spinner('Loading model...'):
27
- try:
28
- model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
29
- model.eval()
30
- return model
31
- except Exception as e:
32
- st.error(f"Failed to load the model: {e}")
33
- return None
34
-
35
-
36
- # Function to find the most frequent location
37
- def most_frequent_locations(text: str):
38
- doc = nlp(text)
39
- locations = []
40
-
41
- # Collect all identified location entities
42
- for ent in doc.ents:
43
- if ent.label_ in ['LOC', 'GPE']:
44
- print(f"Entity: {ent.text} | Label: {ent.label_} | Sentence: {ent.sent}")
45
- locations.append(ent.text)
46
-
47
- # Count occurrences and extract the most common locations
48
- if locations:
49
- location_counts = Counter(locations)
50
- most_common_locations = location_counts.most_common(2) # Adjust the number as needed
51
- # Format the output to show location names along with their counts
52
- common_locations_str = ', '.join([f"{loc[0]} ({loc[1]} occurrences)" for loc in most_common_locations])
53
-
54
- return f"Most Mentioned Locations: {common_locations_str}"
55
- else:
56
- return "No locations found"
57
-
58
-
59
- # Transform image for model prediction
60
- def transform_image(image: Image) -> torch.Tensor:
61
- transform = transforms.Compose([
62
- transforms.Resize(IMAGE_SIZE),
63
- transforms.ToTensor(),
64
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
65
- ])
66
- return transform(image).unsqueeze(0)
67
-
68
-
69
- # Fetch city GeoJSON data
70
- def get_city_geojson(location_name: str) -> dict:
71
- geolocator = Nominatim(user_agent="predictGeolocforImage")
72
- try:
73
- location = geolocator.geocode(location_name, geometry='geojson')
74
- return location.raw['geojson'] if location else None
75
- except Exception as e:
76
- st.error(f"Failed to geocode location: {e}")
77
- return None
78
-
79
-
80
- # Fetch media from URL
81
- def get_media(url: str) -> list:
82
- try:
83
- response = requests.get(url)
84
- response.raise_for_status()
85
- data = response.json()
86
- return [(media['media_url'], entry['full_text'])
87
- for entry in data for media in entry.get('media', []) if 'media_url' in media]
88
- except requests.RequestException as e:
89
- st.error(f"Failed to fetch media URL: {e}")
90
- return None
91
-
92
-
93
- # Predict location from image
94
- def predict_location(image: Image, model: Geolocalizer) -> tuple:
95
- with st.spinner('Processing image and predicting location...'):
96
- start_time = time.time()
97
- try:
98
- img_tensor = transform_image(image)
99
- gps_radians = model(img_tensor)
100
- gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist()
101
- location_query = rg.search((gps_degrees[0], gps_degrees[1]))[0]
102
- location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
103
- city_geojson = get_city_geojson(location_name)
104
- processing_time = time.time() - start_time
105
- return gps_degrees, location_query, city_geojson, processing_time
106
- except Exception as e:
107
- st.error(f"Failed to predict the location: {e}")
108
- return None
109
-
110
- # Display map in Streamlit
111
- def display_map(city_geojson: dict, gps_degrees: list) -> None:
112
- map_view = pdk.Deck(
113
- map_style='mapbox://styles/mapbox/light-v9',
114
- initial_view_state=pdk.ViewState(
115
- latitude=gps_degrees[0],
116
- longitude=gps_degrees[1],
117
- zoom=8,
118
- pitch=0,
119
- ),
120
- layers=[
121
- pdk.Layer(
122
- 'GeoJsonLayer',
123
- data=city_geojson,
124
- get_fill_color=[255, 180, 0, 140],
125
- pickable=True,
126
- stroked=True,
127
- filled=True,
128
- extruded=False,
129
- line_width_min_pixels=1,
130
- ),
131
- ],
132
- )
133
- st.pydeck_chart(map_view)
134
-
135
-
136
- # Display image
137
- def display_image(image_url: str) -> None:
138
- try:
139
- response = requests.get(image_url)
140
- response.raise_for_status()
141
- image_bytes = BytesIO(response.content)
142
- st.image(image_bytes, caption=f'Image from URL: {image_url}', use_column_width=True)
143
- except requests.RequestException as e:
144
- st.error(f"Failed to fetch image at URL {image_url}: {e}")
145
- except Exception as e:
146
- st.error(f"An error occurred: {e}")
147
-
148
-
149
- # Scrape webpage for text and images
150
- def scrape_webpage(url: str) -> tuple:
151
- with st.spinner('Scraping web page...'):
152
- try:
153
- response = requests.get(url)
154
- response.raise_for_status()
155
- soup = BeautifulSoup(response.content, 'html.parser')
156
- base_url = url # Adjust based on <base> tags or other HTML clues
157
- text = ''.join(p.text for p in soup.find_all('p'))
158
- images = [urljoin(base_url, img['src']) for img in soup.find_all('img') if 'src' in img.attrs]
159
- return text, images
160
- except requests.RequestException as e:
161
- st.error(f"Failed to fetch and parse the URL: {e}")
162
- return None, None
163
-
164
-
165
- def main():
166
- st.title('Welcome to Geolocation Predictor Demo 👋')
167
-
168
- # Define page navigation using the sidebar
169
- page = st.sidebar.selectbox(
170
- "Choose your action:",
171
- ("Home", "Upload Images", "Social Media URL", "Web Page URL"),
172
- index=0 # Default to Home
173
- )
174
- st.sidebar.success("Select a demo above.")
175
- if page == "Home":
176
- st.write("Welcome to the Geolocation Predictor. Please select an action from the sidebar dropdown.")
177
-
178
- elif page == "Upload Images":
179
- upload_images_page()
180
-
181
- elif page == "Social Media URL":
182
- social_media_page()
183
-
184
- elif page == "Web Page URL":
185
- web_page_url_page()
186
-
187
-
188
- def upload_images_page():
189
- st.header("Image Upload for Geolocation Prediction")
190
- uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
191
- if uploaded_files:
192
- for idx, file in enumerate(uploaded_files, start=1):
193
- with st.spinner(f"Processing {file.name}..."):
194
- image = Image.open(file).convert('RGB')
195
- st.image(image, caption=f'Uploaded Image: {file.name}', use_column_width=True)
196
- model = load_geoloc_model()
197
- if model:
198
- result = predict_location(image, model) # Assume this function is defined elsewhere
199
- if result:
200
- gps_degrees, location_query, city_geojson, processing_time = result
201
- st.write(
202
- f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
203
- if city_geojson:
204
- display_map(city_geojson, gps_degrees)
205
- st.write(f"Processing Time (seconds): {processing_time}")
206
-
207
-
208
- def social_media_page():
209
- st.header("Social Media Image Analyser")
210
- social_media_url = st.text_input("Enter a social media URL to analyse:", key='social_media_url_input')
211
- if social_media_url:
212
- media_data = get_media(social_media_url) # Assume this function is defined elsewhere
213
- if media_data:
214
- # Display the full text of the first media found
215
- full_text = media_data[0][1]
216
- st.subheader("Full Text")
217
- st.write(full_text)
218
- most_used_location = most_frequent_locations(full_text)
219
- st.subheader("Most Frequent Location")
220
- st.write(most_used_location)
221
-
222
- # Process and display each image found in the media data
223
- for idx, (media_url, _) in enumerate(media_data, start=1):
224
- st.subheader(f"Image {idx}")
225
- response = requests.get(media_url)
226
- if response.status_code == 200:
227
- image = Image.open(BytesIO(response.content)).convert('RGB')
228
- st.image(image, caption=f'Image from URL: {media_url}', use_column_width=True)
229
- model = load_geoloc_model() # Assume this function is defined elsewhere
230
- if model:
231
- result = predict_location(image, model) # Assume this function is defined elsewhere
232
- if result:
233
- gps_degrees, location_query, city_geojson, processing_time = result
234
- st.write(
235
- f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
236
- if city_geojson:
237
- display_map(city_geojson, gps_degrees)
238
- st.write(f"Processing Time (seconds): {processing_time}")
239
- else:
240
- st.error(f"Failed to fetch image at URL {media_url}: HTTP {response.status_code}")
241
-
242
-
243
- def web_page_url_page():
244
- st.header("Web Page Scraper")
245
- web_page_url = st.text_input("Enter a web page URL to scrape:", key='web_page_url_input')
246
- if web_page_url:
247
- text, images = scrape_webpage(web_page_url) # Assume this function is defined elsewhere
248
- if text:
249
- st.subheader("Extracted Text First 500 Chracter:")
250
- st.write(text[:500])
251
- most_used_location = most_frequent_locations(text)
252
- st.subheader("Most Frequent Location")
253
- st.write(most_used_location)
254
- show_images = st.checkbox('Show Images', key='show_images')
255
- if show_images:
256
- st.subheader("Images Found")
257
- for image_url in images:
258
- display_image(image_url) # Assumes a function to display images with error handling
259
- else:
260
- st.write("No data found or unable to parse the webpage.")
261
-
262
-
263
- if __name__ == '__main__':
264
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ import pydeck as pdk
6
+ from geopy.geocoders import Nominatim
7
+ import time
8
+ import requests
9
+ from io import BytesIO
10
+ import reverse_geocoder as rg
11
+ from bs4 import BeautifulSoup
12
+ from urllib.parse import urljoin
13
+ from models.huggingface import Geolocalizer
14
+ import spacy
15
+ from collections import Counter
16
+ from spacy.cli import download
17
+
18
+
19
+ def load_spacy_model(model_name="en_core_web_md"):
20
+ try:
21
+ return spacy.load(model_name)
22
+ except IOError:
23
+ print(f"Model {model_name} not found, downloading...")
24
+ download(model_name)
25
+ return spacy.load(model_name)
26
+
27
+
28
+
29
+ nlp = load_spacy_model()
30
+
31
+ IMAGE_SIZE = (224, 224)
32
+ GEOLOC_MODEL_NAME = "osv5m/baseline"
33
+
34
+
35
+ # Load geolocation model
36
+ @st.cache_resource(show_spinner=True)
37
+ def load_geoloc_model() -> Geolocalizer:
38
+ with st.spinner('Loading model...'):
39
+ try:
40
+ model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
41
+ model.eval()
42
+ return model
43
+ except Exception as e:
44
+ st.error(f"Failed to load the model: {e}")
45
+ return None
46
+
47
+
48
+ # Function to find the most frequent location
49
+ def most_frequent_locations(text: str):
50
+ doc = nlp(text)
51
+ locations = []
52
+
53
+ # Collect all identified location entities
54
+ for ent in doc.ents:
55
+ if ent.label_ in ['LOC', 'GPE']:
56
+ print(f"Entity: {ent.text} | Label: {ent.label_} | Sentence: {ent.sent}")
57
+ locations.append(ent.text)
58
+
59
+ # Count occurrences and extract the most common locations
60
+ if locations:
61
+ location_counts = Counter(locations)
62
+ most_common_locations = location_counts.most_common(2) # Adjust the number as needed
63
+ # Format the output to show location names along with their counts
64
+ common_locations_str = ', '.join([f"{loc[0]} ({loc[1]} occurrences)" for loc in most_common_locations])
65
+
66
+ return f"Most Mentioned Locations: {common_locations_str}"
67
+ else:
68
+ return "No locations found"
69
+
70
+
71
+ # Transform image for model prediction
72
+ def transform_image(image: Image) -> torch.Tensor:
73
+ transform = transforms.Compose([
74
+ transforms.Resize(IMAGE_SIZE),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
77
+ ])
78
+ return transform(image).unsqueeze(0)
79
+
80
+
81
+ # Fetch city GeoJSON data
82
+ def get_city_geojson(location_name: str) -> dict:
83
+ geolocator = Nominatim(user_agent="predictGeolocforImage")
84
+ try:
85
+ location = geolocator.geocode(location_name, geometry='geojson')
86
+ return location.raw['geojson'] if location else None
87
+ except Exception as e:
88
+ st.error(f"Failed to geocode location: {e}")
89
+ return None
90
+
91
+
92
+ # Fetch media from URL
93
+ def get_media(url: str) -> list:
94
+ try:
95
+ response = requests.get(url)
96
+ response.raise_for_status()
97
+ data = response.json()
98
+ return [(media['media_url'], entry['full_text'])
99
+ for entry in data for media in entry.get('media', []) if 'media_url' in media]
100
+ except requests.RequestException as e:
101
+ st.error(f"Failed to fetch media URL: {e}")
102
+ return None
103
+
104
+
105
+ # Predict location from image
106
+ def predict_location(image: Image, model: Geolocalizer) -> tuple:
107
+ with st.spinner('Processing image and predicting location...'):
108
+ start_time = time.time()
109
+ try:
110
+ img_tensor = transform_image(image)
111
+ gps_radians = model(img_tensor)
112
+ gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist()
113
+ location_query = rg.search((gps_degrees[0], gps_degrees[1]))[0]
114
+ location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
115
+ city_geojson = get_city_geojson(location_name)
116
+ processing_time = time.time() - start_time
117
+ return gps_degrees, location_query, city_geojson, processing_time
118
+ except Exception as e:
119
+ st.error(f"Failed to predict the location: {e}")
120
+ return None
121
+
122
+ # Display map in Streamlit
123
+ def display_map(city_geojson: dict, gps_degrees: list) -> None:
124
+ map_view = pdk.Deck(
125
+ map_style='mapbox://styles/mapbox/light-v9',
126
+ initial_view_state=pdk.ViewState(
127
+ latitude=gps_degrees[0],
128
+ longitude=gps_degrees[1],
129
+ zoom=8,
130
+ pitch=0,
131
+ ),
132
+ layers=[
133
+ pdk.Layer(
134
+ 'GeoJsonLayer',
135
+ data=city_geojson,
136
+ get_fill_color=[255, 180, 0, 140],
137
+ pickable=True,
138
+ stroked=True,
139
+ filled=True,
140
+ extruded=False,
141
+ line_width_min_pixels=1,
142
+ ),
143
+ ],
144
+ )
145
+ st.pydeck_chart(map_view)
146
+
147
+
148
+ # Display image
149
+ def display_image(image_url: str) -> None:
150
+ try:
151
+ response = requests.get(image_url)
152
+ response.raise_for_status()
153
+ image_bytes = BytesIO(response.content)
154
+ st.image(image_bytes, caption=f'Image from URL: {image_url}', use_column_width=True)
155
+ except requests.RequestException as e:
156
+ st.error(f"Failed to fetch image at URL {image_url}: {e}")
157
+ except Exception as e:
158
+ st.error(f"An error occurred: {e}")
159
+
160
+
161
+ # Scrape webpage for text and images
162
+ def scrape_webpage(url: str) -> tuple:
163
+ with st.spinner('Scraping web page...'):
164
+ try:
165
+ response = requests.get(url)
166
+ response.raise_for_status()
167
+ soup = BeautifulSoup(response.content, 'html.parser')
168
+ base_url = url # Adjust based on <base> tags or other HTML clues
169
+ text = ''.join(p.text for p in soup.find_all('p'))
170
+ images = [urljoin(base_url, img['src']) for img in soup.find_all('img') if 'src' in img.attrs]
171
+ return text, images
172
+ except requests.RequestException as e:
173
+ st.error(f"Failed to fetch and parse the URL: {e}")
174
+ return None, None
175
+
176
+
177
+ def main():
178
+ st.title('Welcome to Geolocation Predictor Demo 👋')
179
+
180
+ # Define page navigation using the sidebar
181
+ page = st.sidebar.selectbox(
182
+ "Choose your action:",
183
+ ("Home", "Upload Images", "Social Media URL", "Web Page URL"),
184
+ index=0 # Default to Home
185
+ )
186
+ st.sidebar.success("Select a demo above.")
187
+ if page == "Home":
188
+ st.write("Welcome to the Geolocation Predictor. Please select an action from the sidebar dropdown.")
189
+
190
+ elif page == "Upload Images":
191
+ upload_images_page()
192
+
193
+ elif page == "Social Media URL":
194
+ social_media_page()
195
+
196
+ elif page == "Web Page URL":
197
+ web_page_url_page()
198
+
199
+
200
+ def upload_images_page():
201
+ st.header("Image Upload for Geolocation Prediction")
202
+ uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
203
+ if uploaded_files:
204
+ for idx, file in enumerate(uploaded_files, start=1):
205
+ with st.spinner(f"Processing {file.name}..."):
206
+ image = Image.open(file).convert('RGB')
207
+ st.image(image, caption=f'Uploaded Image: {file.name}', use_column_width=True)
208
+ model = load_geoloc_model()
209
+ if model:
210
+ result = predict_location(image, model) # Assume this function is defined elsewhere
211
+ if result:
212
+ gps_degrees, location_query, city_geojson, processing_time = result
213
+ st.write(
214
+ f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
215
+ if city_geojson:
216
+ display_map(city_geojson, gps_degrees)
217
+ st.write(f"Processing Time (seconds): {processing_time}")
218
+
219
+
220
+ def social_media_page():
221
+ st.header("Social Media Image Analyser")
222
+ social_media_url = st.text_input("Enter a social media URL to analyse:", key='social_media_url_input')
223
+ if social_media_url:
224
+ media_data = get_media(social_media_url) # Assume this function is defined elsewhere
225
+ if media_data:
226
+ # Display the full text of the first media found
227
+ full_text = media_data[0][1]
228
+ st.subheader("Full Text")
229
+ st.write(full_text)
230
+ most_used_location = most_frequent_locations(full_text)
231
+ st.subheader("Most Frequent Location")
232
+ st.write(most_used_location)
233
+
234
+ # Process and display each image found in the media data
235
+ for idx, (media_url, _) in enumerate(media_data, start=1):
236
+ st.subheader(f"Image {idx}")
237
+ response = requests.get(media_url)
238
+ if response.status_code == 200:
239
+ image = Image.open(BytesIO(response.content)).convert('RGB')
240
+ st.image(image, caption=f'Image from URL: {media_url}', use_column_width=True)
241
+ model = load_geoloc_model() # Assume this function is defined elsewhere
242
+ if model:
243
+ result = predict_location(image, model) # Assume this function is defined elsewhere
244
+ if result:
245
+ gps_degrees, location_query, city_geojson, processing_time = result
246
+ st.write(
247
+ f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
248
+ if city_geojson:
249
+ display_map(city_geojson, gps_degrees)
250
+ st.write(f"Processing Time (seconds): {processing_time}")
251
+ else:
252
+ st.error(f"Failed to fetch image at URL {media_url}: HTTP {response.status_code}")
253
+
254
+
255
+ def web_page_url_page():
256
+ st.header("Web Page Scraper")
257
+ web_page_url = st.text_input("Enter a web page URL to scrape:", key='web_page_url_input')
258
+ if web_page_url:
259
+ text, images = scrape_webpage(web_page_url) # Assume this function is defined elsewhere
260
+ if text:
261
+ st.subheader("Extracted Text First 500 Chracter:")
262
+ st.write(text[:500])
263
+ most_used_location = most_frequent_locations(text)
264
+ st.subheader("Most Frequent Location")
265
+ st.write(most_used_location)
266
+ show_images = st.checkbox('Show Images', key='show_images')
267
+ if show_images:
268
+ st.subheader("Images Found")
269
+ for image_url in images:
270
+ display_image(image_url) # Assumes a function to display images with error handling
271
+ else:
272
+ st.write("No data found or unable to parse the webpage.")
273
+
274
+
275
+ if __name__ == '__main__':
276
+ main()