derek-thomas's picture
derek-thomas HF staff
Updates for datetime format and correcting most_recent_date
9de4dba
raw
history blame
6.17 kB
import time
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
import pandas as pd
import requests
from my_logger import setup_logger
logger = setup_logger(__name__)
def get_pushshift_data(subreddit: str, before: Optional[int] = None,
after: Optional[int] = None, aggs: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
Fetch data from the Pushshift API for the specified subreddit.
:param subreddit: The name of the subreddit to scrape.
:param before: The upper limit for the created_utc attribute of the submissions.
:param after: The lower limit for the created_utc attribute of the submissions.
:param aggs: The aggregation summary option to use.
:return: A dictionary containing the fetched data and aggregations if available.
"""
url = "https://api.pushshift.io/reddit/search/submission/"
params = {
"subreddit": subreddit,
"size": 1000,
"sort": "created_utc",
"sort_type": "desc",
}
if before is not None:
params["before"] = before
if after is not None:
params["after"] = after
if aggs is not None:
params["aggs"] = aggs
response = requests.get(url, params=params)
if response.status_code == 200:
return response.json()
else:
logger.error(f"Error fetching data: {response.status_code}")
return None
def get_post_count_for_day(subreddit: str, day_to_scrape: str) -> int:
"""
Get the total number of posts for a specific day in the specified subreddit using the Pushshift API.
:param subreddit: The name of the subreddit to get the post count for.
:param day_to_scrape: The date for which to get the post count (format: "YYYY-MM-DD").
:return: The total number of posts for the specified day.
"""
date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d")
after = int(date_obj.timestamp())
before = int((date_obj + timedelta(days=1)).timestamp())
response = get_pushshift_data(subreddit, before=before, after=after, aggs="created_utc")
if response is not None:
aggs = response.get("aggs", {}).get("created_utc", [])
if aggs:
return aggs[0]["doc_count"]
return 0
def fetch_data(subreddit: str, before: int, after: int) -> Optional[Dict[str, Any]]:
url = "https://api.pushshift.io/reddit/search/submission/"
params = {
"subreddit": subreddit,
"size": 1000,
"sort": "created_utc",
"sort_type": "desc",
"before": before,
"after": after,
}
response = requests.get(url, params=params)
if response.status_code == 200:
return response.json()
else:
logger.error(f"Error fetching data: {response.status_code}")
return None
def convert_timestamp_to_datetime(timestamp: int) -> str:
# Convert the timestamp to a datetime object
datetime_obj = datetime.utcfromtimestamp(timestamp)
# Add timezone information
datetime_obj_utc = datetime_obj.replace(tzinfo=timezone.utc)
# Convert the datetime object to a formatted string
datetime_str = datetime_obj_utc.strftime('%Y-%m-%d %H:%M:%S')
return datetime_str
def scrape_submissions_by_day(subreddit_to_scrape: str, day_to_scrape: str) -> List[Dict[str, Any]]:
start_time = time.time()
scraped_submissions = []
date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d")
if date_obj > datetime.now() - timedelta(days=7):
logger.error("The specified date might not be available in the Pushshift API yet. "
"Please try an earlier date or wait for the API to be updated.")
return scraped_submissions
after = int(date_obj.timestamp())
before = int((date_obj + timedelta(days=1)).timestamp())
# todo get_post_count_for_day didnt seem to work
# post_count = get_post_count_for_day(subreddit_to_scrape, day_to_scrape)
# total_requests = (post_count + 99) // 100 # Estimate the total number of requests
actual_requests = 0
while after < before:
after_str, before_str = convert_timestamp_to_datetime(after), convert_timestamp_to_datetime(before)
logger.info(f"Fetching data between timestamps {after_str} and {before_str}")
data = get_pushshift_data(subreddit_to_scrape, before=before, after=after)
if data is None or len(data["data"]) == 0:
break
scraped_submissions.extend(data["data"])
before = data["data"][-1]["created_utc"]
actual_requests += 1
time.sleep(1)
elapsed_time = time.time() - start_time
if actual_requests:
logger.info(
f"{actual_requests}it [{elapsed_time // 60:02}:{elapsed_time % 60:.2f} {elapsed_time / actual_requests:.2f}s/it]")
logger.info(
f"Finished scraping {len(scraped_submissions)} submissions in {elapsed_time:.2f} seconds in {actual_requests} requests")
return scraped_submissions
def submissions_to_dataframe(submissions: List[Dict[str, Any]]) -> pd.DataFrame:
"""
Parse a list of submissions into a pandas DataFrame.
:param submissions: A list of dictionaries containing the scraped submission data.
:return: A pandas DataFrame containing the submission data.
"""
cols = ['score', 'num_comments', 'title', 'permalink', 'selftext', 'url', 'created_utc', 'author', 'id',
'downs', 'ups']
df = pd.DataFrame(submissions)
df = df.convert_dtypes()
df = df[cols]
# Convert the "created_utc" column to a datetime column with timezone information
df['created_utc'] = pd.to_datetime(df['created_utc'], unit='s').dt.tz_localize('UTC')
df['date'] = df['created_utc'].dt.date
df['time'] = df['created_utc'].dt.time
return df
if __name__ == '__main__':
subreddit_to_scrape = "askreddit"
day_to_scrape = "2013-03-01"
submissions = scrape_submissions_by_day(subreddit_to_scrape, day_to_scrape)
df = submissions_to_dataframe(submissions)
print(df.head().to_string())
logger.info(f"Scraped {len(submissions)} submissions from r/{subreddit_to_scrape} on {day_to_scrape}")