|
import datetime |
|
import pathlib |
|
import re |
|
import tempfile |
|
|
|
import pandas as pd |
|
import requests |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from huggingface_hub import HfApi, Repository |
|
from huggingface_hub.utils import RepositoryNotFoundError |
|
|
|
|
|
class SpaceRestarter: |
|
def __init__(self, space_id: str): |
|
self.api = HfApi() |
|
if self.api.get_token_permission() != 'write': |
|
raise ValueError('The HF token must have write permission.') |
|
try: |
|
self.api.space_info(repo_id=space_id) |
|
except RepositoryNotFoundError: |
|
raise ValueError('The Space ID does not exist.') |
|
self.space_id = space_id |
|
|
|
def restart(self) -> None: |
|
self.api.restart_space(self.space_id) |
|
|
|
|
|
def find_github_links(summary: str) -> str: |
|
links = re.findall( |
|
r'https://github.com/[^/]+/[^/)}, ]+(?:/(?:tree|blob)/[^/]+/[^/)}, ]+)?', |
|
summary) |
|
if len(links) == 0: |
|
return '' |
|
if len(links) != 1: |
|
raise RuntimeError(f'Found multiple GitHub links: {links}') |
|
link = links[0] |
|
if link.endswith('.'): |
|
link = link[:-1] |
|
link = link.strip() |
|
return link |
|
|
|
|
|
class RepoUpdater: |
|
def __init__(self, repo_id: str, repo_type: str): |
|
api = HfApi() |
|
if api.get_token_permission() != 'write': |
|
raise ValueError('The HF token must have write permission.') |
|
|
|
name = api.whoami()['name'] |
|
|
|
repo_dir = pathlib.Path( |
|
tempfile.tempdir) / repo_id.split('/')[-1] |
|
self.csv_path = repo_dir / 'papers.csv' |
|
self.repo = Repository( |
|
local_dir=repo_dir, |
|
clone_from=repo_id, |
|
repo_type=repo_type, |
|
git_user=name, |
|
git_email=f'{name}@users.noreply.huggingface.co') |
|
self.repo.git_pull() |
|
|
|
def update(self) -> None: |
|
yesterday = (datetime.datetime.now() - |
|
datetime.timedelta(days=1)).strftime('%Y-%m-%d') |
|
today = datetime.datetime.now().strftime('%Y-%m-%d') |
|
daily_papers = [ |
|
{ |
|
'date': |
|
yesterday, |
|
'papers': |
|
requests.get( |
|
f'https://huggingface.co/api/daily_papers?date={yesterday}' |
|
).json() |
|
}, |
|
{ |
|
'date': |
|
today, |
|
'papers': |
|
requests.get( |
|
f'https://huggingface.co/api/daily_papers?date={today}'). |
|
json() |
|
}, |
|
] |
|
|
|
self.repo.git_pull() |
|
df = pd.read_csv(self.csv_path, dtype=str).fillna('') |
|
rows = [row for _, row in df.iterrows()] |
|
arxiv_ids = {row.arxiv_id for row in rows} |
|
|
|
for d in daily_papers: |
|
date = d['date'] |
|
papers = d['papers'] |
|
for paper in papers: |
|
arxiv_id = paper['paper']['id'] |
|
if arxiv_id in arxiv_ids: |
|
continue |
|
try: |
|
github = find_github_links(paper['paper']['summary']) |
|
except RuntimeError as e: |
|
print(e) |
|
continue |
|
rows.append( |
|
pd.Series({ |
|
'date': date, |
|
'arxiv_id': arxiv_id, |
|
'github': github, |
|
})) |
|
df = pd.DataFrame(rows).reset_index(drop=True) |
|
df.to_csv(self.csv_path, index=False) |
|
|
|
def push(self) -> None: |
|
self.repo.push_to_hub() |
|
|
|
|
|
class UpdateScheduler: |
|
def __init__(self, space_id: str, cron_hour: str, cron_minute: str): |
|
self.space_restarter = SpaceRestarter(space_id=space_id) |
|
self.repo_updater = RepoUpdater(repo_id=space_id, repo_type='space') |
|
|
|
self.scheduler = BackgroundScheduler() |
|
self.scheduler.add_job(func=self._update, |
|
trigger='cron', |
|
hour=cron_hour, |
|
minute=cron_minute, |
|
second=0, |
|
timezone='UTC') |
|
|
|
def _update(self) -> None: |
|
self.repo_updater.update() |
|
if self.repo_updater.repo.is_repo_clean(): |
|
self.space_restarter.restart() |
|
else: |
|
self.repo_updater.push() |
|
|
|
def start(self) -> None: |
|
self.scheduler.start() |
|
|