papers / update_scheduler.py
pxiaoer's picture
Update update_scheduler.py
8ed0f40
raw
history blame
4.27 kB
import datetime
import pathlib
import re
import tempfile
import os
import pandas as pd
import requests
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import HfApi, Repository
from huggingface_hub.utils import RepositoryNotFoundError
class SpaceRestarter:
def __init__(self, space_id: str):
self.api = HfApi()
if self.api.get_token_permission() != "write":
raise ValueError(f"The HF token must have write permission.")
try:
self.api.space_info(repo_id=space_id)
except RepositoryNotFoundError:
raise ValueError("The Space ID does not exist.")
self.space_id = space_id
def restart(self) -> None:
self.api.restart_space(self.space_id)
def find_github_links(summary: str) -> str:
links = re.findall(r"https://github.com/[^/]+/[^/)}, ]+(?:/(?:tree|blob)/[^/]+/[^/)}, ]+)?", summary)
if len(links) == 0:
return ""
if len(links) != 1:
raise RuntimeError(f"Found multiple GitHub links: {links}")
link = links[0]
if link.endswith("."):
link = link[:-1]
link = link.strip()
return link
class RepoUpdater:
def __init__(self, repo_id: str, repo_type: str):
api = HfApi()
if api.get_token_permission() != "write":
raise ValueError("The HF token must have write permission.")
name = api.whoami()["name"]
repo_dir = pathlib.Path(tempfile.tempdir) / repo_id.split("/")[-1] # type: ignore
self.csv_path = repo_dir / "papers.csv"
self.repo = Repository(
local_dir=repo_dir,
clone_from=repo_id,
repo_type=repo_type,
git_user=name,
git_email=f"{name}@users.noreply.huggingface.co",
)
self.repo.git_pull()
def update(self) -> None:
yesterday = (datetime.datetime.now() - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
today = datetime.datetime.now().strftime("%Y-%m-%d")
daily_papers = [
{
"date": yesterday,
"papers": requests.get(f"https://huggingface.co/api/daily_papers?date={yesterday}").json(),
},
{
"date": today,
"papers": requests.get(f"https://huggingface.co/api/daily_papers?date={today}").json(),
},
]
self.repo.git_pull()
df = pd.read_csv(self.csv_path, dtype=str).fillna("")
rows = [row for _, row in df.iterrows()]
arxiv_ids = {row.arxiv_id for row in rows}
for d in daily_papers:
date = d["date"]
papers = d["papers"]
for paper in papers:
arxiv_id = paper["paper"]["id"]
if arxiv_id in arxiv_ids:
continue
try:
github = find_github_links(paper["paper"]["summary"])
except RuntimeError as e:
print(e)
continue
rows.append(
pd.Series(
{
"date": date,
"arxiv_id": arxiv_id,
"github": github,
}
)
)
df = pd.DataFrame(rows).reset_index(drop=True)
df.to_csv(self.csv_path, index=False)
def push(self) -> None:
self.repo.push_to_hub()
class UpdateScheduler:
def __init__(self, space_id: str, cron_hour: str, cron_minute: str, cron_second: str = "0"):
self.space_restarter = SpaceRestarter(space_id=space_id)
self.repo_updater = RepoUpdater(repo_id=space_id, repo_type="space")
self.scheduler = BackgroundScheduler()
self.scheduler.add_job(
func=self._update,
trigger="cron",
hour=cron_hour,
minute=cron_minute,
second=cron_second,
timezone="UTC",
)
def _update(self) -> None:
self.repo_updater.update()
if self.repo_updater.repo.is_repo_clean():
self.space_restarter.restart()
else:
self.repo_updater.push()
def start(self) -> None:
self.scheduler.start()