|
import datetime |
|
import pathlib |
|
import re |
|
import tempfile |
|
|
|
import pandas as pd |
|
import requests |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from huggingface_hub import HfApi, Repository |
|
from huggingface_hub.utils import RepositoryNotFoundError |
|
|
|
|
|
class SpaceRestarter: |
|
def __init__(self, space_id: str): |
|
self.api = HfApi() |
|
if self.api.get_token_permission() != "write": |
|
raise ValueError("The HF token must have write permission.") |
|
try: |
|
self.api.space_info(repo_id=space_id) |
|
except RepositoryNotFoundError: |
|
raise ValueError("The Space ID does not exist.") |
|
self.space_id = space_id |
|
|
|
def restart(self) -> None: |
|
self.api.restart_space(self.space_id) |
|
|
|
|
|
def find_github_links(summary: str) -> str: |
|
links = re.findall(r"https://github.com/[^/]+/[^/)}, ]+(?:/(?:tree|blob)/[^/]+/[^/)}, ]+)?", summary) |
|
if len(links) == 0: |
|
return "" |
|
if len(links) != 1: |
|
raise RuntimeError(f"Found multiple GitHub links: {links}") |
|
link = links[0] |
|
if link.endswith("."): |
|
link = link[:-1] |
|
link = link.strip() |
|
return link |
|
|
|
|
|
class RepoUpdater: |
|
def __init__(self, repo_id: str, repo_type: str): |
|
api = HfApi() |
|
if api.get_token_permission() != "write": |
|
raise ValueError("The HF token must have write permission.") |
|
|
|
name = api.whoami()["name"] |
|
|
|
repo_dir = pathlib.Path(tempfile.tempdir) / repo_id.split("/")[-1] |
|
self.csv_path = repo_dir / "papers.csv" |
|
self.repo = Repository( |
|
local_dir=repo_dir, |
|
clone_from=repo_id, |
|
repo_type=repo_type, |
|
git_user=name, |
|
git_email=f"{name}@users.noreply.huggingface.co", |
|
) |
|
self.repo.git_pull() |
|
|
|
def update(self) -> None: |
|
yesterday = (datetime.datetime.now() - datetime.timedelta(days=1)).strftime("%Y-%m-%d") |
|
today = datetime.datetime.now().strftime("%Y-%m-%d") |
|
daily_papers = [ |
|
{ |
|
"date": yesterday, |
|
"papers": requests.get(f"https://huggingface.co/api/daily_papers?date={yesterday}").json(), |
|
}, |
|
{ |
|
"date": today, |
|
"papers": requests.get(f"https://huggingface.co/api/daily_papers?date={today}").json(), |
|
}, |
|
] |
|
|
|
self.repo.git_pull() |
|
df = pd.read_csv(self.csv_path, dtype=str).fillna("") |
|
rows = [row for _, row in df.iterrows()] |
|
arxiv_ids = {row.arxiv_id for row in rows} |
|
|
|
for d in daily_papers: |
|
date = d["date"] |
|
papers = d["papers"] |
|
for paper in papers: |
|
arxiv_id = paper["paper"]["id"] |
|
if arxiv_id in arxiv_ids: |
|
continue |
|
try: |
|
github = find_github_links(paper["paper"]["summary"]) |
|
except RuntimeError as e: |
|
print(e) |
|
continue |
|
rows.append( |
|
pd.Series( |
|
{ |
|
"date": date, |
|
"arxiv_id": arxiv_id, |
|
"github": github, |
|
} |
|
) |
|
) |
|
df = pd.DataFrame(rows).reset_index(drop=True) |
|
df.to_csv(self.csv_path, index=False) |
|
|
|
def push(self) -> None: |
|
self.repo.push_to_hub() |
|
|
|
|
|
class UpdateScheduler: |
|
def __init__(self, space_id: str, cron_hour: str, cron_minute: str, cron_second: str = "0"): |
|
self.space_restarter = SpaceRestarter(space_id=space_id) |
|
self.repo_updater = RepoUpdater(repo_id=space_id, repo_type="space") |
|
|
|
self.scheduler = BackgroundScheduler() |
|
self.scheduler.add_job( |
|
func=self._update, |
|
trigger="cron", |
|
hour=cron_hour, |
|
minute=cron_minute, |
|
second=cron_second, |
|
timezone="UTC", |
|
) |
|
|
|
def _update(self) -> None: |
|
self.repo_updater.update() |
|
if self.repo_updater.repo.is_repo_clean(): |
|
self.space_restarter.restart() |
|
else: |
|
self.repo_updater.push() |
|
|
|
def start(self) -> None: |
|
self.scheduler.start() |
|
|