File size: 6,130 Bytes
749d1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import time
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional

import pandas as pd
import requests

from my_logger import setup_logger

logger = setup_logger(__name__)


def get_pushshift_data(subreddit: str, before: Optional[int] = None,
                       after: Optional[int] = None, aggs: Optional[str] = None) -> Optional[Dict[str, Any]]:
    """
    Fetch data from the Pushshift API for the specified subreddit.

    :param subreddit: The name of the subreddit to scrape.
    :param before: The upper limit for the created_utc attribute of the submissions.
    :param after: The lower limit for the created_utc attribute of the submissions.
    :param aggs: The aggregation summary option to use.
    :return: A dictionary containing the fetched data and aggregations if available.
    """
    url = "https://api.pushshift.io/reddit/search/submission/"
    params = {
        "subreddit": subreddit,
        "size": 1000,
        "sort": "created_utc",
        "sort_type": "desc",
        }
    if before is not None:
        params["before"] = before
    if after is not None:
        params["after"] = after
    if aggs is not None:
        params["aggs"] = aggs

    response = requests.get(url, params=params)
    if response.status_code == 200:
        return response.json()
    else:
        logger.error(f"Error fetching data: {response.status_code}")
        return None


def get_post_count_for_day(subreddit: str, day_to_scrape: str) -> int:
    """
    Get the total number of posts for a specific day in the specified subreddit using the Pushshift API.

    :param subreddit: The name of the subreddit to get the post count for.
    :param day_to_scrape: The date for which to get the post count (format: "YYYY-MM-DD").
    :return: The total number of posts for the specified day.
    """
    date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d")
    after = int(date_obj.timestamp())
    before = int((date_obj + timedelta(days=1)).timestamp())

    response = get_pushshift_data(subreddit, before=before, after=after, aggs="created_utc")
    if response is not None:
        aggs = response.get("aggs", {}).get("created_utc", [])
        if aggs:
            return aggs[0]["doc_count"]
    return 0


def fetch_data(subreddit: str, before: int, after: int) -> Optional[Dict[str, Any]]:
    url = "https://api.pushshift.io/reddit/search/submission/"
    params = {
        "subreddit": subreddit,
        "size": 1000,
        "sort": "created_utc",
        "sort_type": "desc",
        "before": before,
        "after": after,
        }

    response = requests.get(url, params=params)
    if response.status_code == 200:
        return response.json()
    else:
        logger.error(f"Error fetching data: {response.status_code}")
        return None


def convert_timestamp_to_datetime(timestamp: int) -> str:
    # Convert the timestamp to a datetime object
    datetime_obj = datetime.utcfromtimestamp(timestamp)

    # Add timezone information
    datetime_obj_utc = datetime_obj.replace(tzinfo=timezone.utc)

    # Convert the datetime object to a formatted string
    datetime_str = datetime_obj_utc.strftime('%Y-%m-%d %H:%M:%S')

    return datetime_str


def scrape_submissions_by_day(subreddit_to_scrape: str, day_to_scrape: str) -> List[Dict[str, Any]]:
    start_time = time.time()
    scraped_submissions = []
    date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d")

    if date_obj > datetime.now() - timedelta(days=7):
        logger.error("The specified date might not be available in the Pushshift API yet. "
                     "Please try an earlier date or wait for the API to be updated.")
        return scraped_submissions

    after = int(date_obj.timestamp())
    before = int((date_obj + timedelta(days=1)).timestamp())

    # todo get_post_count_for_day didnt seem to work
    # post_count = get_post_count_for_day(subreddit_to_scrape, day_to_scrape)
    # total_requests = (post_count + 99) // 100  # Estimate the total number of requests

    actual_requests = 0
    while after < before:
        after_str, before_str = convert_timestamp_to_datetime(after), convert_timestamp_to_datetime(before)
        logger.info(f"Fetching data between timestamps {after_str} and {before_str}")
        data = get_pushshift_data(subreddit_to_scrape, before=before, after=after)
        if data is None or len(data["data"]) == 0:
            break

        scraped_submissions.extend(data["data"])
        before = data["data"][-1]["created_utc"]

        actual_requests += 1
        time.sleep(1)

    elapsed_time = time.time() - start_time
    if actual_requests:
        logger.info(
                f"{actual_requests}it [{elapsed_time // 60:02}:{elapsed_time % 60:.2f} {elapsed_time / actual_requests:.2f}s/it]")
    logger.info(
            f"Finished scraping {len(scraped_submissions)} submissions in {elapsed_time:.2f} seconds in {actual_requests} requests")
    return scraped_submissions


def submissions_to_dataframe(submissions: List[Dict[str, Any]]) -> pd.DataFrame:
    """
    Parse a list of submissions into a pandas DataFrame.

    :param submissions: A list of dictionaries containing the scraped submission data.
    :return: A pandas DataFrame containing the submission data.
    """
    cols = ['score', 'num_comments', 'title', 'permalink', 'selftext', 'url', 'created_utc', 'author', 'id',
            'downs', 'ups']
    df = pd.DataFrame(submissions)
    df = df.convert_dtypes()
    df = df[cols]
    # Convert the "created_utc" column to a datetime column with timezone information
    df['created_utc'] = pd.to_datetime(df['created_utc'], unit='s').dt.tz_localize('UTC').dt.strftime(
            '%Y-%m-%d %H:%M:%S')
    return df


if __name__ == '__main__':
    subreddit_to_scrape = "askreddit"
    day_to_scrape = "2013-03-01"
    submissions = scrape_submissions_by_day(subreddit_to_scrape, day_to_scrape)
    df = submissions_to_dataframe(submissions)
    print(df.head().to_string())
    logger.info(f"Scraped {len(submissions)} submissions from r/{subreddit_to_scrape} on {day_to_scrape}")