File size: 1,032 Bytes
7b34e37
 
 
91ce433
 
a0650fb
 
7b34e37
24fe512
91ce433
7b34e37
6fb939f
 
 
 
a0650fb
 
6fb939f
 
7b34e37
 
 
 
 
0461bfe
7b34e37
 
f901672
7b34e37
91ce433
24fe512
7b34e37
 
91ce433
6fb939f
7b34e37
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
A script that is run when the server starts.
"""
from concurrent.futures import ThreadPoolExecutor

from transformers import logging as transformers_logging
from huggingface_hub import logging as huggingface_hub_logging

from download_repo import download_pytorch_model


def disable_progress_bar():
    """
    Disables the progress bar when downloading models.
    """
    transformers_logging.disable_progress_bar()
    huggingface_hub_logging.disable_propagation()


def download_useful_models():
    """
    Downloads the models that are useful for this project.
    So that the user doesn't have to wait for the models to download when they first use the app.
    """
    print("Downloading useful models...")
    useful_models = (
        "facebook/opt-125m",
        "facebook/opt-iml-max-30b",
    )
    with ThreadPoolExecutor() as executor:
        executor.map(download_pytorch_model, useful_models)


async def main():
    disable_progress_bar()
    download_useful_models()


if __name__ == "__main__":
    main()