File size: 6,130 Bytes
749d1d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import time
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
import pandas as pd
import requests
from my_logger import setup_logger
logger = setup_logger(__name__)
def get_pushshift_data(subreddit: str, before: Optional[int] = None,
after: Optional[int] = None, aggs: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
Fetch data from the Pushshift API for the specified subreddit.
:param subreddit: The name of the subreddit to scrape.
:param before: The upper limit for the created_utc attribute of the submissions.
:param after: The lower limit for the created_utc attribute of the submissions.
:param aggs: The aggregation summary option to use.
:return: A dictionary containing the fetched data and aggregations if available.
"""
url = "https://api.pushshift.io/reddit/search/submission/"
params = {
"subreddit": subreddit,
"size": 1000,
"sort": "created_utc",
"sort_type": "desc",
}
if before is not None:
params["before"] = before
if after is not None:
params["after"] = after
if aggs is not None:
params["aggs"] = aggs
response = requests.get(url, params=params)
if response.status_code == 200:
return response.json()
else:
logger.error(f"Error fetching data: {response.status_code}")
return None
def get_post_count_for_day(subreddit: str, day_to_scrape: str) -> int:
"""
Get the total number of posts for a specific day in the specified subreddit using the Pushshift API.
:param subreddit: The name of the subreddit to get the post count for.
:param day_to_scrape: The date for which to get the post count (format: "YYYY-MM-DD").
:return: The total number of posts for the specified day.
"""
date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d")
after = int(date_obj.timestamp())
before = int((date_obj + timedelta(days=1)).timestamp())
response = get_pushshift_data(subreddit, before=before, after=after, aggs="created_utc")
if response is not None:
aggs = response.get("aggs", {}).get("created_utc", [])
if aggs:
return aggs[0]["doc_count"]
return 0
def fetch_data(subreddit: str, before: int, after: int) -> Optional[Dict[str, Any]]:
url = "https://api.pushshift.io/reddit/search/submission/"
params = {
"subreddit": subreddit,
"size": 1000,
"sort": "created_utc",
"sort_type": "desc",
"before": before,
"after": after,
}
response = requests.get(url, params=params)
if response.status_code == 200:
return response.json()
else:
logger.error(f"Error fetching data: {response.status_code}")
return None
def convert_timestamp_to_datetime(timestamp: int) -> str:
# Convert the timestamp to a datetime object
datetime_obj = datetime.utcfromtimestamp(timestamp)
# Add timezone information
datetime_obj_utc = datetime_obj.replace(tzinfo=timezone.utc)
# Convert the datetime object to a formatted string
datetime_str = datetime_obj_utc.strftime('%Y-%m-%d %H:%M:%S')
return datetime_str
def scrape_submissions_by_day(subreddit_to_scrape: str, day_to_scrape: str) -> List[Dict[str, Any]]:
start_time = time.time()
scraped_submissions = []
date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d")
if date_obj > datetime.now() - timedelta(days=7):
logger.error("The specified date might not be available in the Pushshift API yet. "
"Please try an earlier date or wait for the API to be updated.")
return scraped_submissions
after = int(date_obj.timestamp())
before = int((date_obj + timedelta(days=1)).timestamp())
# todo get_post_count_for_day didnt seem to work
# post_count = get_post_count_for_day(subreddit_to_scrape, day_to_scrape)
# total_requests = (post_count + 99) // 100 # Estimate the total number of requests
actual_requests = 0
while after < before:
after_str, before_str = convert_timestamp_to_datetime(after), convert_timestamp_to_datetime(before)
logger.info(f"Fetching data between timestamps {after_str} and {before_str}")
data = get_pushshift_data(subreddit_to_scrape, before=before, after=after)
if data is None or len(data["data"]) == 0:
break
scraped_submissions.extend(data["data"])
before = data["data"][-1]["created_utc"]
actual_requests += 1
time.sleep(1)
elapsed_time = time.time() - start_time
if actual_requests:
logger.info(
f"{actual_requests}it [{elapsed_time // 60:02}:{elapsed_time % 60:.2f} {elapsed_time / actual_requests:.2f}s/it]")
logger.info(
f"Finished scraping {len(scraped_submissions)} submissions in {elapsed_time:.2f} seconds in {actual_requests} requests")
return scraped_submissions
def submissions_to_dataframe(submissions: List[Dict[str, Any]]) -> pd.DataFrame:
"""
Parse a list of submissions into a pandas DataFrame.
:param submissions: A list of dictionaries containing the scraped submission data.
:return: A pandas DataFrame containing the submission data.
"""
cols = ['score', 'num_comments', 'title', 'permalink', 'selftext', 'url', 'created_utc', 'author', 'id',
'downs', 'ups']
df = pd.DataFrame(submissions)
df = df.convert_dtypes()
df = df[cols]
# Convert the "created_utc" column to a datetime column with timezone information
df['created_utc'] = pd.to_datetime(df['created_utc'], unit='s').dt.tz_localize('UTC').dt.strftime(
'%Y-%m-%d %H:%M:%S')
return df
if __name__ == '__main__':
subreddit_to_scrape = "askreddit"
day_to_scrape = "2013-03-01"
submissions = scrape_submissions_by_day(subreddit_to_scrape, day_to_scrape)
df = submissions_to_dataframe(submissions)
print(df.head().to_string())
logger.info(f"Scraped {len(submissions)} submissions from r/{subreddit_to_scrape} on {day_to_scrape}")
|