Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import datetime | |
import tempfile | |
from huggingface_hub import hf_hub_download | |
import subprocess | |
def md5(filename): | |
return subprocess.check_output(["md5sum", filename]) | |
def download_very_slow(repo_id): | |
os.environ.pop("HF_TRANSFER", None) | |
os.environ["HF_CHUNK_SIZE"] = "1024" | |
with tempfile.TemporaryDirectory() as workdir: | |
filename = hf_hub_download( | |
repo_id, | |
filename="pytorch_model.bin", | |
force_download=True, | |
cache_dir=workdir, | |
) | |
return md5(filename) | |
def download_slow(repo_id): | |
os.environ.pop("HF_TRANSFER", None) | |
os.environ["HF_CHUNK_SIZE"] = "10485760" | |
with tempfile.TemporaryDirectory() as workdir: | |
filename = hf_hub_download( | |
repo_id, | |
filename="pytorch_model.bin", | |
force_download=True, | |
cache_dir=workdir, | |
) | |
return md5(filename) | |
def download_fast(repo_id): | |
os.environ["HF_TRANSFER"] = "1" | |
with tempfile.TemporaryDirectory() as workdir: | |
filename = hf_hub_download( | |
repo_id, | |
filename="pytorch_model.bin", | |
force_download=True, | |
cache_dir=workdir, | |
) | |
return md5(filename) | |
def download(repo_id): | |
start = datetime.datetime.now() | |
md5_very_slow = download_very_slow(repo_id) | |
taken_very_slow = datetime.datetime.now() - start | |
start = datetime.datetime.now() | |
md5_slow = download_slow(repo_id) | |
taken_slow = datetime.datetime.now() - start | |
start = datetime.datetime.now() | |
md5_fast = download_fast(repo_id) | |
taken_fast = datetime.datetime.now() - start | |
return f""" | |
Very slow (huggingface_hub previous to https://github.com/huggingface/huggingface_hub/pull/1267): {taken_very_slow} | |
MD5: {md5_very_slow} | |
Slow (huggingface_hub after): {taken_slow} | |
MD5: {md5_slow} | |
Fast (with hf_transfer): {taken_fast} | |
MD5: {md5_fast} | |
""" | |
examples = ["gpt2", "openai/whisper-large-v2"] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
inputs = gr.Textbox( | |
label="Repo id", | |
value="gpt2", # should be set to " " when plugged into a real API | |
) | |
submit = gr.Button("Submit") | |
with gr.Column(): | |
outputs = gr.Textbox( | |
label="Download speeds", | |
) | |
with gr.Row(): | |
gr.Examples(examples=examples, inputs=[inputs], cache_examples=True, fn=download, outputs=[outputs]) | |
submit.click( | |
download, | |
inputs=[inputs], | |
outputs=[outputs], | |
) | |
demo.launch() | |