derek-thomas's picture
derek-thomas HF staff
Init commit
749d1d8
raw
history blame
No virus
6.13 kB
import time
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
import pandas as pd
import requests
from my_logger import setup_logger
logger = setup_logger(__name__)
def get_pushshift_data(subreddit: str, before: Optional[int] = None,
after: Optional[int] = None, aggs: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
Fetch data from the Pushshift API for the specified subreddit.
:param subreddit: The name of the subreddit to scrape.
:param before: The upper limit for the created_utc attribute of the submissions.
:param after: The lower limit for the created_utc attribute of the submissions.
:param aggs: The aggregation summary option to use.
:return: A dictionary containing the fetched data and aggregations if available.
"""
url = "https://api.pushshift.io/reddit/search/submission/"
params = {
"subreddit": subreddit,
"size": 1000,
"sort": "created_utc",
"sort_type": "desc",
}
if before is not None:
params["before"] = before
if after is not None:
params["after"] = after
if aggs is not None:
params["aggs"] = aggs
response = requests.get(url, params=params)
if response.status_code == 200:
return response.json()
else:
logger.error(f"Error fetching data: {response.status_code}")
return None
def get_post_count_for_day(subreddit: str, day_to_scrape: str) -> int:
"""
Get the total number of posts for a specific day in the specified subreddit using the Pushshift API.
:param subreddit: The name of the subreddit to get the post count for.
:param day_to_scrape: The date for which to get the post count (format: "YYYY-MM-DD").
:return: The total number of posts for the specified day.
"""
date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d")
after = int(date_obj.timestamp())
before = int((date_obj + timedelta(days=1)).timestamp())
response = get_pushshift_data(subreddit, before=before, after=after, aggs="created_utc")
if response is not None:
aggs = response.get("aggs", {}).get("created_utc", [])
if aggs:
return aggs[0]["doc_count"]
return 0
def fetch_data(subreddit: str, before: int, after: int) -> Optional[Dict[str, Any]]:
url = "https://api.pushshift.io/reddit/search/submission/"
params = {
"subreddit": subreddit,
"size": 1000,
"sort": "created_utc",
"sort_type": "desc",
"before": before,
"after": after,
}
response = requests.get(url, params=params)
if response.status_code == 200:
return response.json()
else:
logger.error(f"Error fetching data: {response.status_code}")
return None
def convert_timestamp_to_datetime(timestamp: int) -> str:
# Convert the timestamp to a datetime object
datetime_obj = datetime.utcfromtimestamp(timestamp)
# Add timezone information
datetime_obj_utc = datetime_obj.replace(tzinfo=timezone.utc)
# Convert the datetime object to a formatted string
datetime_str = datetime_obj_utc.strftime('%Y-%m-%d %H:%M:%S')
return datetime_str
def scrape_submissions_by_day(subreddit_to_scrape: str, day_to_scrape: str) -> List[Dict[str, Any]]:
start_time = time.time()
scraped_submissions = []
date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d")
if date_obj > datetime.now() - timedelta(days=7):
logger.error("The specified date might not be available in the Pushshift API yet. "
"Please try an earlier date or wait for the API to be updated.")
return scraped_submissions
after = int(date_obj.timestamp())
before = int((date_obj + timedelta(days=1)).timestamp())
# todo get_post_count_for_day didnt seem to work
# post_count = get_post_count_for_day(subreddit_to_scrape, day_to_scrape)
# total_requests = (post_count + 99) // 100 # Estimate the total number of requests
actual_requests = 0
while after < before:
after_str, before_str = convert_timestamp_to_datetime(after), convert_timestamp_to_datetime(before)
logger.info(f"Fetching data between timestamps {after_str} and {before_str}")
data = get_pushshift_data(subreddit_to_scrape, before=before, after=after)
if data is None or len(data["data"]) == 0:
break
scraped_submissions.extend(data["data"])
before = data["data"][-1]["created_utc"]
actual_requests += 1
time.sleep(1)
elapsed_time = time.time() - start_time
if actual_requests:
logger.info(
f"{actual_requests}it [{elapsed_time // 60:02}:{elapsed_time % 60:.2f} {elapsed_time / actual_requests:.2f}s/it]")
logger.info(
f"Finished scraping {len(scraped_submissions)} submissions in {elapsed_time:.2f} seconds in {actual_requests} requests")
return scraped_submissions
def submissions_to_dataframe(submissions: List[Dict[str, Any]]) -> pd.DataFrame:
"""
Parse a list of submissions into a pandas DataFrame.
:param submissions: A list of dictionaries containing the scraped submission data.
:return: A pandas DataFrame containing the submission data.
"""
cols = ['score', 'num_comments', 'title', 'permalink', 'selftext', 'url', 'created_utc', 'author', 'id',
'downs', 'ups']
df = pd.DataFrame(submissions)
df = df.convert_dtypes()
df = df[cols]
# Convert the "created_utc" column to a datetime column with timezone information
df['created_utc'] = pd.to_datetime(df['created_utc'], unit='s').dt.tz_localize('UTC').dt.strftime(
'%Y-%m-%d %H:%M:%S')
return df
if __name__ == '__main__':
subreddit_to_scrape = "askreddit"
day_to_scrape = "2013-03-01"
submissions = scrape_submissions_by_day(subreddit_to_scrape, day_to_scrape)
df = submissions_to_dataframe(submissions)
print(df.head().to_string())
logger.info(f"Scraped {len(submissions)} submissions from r/{subreddit_to_scrape} on {day_to_scrape}")