Spaces:
Sleeping
Sleeping
File size: 2,878 Bytes
faf4679 bcea21a faf4679 16b83b8 1a8333a bcea21a 16b83b8 faf4679 16b83b8 1a8333a 16b83b8 bcea21a faf4679 bcea21a faf4679 1a8333a faf4679 bcea21a 16b83b8 1a8333a 16b83b8 bcea21a 1a8333a bcea21a faf4679 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
import requests
from transformers import AutoModel
def model_file_exists_and_valid(model_file_path):
# Check if the model file exists and has a size greater than 0
return os.path.exists(model_file_path) and os.path.getsize(model_file_path) > 0
def write_model_path_to_txt_file(model_file_path):
# Write the model path to model_path.txt
with open('/home/user/data/models/model_path.txt', 'w') as f:
f.write(model_file_path)
def download_hf_model():
'''
Model File Path for HF Models: The download_hf_model function now includes a default model file path (pytorch_model.bin) check.
Adjust this path based on the expected model file type (e.g., TensorFlow or Flax models might have different names).
'''
model_name = os.getenv("HF_MODEL_NAME")
model_dir = f"/home/user/data/models/{model_name}"
model_file_path = os.path.join(model_dir, "pytorch_model.bin") # Assuming PyTorch model for simplicity
if model_file_exists_and_valid(model_file_path):
print(f"Model {model_name} already downloaded.")
write_model_path_to_txt_file(model_file_path)
return
# Authenticate with Hugging Face using the token, if available
hf_token = os.getenv("HF_TOKEN")
if hf_token:
from huggingface_hub import HfFolder
HfFolder.save_token(hf_token)
print(f"Downloading model: {model_name}...")
model = AutoModel.from_pretrained(model_name)
model.save_pretrained(model_dir)
print(f"Model {model_name} downloaded and saved to {model_dir}")
write_model_path_to_txt_file(model_file_path)
def download_gguf_model():
model_name = os.getenv("HF_MODEL_NAME")
model_dir = f"/home/user/data/models/{model_name}"
os.makedirs(model_dir, exist_ok=True)
model_url = os.getenv("GGUF_MODEL_URL") # Assuming URL is provided as an env variable
model_file_path = os.path.join(model_dir, os.path.basename(model_url))
if model_file_exists_and_valid(model_file_path):
print(f"Model {model_name} already downloaded.")
write_model_path_to_txt_file(model_file_path)
return
print(f"Downloading model from {model_url}...")
response = requests.get(model_url, stream=True)
if response.status_code == 200:
with open(model_file_path, 'wb') as f:
f.write(response.content)
print(f"Model downloaded and saved to {model_file_path}")
else:
print(f"Failed to download the model. Status code: {response.status_code}")
write_model_path_to_txt_file(model_file_path)
def download_model():
model_class = os.getenv("MODEL_CLASS")
if model_class == 'gguf':
download_gguf_model()
elif model_class == 'hf':
download_hf_model()
else:
print(f"Unsupported model class: {model_class}")
if __name__ == "__main__":
download_model()
|