File size: 3,790 Bytes
b4eb3ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import datetime
import pathlib
import re
import tempfile

import pandas as pd
import requests
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import HfApi, Repository
from huggingface_hub.utils import RepositoryNotFoundError


class SpaceRestarter:
    def __init__(self, space_id: str):
        self.api = HfApi()
        if self.api.get_token_permission() != 'write':
            raise ValueError('The HF token must have write permission.')
        try:
            self.api.space_info(repo_id=space_id)
        except RepositoryNotFoundError:
            raise ValueError('The Space ID does not exist.')
        self.space_id = space_id

    def restart(self) -> None:
        self.api.restart_space(self.space_id)


def find_github_links(summary: str) -> str:
    links = re.findall(
        r'https://github.com/[^/]+/[^/)}, ]+(?:/(?:tree|blob)/[^/]+/[^/)}, ]+)?',
        summary)
    if len(links) == 0:
        return ''
    if len(links) != 1:
        raise RuntimeError(f'Found multiple GitHub links: {links}')
    link = links[0]
    if link.endswith('.'):
        link = link[:-1]
    link = link.strip()
    return link


class RepoUpdater:
    def __init__(self, repo_id: str, repo_type: str):
        api = HfApi()
        name = api.whoami()['name']

        self.repo_dir = pathlib.Path(
            tempfile.tempdir) / repo_id.split('/')[-1]  # type: ignore
        self.repo = Repository(
            local_dir=self.repo_dir,
            clone_from=repo_id,
            repo_type=repo_type,
            git_user=name,
            git_email=f'{name}@users.noreply.huggingface.co')
        self.repo.git_pull()

    def update(self) -> None:
        yesterday = (datetime.datetime.now() -
                     datetime.timedelta(days=1)).strftime('%Y-%m-%d')
        today = datetime.datetime.now().strftime('%Y-%m-%d')
        daily_papers = requests.get(
            f'https://huggingface.co/api/daily_papers?date={yesterday}').json(
            )
        daily_papers += requests.get(
            f'https://huggingface.co/api/daily_papers?date={today}').json()

        self.repo.git_pull()
        df = pd.read_csv(self.repo_dir / 'papers.csv', dtype=str).fillna('')
        rows = [row for _, row in df.iterrows()]
        arxiv_ids = {row.arxiv_id for row in rows}

        for paper in daily_papers:
            arxiv_id = paper['paper']['id']
            if arxiv_id in arxiv_ids:
                continue
            try:
                github = find_github_links(paper['paper']['summary'])
            except RuntimeError as e:
                print(e)
                continue
            rows.append(pd.Series({
                'arxiv_id': arxiv_id,
                'github': github,
            }))
        df = pd.DataFrame(rows).reset_index(drop=True)
        df.to_csv(self.repo_dir / 'papers.csv', index=False)

    def push(self) -> None:
        self.repo.push_to_hub()


class UpdateScheduler:
    def __init__(self, space_id: str, cron_hour: str, cron_minute: str):
        self.space_restarter = SpaceRestarter(space_id=space_id)
        self.repo_updater = RepoUpdater(repo_id=space_id, repo_type='space')

        self.scheduler = BackgroundScheduler()
        self.scheduler.add_job(func=self._update,
                               trigger='cron',
                               hour=cron_hour,
                               minute=cron_minute,
                               second=0,
                               timezone='UTC')

    def _update(self) -> None:
        self.repo_updater.update()
        if self.repo_updater.repo.is_repo_clean():
            self.space_restarter.restart()
        else:
            self.repo_updater.push()

    def start(self) -> None:
        self.scheduler.start()